trw.train.losses

Module Contents

Classes

LossDiceMulticlass

Implementation of the soft Dice Loss (multi-class) for N-d images

LossCrossEntropyCsiMulticlass

Optimize a metric similar to Critical Success Index (CSI) on the cross-entropy

LossFocalMulticlass

This criterion is a implementation of Focal Loss, which is proposed in

LossTriplets

Implement a triplet loss

LossBinaryF1

The macro F1-score is non-differentiable. Instead use a surrogate that is differentiable

LossCenter

Center loss, penalize the features falling further from the feature class center.

LossContrastive

Implementation of the contrastive loss.

LossMsePacked

Mean squared error loss with target packed as an integer (e.g., classification)

Functions

one_hot(targets: trw.basic_typing.TorchTensorNX, num_classes: int, dtype=torch.float32, device: Optional[torch.device] = None) → trw.basic_typing.TorchTensorNCX

Encode the targets (an tensor of integers representing a class)

_total_variation_norm_2d(x, beta)

_total_variation_norm_3d(x, beta)

total_variation_norm(x, beta)

Calculate the total variation norm

trw.train.losses.one_hot(targets: trw.basic_typing.TorchTensorNX, num_classes: int, dtype=torch.float32, device: Optional[torch.device] = None) trw.basic_typing.TorchTensorNCX

Encode the targets (an tensor of integers representing a class) as one hot encoding.

Support target as N-dimensional data (e.g., 3D segmentation map).

Equivalent to torch.nn.functional.one_hot for backward compatibility with pytorch 1.0

Parameters
  • num_classes – the total number of classes

  • targets – a N-dimensional integral tensor (e.g., 1D for classification, 2D for 2D segmentation map…)

  • dtype – the type of the output tensor

  • device – the device of the one-hot encoded tensor. If None, use the target’s device

Returns

a one hot encoding of a N-dimentional integral tensor

class trw.train.losses.LossDiceMulticlass(normalization_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.Softmax, dim=1), eps: float = 1e-05, return_dice_by_class: bool = False, smooth: float = 0.001, power: float = 1.0, per_class_weights: Sequence[float] = None, discard_background_loss: bool = True)

Bases: torch.nn.Module

Implementation of the soft Dice Loss (multi-class) for N-d images

If multi-class, compute the loss for each class then average the losses

References

[1] “V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation” https://arxiv.org/pdf/1606.04797.pdf

forward(self, output, target)
Parameters
  • output – must have N x C x d0 x … x dn shape, where C is the total number of classes to predict

  • target – must have N x 1 x d0 x … x dn shape

Returns

if return_dice_by_class is False, return 1 - dice score suitable for optimization. Else, return the (numerator, cardinality) by class and by sample

class trw.train.losses.LossCrossEntropyCsiMulticlass

Bases: torch.nn.Module

Optimize a metric similar to Critical Success Index (CSI) on the cross-entropy

A loss for heavily unbalanced data (order of magnitude more negative than positive) Calculate the cross-entropy and use only the loss using the TP, FP and FN. Loss from TN is simply discarded.

forward(self, outputs, targets, important_class=1)
Parameters
  • outputs – a N x C tensor with N the number of samples and C the number of classes

  • targets – a N integral tensor

  • important_class – the class to keep the cross-entropy loss even if classification is correct

Returns

a N floating tensor representing the loss of each sample

class trw.train.losses.LossFocalMulticlass(alpha=None, gamma=2, reduction='mean')

Bases: torch.nn.Module

This criterion is a implementation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection, https://arxiv.org/pdf/1708.02002.pdf

Loss(x, class) = - alpha (1-softmax(x)[class])^gamma log(softmax(x)[class])

Parameters
  • alpha (1D Tensor, Variable) – the scalar factor for this criterion. One weight factor for each class.

  • gamma (float, double) – gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples

forward(self, outputs, targets)
class trw.train.losses.LossTriplets(margin=1.0, distance=nn.PairwiseDistance(p=2))

