trw.train.trainer

Module Contents

Classes

Trainer

This is the main class to train a model

Functions

postprocess_batch(dataset_name, split_name, batch, callbacks_per_batch)

Post process a batch of data (e.g., this can be useful to add additional

prepare_loss_terms(outputs, batch, is_training)

Return the loss_terms for the given outputs

create_losses_fn(datasets, generic_loss)

Create a dictionary of loss functions for each of the dataset

generic_aggregate_loss_terms(loss_terms_history)

Aggregate the loss terms for all the internal_nodes of an epoch

loss_term_cleanup(loss_terms)

Perform cleanup on all the loss terms

train_loop(device, dataset_name, split_name, split, optimizer, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, apply_backward=True)

Run the train loop (i.e., the model parameters will be updated)

eval_loop(device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)

Run the eval loop (i.e., the model parameters will NOT be updated)

epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, eval_loop_fn=eval_loop, train_loop_fn=train_loop)

Orchestrate the train and evaluation loops

default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True)

Default callbacks to be performed before the fitting of the model

default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None)

Default callbacks to be performed at the end of each epoch

default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True)

Default callbacks to be performed after the model has been trained

default_sum_all_losses(dataset_name, batch, loss_terms)

Default loss is the sum of all loss terms

trainer_callbacks_per_batch(dataset_name, split_name, batch)

Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)

strip_unpickable(outputs)

Remove the objects that cannot be pickled

run_trainer_repeat(trainer, options, inputs_fn, model_fn, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, run_prefix='default', eval_every_X_epoch=1, number_of_training_runs=10, post_init_fn=None)

Manages multiple run of a trainer for example to repeat the training and have an idea of the variance of a model

Attributes

logger

default_logger

trw.train.trainer.logger
trw.train.trainer.postprocess_batch(dataset_name, split_name, batch, callbacks_per_batch)

Post process a batch of data (e.g., this can be useful to add additional data to the current batch)

Parameters
  • dataset_name (str) – the name of the dataset the batch belongs to

  • split_name (str) – the name of the split the batch belongs to

  • batch – the current batch of data

  • callbacks_per_batch (list) – the callbacks to be executed for each batch. Each callback must be callable with (dataset_name, split_name, batch). if None, no callbacks

trw.train.trainer.prepare_loss_terms(outputs, batch, is_training)

Return the loss_terms for the given outputs

trw.train.trainer.create_losses_fn(datasets, generic_loss)

Create a dictionary of loss functions for each of the dataset

Parameters
  • datasets – the datasets

  • generic_loss – a loss function

Returns

A dictionary of losses for each of the dataset

trw.train.trainer.generic_aggregate_loss_terms(loss_terms_history)

Aggregate the loss terms for all the internal_nodes of an epoch

Parameters

loss_terms_history – a list of loss terms

Returns

a tuple output, history. output is maintained alive only during the current epoch.

history is kept in memory during the whole training

trw.train.trainer.loss_term_cleanup(loss_terms)

Perform cleanup on all the loss terms

Requires outputs.Output.output_ref_tag tag for each loss term, else no cleanup will be done for this loss term.

Parameters

loss_terms – the loss terms to be cleaned up

trw.train.trainer.train_loop(device, dataset_name, split_name, split, optimizer, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, apply_backward=True)

Run the train loop (i.e., the model parameters will be updated)

Note

If callbacks_per_batch or callbacks_per_batch_loss_terms raise an exception StopIteration, the train loop will be stopped

Parameters
  • device – the device to be used to optimize the model

  • dataset_name – the name of the dataset

  • split_name – the name of the split

  • split – a dictionary of feature name and values

  • optimizer – an optimizer to optimize the model

  • model – the model to be optimized

  • loss_fn – the loss function

  • history – a list of history step

  • callbacks_per_batch – the callbacks to be performed on each batch. if None, no callbacks to be run

  • callbacks_per_batch_loss_terms – the callbacks to be performed on each loss term. if None, no callbacks to be run

  • apply_backward – if True, the gradient will be back-propagated

trw.train.trainer.eval_loop(device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)

Run the eval loop (i.e., the model parameters will NOT be updated)

Note

If callback_per_batch or callbacks_per_batch_loss_terms raise StopIteration, the eval loop will be stopped

Parameters
  • device

  • dataset_name

  • split_name

  • split

  • model

  • loss_fn

  • history

  • callbacks_per_batch

  • callbacks_per_batch_loss_terms

Returns

