trw.train.losses

Module Contents

Classes

LossDiceMulticlass

Implementation of the Dice Loss (multi-class) for N-d images

class trw.train.losses.LossDiceMulticlass(normalization_fn=nn.Sigmoid, eps=0.0001)

Bases: torch.nn.Module

Implementation of the Dice Loss (multi-class) for N-d images

If multi-class, compute the loss for each class then average the losses

forward(self, output, target)
Parameters
  • output – must have W x C x d0 x … x dn shape, where C is the total number of classes to predict

  • target – must have W x d0 x … x dn shape

Returns

The dice score