
Module Contents



This is a tag name to find the output reference back from outputs


Represent an embedding


Classification output


Classification output for binary classification


Classification output


Output for binary segmentation.


Regression output


This is a tag name to find the output reference back from outputs


Represent a given loss as an output.



Transform all torch.Tensor to numpy arrays of a dictionary like object

extract_metrics(metrics_outputs, outputs)

Extract metrics from an output

segmentation_criteria_ce_dice(output, truth, per_voxel_weights=None, ce_weight=0.5, per_class_weights=None, power=1.0, smooth=1.0, focal_gamma=None)

loss combining cross entropy and multi-class dice

criterion_softmax_cross_entropy(output, output_truth)


param x

a Tensor




Transform all torch.Tensor to numpy arrays of a dictionary like object

class trw.train.outputs_trw.Output(metrics, output, criterion_fn, collect_output=False, sample_uid_name=None)

This is a tag name to find the output reference back from outputs

output_ref_tag = output_ref
evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored or release CUDA memory

trw.train.outputs_trw.extract_metrics(metrics_outputs, outputs)

Extract metrics from an output

  • metrics_outputs – a list of metrics

  • outputs – the result of Output.evaluate_batch


a dictionary of key, value

class trw.train.outputs_trw.OutputEmbedding(output, clean_loss_term_each_batch=False, sample_uid_name=default_sample_uid_name, functor=None)

Bases: Output

Represent an embedding

This is only used to record a tensor that we consider an embedding (e.g., to be exported to tensorboard)

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored or release CUDA memory

trw.train.outputs_trw.segmentation_criteria_ce_dice(output, truth, per_voxel_weights=None, ce_weight=0.5, per_class_weights=None, power=1.0, smooth=1.0, focal_gamma=None)

loss combining cross entropy and multi-class dice

  • output – the output value, with shape [N, C, Dn…D0]

  • truth – the truth, with shape [N, 1, Dn..D0]

  • ce_weight – the weight of the cross entropy to use. This controls the importance of the cross entropy loss to the overall segmentation loss. Range in [0..1]

  • per_class_weights – the weight per class. A 1D vector of size C indicating the weight of the classes. This will be used for the cross-entropy loss

  • per_voxel_weights – the weight of each truth voxel. Must be of shape [N, Dn..D0]


a torch tensor

trw.train.outputs_trw.criterion_softmax_cross_entropy(output, output_truth)
class trw.train.outputs_trw.OutputClassification(output, output_truth, *, criterion_fn=lambda : ..., collect_output=True, collect_only_non_training_output=False, metrics: List[OutputClassification.__init__.metrics] = metrics.default_classification_metrics(), loss_reduction=torch.mean, weights=None, per_voxel_weights=None, loss_scaling=1.0, output_postprocessing=functools.partial(torch.argmax, dim=1, keepdim=True), maybe_optional=False, classes_name='unknown', sample_uid_name=default_sample_uid_name)

Bases: Output

Classification output

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored or release CUDA memory

class trw.train.outputs_trw.OutputClassificationBinary(output, output_truth, *, criterion_fn=lambda : ..., collect_output=True, collect_only_non_training_output=False, metrics: List[OutputClassificationBinary.__init__.metrics] = metrics.default_classification_metrics(), loss_reduction=torch.mean, weights=None, per_voxel_weights=None, loss_scaling=1.0, output_postprocessing=lambda x: ..., maybe_optional=False, classes_name='unknown', sample_uid_name=default_sample_uid_name)

Bases: OutputClassification

Classification output for binary classification

  • output – the output with shape [N, 1, {X}], without any activation applied (i.e., logits)

  • output_truth – the truth with shape [N, 1, {X}]

class trw.train.outputs_trw.OutputSegmentation(output: torch.Tensor, output_truth: torch.Tensor, criterion_fn: Callable[[], Any] = LossDiceMulticlass, collect_output: bool = False, collect_only_non_training_output: bool = False, metrics: List[OutputSegmentation.__init__.metrics] = metrics.default_segmentation_metrics(), loss_reduction: Callable[[torch.Tensor], torch.Tensor] = torch.mean, weights=None, per_voxel_weights=None, loss_scaling=1.0, output_postprocessing=functools.partial(torch.argmax, dim=1, keepdim=True), maybe_optional=False, sample_uid_name=default_sample_uid_name)

Bases: OutputClassification

Classification output

class trw.train.outputs_trw.OutputSegmentationBinary(output: torch.Tensor, output_truth: torch.Tensor, criterion_fn: Callable[[], Any] = LossDiceMulticlass, collect_output: bool = False, collect_only_non_training_output: bool = False, metrics: List[OutputSegmentationBinary.__init__.metrics] = metrics.default_segmentation_metrics(), loss_reduction: Callable[[torch.Tensor], torch.Tensor] = torch.mean, weights=None, per_voxel_weights=None, loss_scaling=1.0, output_postprocessing=lambda x: ..., maybe_optional=False, sample_uid_name=default_sample_uid_name)

Bases: OutputSegmentation

Output for binary segmentation.

  • output – shape N * 1 * X format, must be raw logits

  • output_truth – should have N * 1 * X format, with values 0 or 1


x – a Tensor


the mean of all values

class trw.train.outputs_trw.OutputRegression(output, output_truth, criterion_fn=lambda : ..., collect_output=True, collect_only_non_training_output=False, metrics=metrics.default_regression_metrics(), loss_reduction=mean_all, weights=None, loss_scaling=1.0, output_postprocessing=lambda x: ..., target_name=None, sample_uid_name=default_sample_uid_name)

Bases: Output

Regression output

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

class trw.train.outputs_trw.OutputTriplets(samples, positive_samples, negative_samples, criterion_fn=lambda : ..., metrics=metrics.default_generic_metrics(), loss_reduction=mean_all, weight_name=None, loss_scaling=1.0, sample_uid_name=default_sample_uid_name)

Bases: Output

This is a tag name to find the output reference back from outputs

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

class trw.train.outputs_trw.OutputLoss(losses, loss_reduction=torch.mean, metrics=metrics.default_generic_metrics(), sample_uid_name=default_sample_uid_name)

Bases: Output

Represent a given loss as an output.

This can be useful to add additional regularizer to the training (e.g., trw.train.LossCenter).

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: tuple(a dictionary of values, dictionary of metrics)

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored or release CUDA memory