trw.train.trainer.epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, eval_loop_fn=eval_loop, train_loop_fn=train_loop)

Orchestrate the train and evaluation loops

Parameters
  • options

  • datasets

  • optimizers – if None, no optimization will be performed on the train split else a dictionary of

optimizers (on for each dataset) :param model: :param losses: :param schedulers: :param history: :param callbacks_per_batch: :param callbacks_per_batch_loss_terms: :param run_eval: if True, run the evaluation :param eval_loop_fn: the eval function to be used :param train_loop_fn: the train function to be used :return:

trw.train.trainer.default_logger
trw.train.trainer.default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True)

Default callbacks to be performed before the fitting of the model

trw.train.trainer.default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None)

Default callbacks to be performed at the end of each epoch

trw.train.trainer.default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True)

Default callbacks to be performed after the model has been trained

trw.train.trainer.default_sum_all_losses(dataset_name, batch, loss_terms)

Default loss is the sum of all loss terms

trw.train.trainer.trainer_callbacks_per_batch(dataset_name, split_name, batch)

Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)

Parameters
  • dataset_name

  • split_name

  • batch

Returns

class trw.train.trainer.Trainer(callbacks_per_batch_fn=None, callbacks_per_batch_loss_terms_fn=None, callbacks_per_epoch_fn=default_per_epoch_callbacks, callbacks_pre_training_fn=default_pre_training_callbacks, callbacks_post_training_fn=default_post_training_callbacks, trainer_callbacks_per_batch=trainer_callbacks_per_batch, run_epoch_fn=epoch_train_eval)

This is the main class to train a model

static save_model(model, result, path)

Save a model :param model: a PyTorch model :param result: None or the result of the model :param path: where to store the model. The result will be saved at path + ‘.result’

static load_model(path, with_result=False, device=None)

load a saved model

Parameters
  • path – where to store the model. result’s will be loaded from path + ‘.result’

  • with_result – if True, the results of the model will be loaded

  • device – where to load the model. For example, models are typically trained on GPU, but for deployment, CPU might be good enough. If None, use the same device as when the model was exported

Returns

a tuple model, result

fit(self, options, inputs_fn, model_fn, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, run_prefix='default', with_final_evaluation=True, eval_every_X_epoch=1)

Fit the model

Requirements:

  • enough main memory to store the outputs of all the datasets of a single epoch.

    If this cannot be satisfied, sub-sample the epoch so that it can fit in main memory.

Notes:

  • if a feature value is Callable, its value will be replaced by the result of the call

    (e.g., this can be useful to generate z embedding in GANs)

Parameters
  • options

  • inputs_fn

    a functor returning a dictionary of datasets. Alternatively, datasets infos can be specified. inputs_fn must return one of:

    • datasets: dictionary of dataset

    • (datasets, datasets_infos): dictionary of dataset and additional infos

    We define:

    • datasets: a dictionary of dataset. a dataset is a dictionary of splits. a split is a dictionary of batched features.

    • Datasets infos are additional infos useful for the debugging of the dataset (e.g., class mappings, sample UIDs).

    Datasets infos are typically much smaller than datasets should be loaded in loadable in memory

  • model_fn – a functor with parameter options and returning a Module or a ModuleDict

Depending of the type of the model, this is how it will be used:

  • Module: optimizer will optimize model.parameters()

  • ModuleDict: for each dataset name, the optimizer will optimize

    model[dataset_name].parameters(). Note that a forward method will need to be implemented

Parameters
  • losses_fn

  • optimizers_fn

  • loss_creator

  • eval_every_X_epoch – evaluate the model every X epochs

  • run_prefix – the prefix of the output folder

  • with_final_evaluation – if True, once the model is fitted, evaluate all the data again in eval mode

Returns

a tuple model, result

trw.train.trainer.strip_unpickable(outputs)

Remove the objects that cannot be pickled

trw.train.trainer.run_trainer_repeat(trainer, options, inputs_fn, model_fn, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, run_prefix='default', eval_every_X_epoch=1, number_of_training_runs=10, post_init_fn=None)

Manages multiple run of a trainer for example to repeat the training and have an idea of the variance of a model

Parameters
  • trainer

  • options

  • inputs_fn

  • model_fn

  • optimizers_fn

  • losses_fn

  • loss_creator

  • run_prefix

  • eval_every_X_epoch

  • number_of_training_runs

  • post_init_fn – if not None, a function to be called before each training repeat

Returns

a tuple model, result of the last model trained