trw.train.trainer

Module Contents

Functions

create_losses_fn(datasets, generic_loss)

Create a dictionary of loss functions for each of the dataset

aggregate_values(values)

aggregate_list_of_dicts(list_of_dicts)

aggregate_list_of_metrics(list_of_metrics)

generic_aggregate_loss_terms(loss_terms_history)

Aggregate the loss terms for all the internal_nodes of an epoch

loss_term_cleanup(loss_terms)

Perform cleanup on all the loss terms

train_loop(options, device, dataset_name, split_name, split, optimizer, per_step_scheduler, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, gradient_scaler=None)

Run the train loop (i.e., the model parameters will be updated)

eval_loop(options, device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)

Run the eval loop (i.e., the model parameters will NOT be updated)

approximate_batch_size_from_loss_terms(all_loss_terms)

Calculate on approximation of the number of samples from the loss terms. Error can be up to the number of

epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, per_step_schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, force_eval_mode, eval_loop_fn=eval_loop, train_loop_fn=train_loop)

param options

default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True, with_reporting_server=True, with_profiler=False, additional_callbacks=None)

Default callbacks to be performed before the fitting of the model

default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None, additional_callbacks=None)

Default callbacks to be performed at the end of each epoch

default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True, additional_callbacks=None)

Default callbacks to be performed after the model has been trained

trainer_callbacks_per_batch(dataset_name, split_name, batch)

Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)

strip_unpickable(outputs)

Remove the objects that cannot be pickled

Attributes

autocast

logger

default_logger

trw.train.trainer.autocast
trw.train.trainer.logger
trw.train.trainer.create_losses_fn(datasets, generic_loss)

Create a dictionary of loss functions for each of the dataset

Parameters
  • datasets – the datasets

  • generic_loss – a loss function

Returns

A dictionary of losses for each of the dataset

trw.train.trainer.aggregate_values(values)
trw.train.trainer.aggregate_list_of_dicts(list_of_dicts)
trw.train.trainer.aggregate_list_of_metrics(list_of_metrics)
trw.train.trainer.generic_aggregate_loss_terms(loss_terms_history)

Aggregate the loss terms for all the internal_nodes of an epoch

Parameters

loss_terms_history – a list of loss terms

Returns

a tuple output, history. output is maintained alive only during the current epoch.

history is kept in memory during the whole training

trw.train.trainer.loss_term_cleanup(loss_terms)

Perform cleanup on all the loss terms

Requires outputs.Output.output_ref_tag tag for each loss term, else no cleanup will be done for this loss term.

Parameters

loss_terms – the loss terms to be cleaned up

trw.train.trainer.train_loop(options, device, dataset_name, split_name, split, optimizer, per_step_scheduler, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, gradient_scaler=None)

Run the train loop (i.e., the model parameters will be updated)

Note

If callbacks_per_batch or callbacks_per_batch_loss_terms raise an exception StopIteration, the train loop will be stopped

Parameters
  • device – the device to be used to optimize the model

  • dataset_name – the name of the dataset

  • split_name – the name of the split

  • split – a dictionary of feature name and values

  • optimizer – an optimizer to optimize the model

  • per_step_scheduler – scheduler to be applied per-batch

  • model – the model to be optimized

  • loss_fn – the loss function

  • history – a list of history step

  • callbacks_per_batch – the callbacks to be performed on each batch. if None, no callbacks to be run

  • callbacks_per_batch_loss_terms – the callbacks to be performed on each loss term. if None, no callbacks to be run

  • gradient_scaler – if mixed precision is enabled, this is the scale to be used for the gradient update

Notes

if optimizer is None, there MUST be a .backward() to free graph and memory.

trw.train.trainer.eval_loop(options, device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)

Run the eval loop (i.e., the model parameters will NOT be updated)

Note

If callback_per_batch or callbacks_per_batch_loss_terms raise StopIteration, the eval loop will be stopped

Parameters
  • device

  • dataset_name

  • split_name

  • split

  • model

  • loss_fn

  • history

  • callbacks_per_batch

  • callbacks_per_batch_loss_terms

Returns

trw.train.trainer.approximate_batch_size_from_loss_terms(all_loss_terms)

Calculate on approximation of the number of samples from the loss terms. Error can be up to the number of samples within one batch

trw.train.trainer.epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, per_step_schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, force_eval_mode, eval_loop_fn=eval_loop, train_loop_fn=train_loop)
Parameters
  • options

  • datasets

  • optimizers

  • model

  • losses

  • schedulers

  • per_step_schedulers

  • history

  • callbacks_per_batch

  • callbacks_per_batch_loss_terms

  • run_eval

  • force_eval_mode

  • eval_loop_fn

  • train_loop_fn

Returns:

trw.train.trainer.default_logger
trw.train.trainer.default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True, with_reporting_server=True, with_profiler=False, additional_callbacks=None)

Default callbacks to be performed before the fitting of the model

trw.train.trainer.default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None, additional_callbacks=None)

Default callbacks to be performed at the end of each epoch

trw.train.trainer.default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True, additional_callbacks=None)

Default callbacks to be performed after the model has been trained

trw.train.trainer.trainer_callbacks_per_batch(dataset_name, split_name, batch)

Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)

Parameters
  • dataset_name

  • split_name

  • batch

Returns

trw.train.trainer.strip_unpickable(outputs)

Remove the objects that cannot be pickled