trw.train.trainer
¶
Module Contents¶
Classes¶
This is the main class to train a model |
Functions¶
|
Post process a batch of data (e.g., this can be useful to add additional |
|
Return the loss_terms for the given outputs |
|
Create a dictionary of loss functions for each of the dataset |
|
Aggregate the loss terms for all the internal_nodes of an epoch |
|
Perform cleanup on all the loss terms |
|
Run the train loop (i.e., the model parameters will be updated) |
|
Run the eval loop (i.e., the model parameters will NOT be updated) |
|
Orchestrate the train and evaluation loops |
|
Default callbacks to be performed before the fitting of the model |
|
Default callbacks to be performed at the end of each epoch |
|
Default callbacks to be performed after the model has been trained |
|
Default loss is the sum of all loss terms |
|
Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it) |
|
Remove the objects that cannot be pickled |
|
Manages multiple run of a trainer for example to repeat the training and have an idea of the variance of a model |
Attributes¶
- trw.train.trainer.logger¶
- trw.train.trainer.postprocess_batch(dataset_name, split_name, batch, callbacks_per_batch)¶
Post process a batch of data (e.g., this can be useful to add additional data to the current batch)
- Parameters
dataset_name (str) – the name of the dataset the batch belongs to
split_name (str) – the name of the split the batch belongs to
batch – the current batch of data
callbacks_per_batch (list) – the callbacks to be executed for each batch. Each callback must be callable with (dataset_name, split_name, batch). if None, no callbacks
- trw.train.trainer.prepare_loss_terms(outputs, batch, is_training)¶
Return the loss_terms for the given outputs
- trw.train.trainer.create_losses_fn(datasets, generic_loss)¶
Create a dictionary of loss functions for each of the dataset
- Parameters
datasets – the datasets
generic_loss – a loss function
- Returns
A dictionary of losses for each of the dataset
- trw.train.trainer.generic_aggregate_loss_terms(loss_terms_history)¶
Aggregate the loss terms for all the internal_nodes of an epoch
- Parameters
loss_terms_history – a list of loss terms
- Returns
- a tuple output, history. output is maintained alive only during the current epoch.
history is kept in memory during the whole training
- trw.train.trainer.loss_term_cleanup(loss_terms)¶
Perform cleanup on all the loss terms
Requires
outputs.Output.output_ref_tag
tag for each loss term, else no cleanup will be done for this loss term.- Parameters
loss_terms – the loss terms to be cleaned up
- trw.train.trainer.train_loop(device, dataset_name, split_name, split, optimizer, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, apply_backward=True)¶
Run the train loop (i.e., the model parameters will be updated)
Note
If callbacks_per_batch or callbacks_per_batch_loss_terms raise an exception StopIteration, the train loop will be stopped
- Parameters
device – the device to be used to optimize the model
dataset_name – the name of the dataset
split_name – the name of the split
split – a dictionary of feature name and values
optimizer – an optimizer to optimize the model
model – the model to be optimized
loss_fn – the loss function
history – a list of history step
callbacks_per_batch – the callbacks to be performed on each batch. if None, no callbacks to be run
callbacks_per_batch_loss_terms – the callbacks to be performed on each loss term. if None, no callbacks to be run
apply_backward – if True, the gradient will be back-propagated
- trw.train.trainer.eval_loop(device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)¶
Run the eval loop (i.e., the model parameters will NOT be updated)
Note
If callback_per_batch or callbacks_per_batch_loss_terms raise StopIteration, the eval loop will be stopped
- Parameters
device –
dataset_name –
split_name –
split –
model –
loss_fn –
history –
callbacks_per_batch –
callbacks_per_batch_loss_terms –
- Returns
- trw.train.trainer.epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, eval_loop_fn=eval_loop, train_loop_fn=train_loop)¶
Orchestrate the train and evaluation loops
- Parameters
options –
datasets –
optimizers – if None, no optimization will be performed on the train split else a dictionary of
optimizers (on for each dataset) :param model: :param losses: :param schedulers: :param history: :param callbacks_per_batch: :param callbacks_per_batch_loss_terms: :param run_eval: if True, run the evaluation :param eval_loop_fn: the eval function to be used :param train_loop_fn: the train function to be used :return:
- trw.train.trainer.default_logger¶
- trw.train.trainer.default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True)¶
Default callbacks to be performed before the fitting of the model
- trw.train.trainer.default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None)¶
Default callbacks to be performed at the end of each epoch
- trw.train.trainer.default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True)¶
Default callbacks to be performed after the model has been trained
- trw.train.trainer.default_sum_all_losses(dataset_name, batch, loss_terms)¶
Default loss is the sum of all loss terms
- trw.train.trainer.trainer_callbacks_per_batch(dataset_name, split_name, batch)¶
Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)
- Parameters
dataset_name –
split_name –
batch –
- Returns
- class trw.train.trainer.Trainer(callbacks_per_batch_fn=None, callbacks_per_batch_loss_terms_fn=None, callbacks_per_epoch_fn=default_per_epoch_callbacks, callbacks_pre_training_fn=default_pre_training_callbacks, callbacks_post_training_fn=default_post_training_callbacks, trainer_callbacks_per_batch=trainer_callbacks_per_batch, run_epoch_fn=epoch_train_eval)¶
This is the main class to train a model
- static save_model(model, result, path)¶
Save a model :param model: a PyTorch model :param result: None or the result of the model :param path: where to store the model. The result will be saved at path + ‘.result’
- static load_model(path, with_result=False, device=None)¶
load a saved model
- Parameters
path – where to store the model. result’s will be loaded from path + ‘.result’
with_result – if True, the results of the model will be loaded
device – where to load the model. For example, models are typically trained on GPU, but for deployment, CPU might be good enough. If None, use the same device as when the model was exported
- Returns
a tuple model, result
- fit(self, options, inputs_fn, model_fn, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, run_prefix='default', with_final_evaluation=True, eval_every_X_epoch=1)¶
Fit the model
Requirements:
- enough main memory to store the outputs of all the datasets of a single epoch.
If this cannot be satisfied, sub-sample the epoch so that it can fit in main memory.
Notes:
- if a feature value is Callable, its value will be replaced by the result of the call
(e.g., this can be useful to generate z embedding in GANs)
- Parameters
options –
inputs_fn –
a functor returning a dictionary of datasets. Alternatively, datasets infos can be specified. inputs_fn must return one of:
datasets: dictionary of dataset
(datasets, datasets_infos): dictionary of dataset and additional infos
We define:
datasets: a dictionary of dataset. a dataset is a dictionary of splits. a split is a dictionary of batched features.
Datasets infos are additional infos useful for the debugging of the dataset (e.g., class mappings, sample UIDs).
Datasets infos are typically much smaller than datasets should be loaded in loadable in memory
model_fn – a functor with parameter options and returning a Module or a ModuleDict
Depending of the type of the model, this is how it will be used:
Module: optimizer will optimize model.parameters()
- ModuleDict: for each dataset name, the optimizer will optimize
model[dataset_name].parameters(). Note that a forward method will need to be implemented
- Parameters
losses_fn –
optimizers_fn –
loss_creator –
eval_every_X_epoch – evaluate the model every X epochs
run_prefix – the prefix of the output folder
with_final_evaluation – if True, once the model is fitted, evaluate all the data again in eval mode
- Returns
a tuple model, result
- trw.train.trainer.strip_unpickable(outputs)¶
Remove the objects that cannot be pickled
- trw.train.trainer.run_trainer_repeat(trainer, options, inputs_fn, model_fn, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, run_prefix='default', eval_every_X_epoch=1, number_of_training_runs=10, post_init_fn=None)¶
Manages multiple run of a trainer for example to repeat the training and have an idea of the variance of a model
- Parameters
trainer –
options –
inputs_fn –
model_fn –
optimizers_fn –
losses_fn –
loss_creator –
run_prefix –
eval_every_X_epoch –
number_of_training_runs –
post_init_fn – if not None, a function to be called before each training repeat
- Returns
a tuple model, result of the last model trained