trw.train.losses
¶
Module Contents¶
Classes¶
Implementation of the Dice Loss (multi-class) for N-d images |
- class trw.train.losses.LossDiceMulticlass(normalization_fn=nn.Sigmoid, eps=0.0001)¶
Bases:
torch.nn.Module
Implementation of the Dice Loss (multi-class) for N-d images
If multi-class, compute the loss for each class then average the losses
- forward(self, output, target)¶
- Parameters
output – must have W x C x d0 x … x dn shape, where C is the total number of classes to predict
target – must have W x d0 x … x dn shape
- Returns
The dice score