trw.train.outputs

Module Contents

Classes

Output

This is a tag name to find the output reference back from outputs

OutputEmbedding

Represent an embedding

OutputSegmentation

Segmentation output

OutputClassification

Classification output

OutputRegression

Regression output

OutputRecord

Record the raw value, but do not compute any loss from it.

Functions

extract_history_from_outputs_and_metrics(metrics, outputs)

Extract a history from metrics and an output result

segmentation_criteria_ce_dice(output, truth, ce_weight=0.5)

loss combining cross entropy and multiclass dice

segmentation_output_postprocessing(mask_pb)

Post-process the mask probability of the segmentation into discrete segmentation map

mean_all(x)

param x

a Tensor

class trw.train.outputs.Output(output, criterion_fn, collect_output=False, sample_uid_name=None)

This is a tag name to find the output reference back from outputs

output_ref_tag = output_ref
evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored

extract_history(self, outputs)

Summarizes epoch statistics from the calculated outputs to populate an history :param outputs: the aggregated evaluate_batch output :return: a dictionary

trw.train.outputs.extract_history_from_outputs_and_metrics(metrics, outputs)

Extract a history from metrics and an output result :param metrics: a list of metrics :param outputs: the result of Output.evaluate_batch

Returns

a dictionary of key, value

class trw.train.outputs.OutputEmbedding(output, clean_loss_term_each_batch=False, sample_uid_name=default_sample_uid_name)

Bases: Output

Represent an embedding

This is only used to record a tensor that we consider an embedding (e.g., to be exported to tensorboard)

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary

loss_term_cleanup(self, loss_term)

This function is called for each batch just before switching to another batch.

It can be used to clean up large arrays stored

trw.train.outputs.segmentation_criteria_ce_dice(output, truth, ce_weight=0.5)

loss combining cross entropy and multiclass dice

Parameters
  • output – the output value

  • truth – the truth

  • ce_weight – the weight of the cross entropy to use. This controls the importance of the cross entropy loss to the overall segmentation loss

Returns

a torch tensor

trw.train.outputs.segmentation_output_postprocessing(mask_pb)

Post-process the mask probability of the segmentation into discrete segmentation map

class trw.train.outputs.OutputSegmentation(output, target_name, criterion_fn=lambda : ..., collect_only_non_training_output=True, metrics=metrics.default_segmentation_metrics(), loss_reduction=torch.mean, weight_name=None, loss_scaling=1.0, collect_output=True, output_postprocessing=functools.partial(torch.argmax, dim=1), sample_uid_name=default_sample_uid_name)

Bases: Output

Segmentation output

extract_history(self, outputs)

Summarizes epoch statistics from the calculated outputs to populate an history :param outputs: the aggregated evaluate_batch output :return: a dictionary

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary

class trw.train.outputs.OutputClassification(output, classes_name, criterion_fn=lambda : ..., collect_output=True, collect_only_non_training_output=False, metrics=metrics.default_classification_metrics(), loss_reduction=torch.mean, weight_name=None, loss_scaling=1.0, output_postprocessing=functools.partial(torch.argmax, dim=1), maybe_optional=False, sample_uid_name=default_sample_uid_name)

Bases: Output

Classification output

extract_history(self, outputs)

Summarizes epoch statistics from the calculated outputs to populate an history :param outputs: the aggregated evaluate_batch output :return: a dictionary

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary

trw.train.outputs.mean_all(x)
Parameters

x – a Tensor

Returns

the mean of all values

class trw.train.outputs.OutputRegression(output, target_name, criterion_fn=lambda : ..., collect_output=True, collect_only_non_training_output=False, metrics=metrics.default_regression_metrics(), loss_reduction=mean_all, weight_name=None, loss_scaling=1.0, output_postprocessing=lambda x: ..., sample_uid_name=default_sample_uid_name)

Bases: Output

Regression output

extract_history(self, outputs)

Summarizes epoch statistics from the calculated outputs to populate an history :param outputs: the aggregated evaluate_batch output :return: a dictionary

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary

class trw.train.outputs.OutputRecord(output)

Bases: Output

Record the raw value, but do not compute any loss from it.

This is useful, e.g., to collect UIDs so that we can save them in the network result and further post-process it (e.g., k-fold cross validation)

evaluate_batch(self, batch, is_training)

Evaluate a batch of data and extract important outputs :param batch: the batch of data :param is_training: if True, this was a training batch :return: a dictionary