Bases: torch.nn.Module

Implement a triplet loss

The goal of the triplet loss is to make sure that:

  • Two examples with the same label have their embeddings close together in the embedding space

  • Two examples with different labels have their embeddings far away.

However, we don’t want to push the train embeddings of each label to collapse into very small clusters. The only requirement is that given two positive examples of the same class and one negative example, the negative should be farther away than the positive by some margin. This is very similar to the margin used in SVMs, and here we want the clusters of each class to be separated by the margin.

The loss implements the following equation:

mathcal{L} = max(d(a, p) - d(a, n) + margin, 0)

forward(self, samples, positive_samples, negative_samples)

Calculate the triplet loss

Parameters
  • samples – the samples

  • positive_samples – the samples that belong to the same group as samples

  • negative_samples – the samples that belong to a different group than samples

Returns

a 1D tensor (N) representing the loss per sample

class trw.train.losses.LossBinaryF1(eps=0.0001)

Bases: torch.nn.Module

The macro F1-score is non-differentiable. Instead use a surrogate that is differentiable

and correlates well with the Macro F1 score by working on the class probabilities rather than the discrete classification.

For example, if the ground truth is 1 and the model prediction is 0.8, we calculate it as 0.8 true

positive and 0.2 false negative

forward(self, outputs, targets)
class trw.train.losses.LossCenter(number_of_classes, number_of_features, alpha=1.0)

Bases: torch.nn.Module

Center loss, penalize the features falling further from the feature class center.

In most of the available CNNs, the softmax loss function is used as the supervision signal to train the deep model. In order to enhance the discriminative power of the deeply learned features, this loss can be used as a new supervision signal. Specifically, the center loss simultaneously learns a center for deep features of each class and penalizes the distances between the deep features and their corresponding class centers.

An implementation of center loss: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

Note

This loss must be part of a parent module or explicitly optimized by an optimizer. If not, the centers will not be modified.

forward(self, x, classes)
Parameters
  • x – the features, an arbitrary n-d tensor (N * C * …). Features should ideally be in range [0..1]

  • classes – a 1D integral tensor (N) representing the class of each x

Returns

a 1D tensor (N) representing the loss per sample

class trw.train.losses.LossContrastive(margin=1.0)

Bases: torch.nn.Module

Implementation of the contrastive loss.

L(x0, x1, y) = 0.5 * (1 - y) * d(x0, x1)^2 + 0.5 * y * max(0, m - d(x0, x1))^2

with y = 0 for samples x0 and x1 deemed dissimilar while y = 1 for similar samples. Dissimilar pairs contribute to the loss function only if their distance is within this radius m and minimize d(x0, x1) over the set of all similar pairs.

See Dimensionality Reduction by Learning an Invariant Mapping, Raia Hadsell, Sumit Chopra, Yann LeCun, 2006.

forward(self, x0, x1, same_target)
Parameters
  • x0 – N-D tensor

  • x1 – N-D tensor

  • same_target0 or 1 1D tensor. 1 means the x0 and x1 belongs to the same class, while 0 means they are from a different class

Returns

a 1D tensor (N) representing the loss per sample

trw.train.losses._total_variation_norm_2d(x, beta)
trw.train.losses._total_variation_norm_3d(x, beta)
trw.train.losses.total_variation_norm(x, beta)

Calculate the total variation norm

Parameters
  • x – a tensor with format (samples, components, dn, …, d0)

  • beta – the exponent

Returns

a scalar

class trw.train.losses.LossMsePacked(reduction: typing_extensions.Literal[mean, none] = 'mean')

Bases: torch.nn.Module

Mean squared error loss with target packed as an integer (e.g., classification)

The packed_target will be one hot encoded and the mean squared error is applied with the tensor.

forward(self, tensor, packed_target)
Parameters
  • tensor – a NxCx… tensor

  • packed_target – a Nx1x… tensor