trw.train.trainer
¶
Module Contents¶
Functions¶
|
Create a dictionary of loss functions for each of the dataset |
|
|
|
|
|
|
|
Aggregate the loss terms for all the internal_nodes of an epoch |
|
Perform cleanup on all the loss terms |
|
Run the train loop (i.e., the model parameters will be updated) |
|
Run the eval loop (i.e., the model parameters will NOT be updated) |
|
Calculate on approximation of the number of samples from the loss terms. Error can be up to the number of |
|
|
|
Default callbacks to be performed before the fitting of the model |
|
Default callbacks to be performed at the end of each epoch |
|
Default callbacks to be performed after the model has been trained |
|
Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it) |
|
Remove the objects that cannot be pickled |
Attributes¶
- trw.train.trainer.autocast¶
- trw.train.trainer.logger¶
- trw.train.trainer.create_losses_fn(datasets, generic_loss)¶
Create a dictionary of loss functions for each of the dataset
- Parameters
datasets – the datasets
generic_loss – a loss function
- Returns
A dictionary of losses for each of the dataset
- trw.train.trainer.aggregate_values(values)¶
- trw.train.trainer.aggregate_list_of_dicts(list_of_dicts)¶
- trw.train.trainer.aggregate_list_of_metrics(list_of_metrics)¶
- trw.train.trainer.generic_aggregate_loss_terms(loss_terms_history)¶
Aggregate the loss terms for all the internal_nodes of an epoch
- Parameters
loss_terms_history – a list of loss terms
- Returns
- a tuple output, history. output is maintained alive only during the current epoch.
history is kept in memory during the whole training
- trw.train.trainer.loss_term_cleanup(loss_terms)¶
Perform cleanup on all the loss terms
Requires
outputs.Output.output_ref_tag
tag for each loss term, else no cleanup will be done for this loss term.- Parameters
loss_terms – the loss terms to be cleaned up
- trw.train.trainer.train_loop(options, device, dataset_name, split_name, split, optimizer, per_step_scheduler, model, loss_fn, history, callbacks_per_batch, callbacks_per_batch_loss_terms, gradient_scaler=None)¶
Run the train loop (i.e., the model parameters will be updated)
Note
If callbacks_per_batch or callbacks_per_batch_loss_terms raise an exception StopIteration, the train loop will be stopped
- Parameters
device – the device to be used to optimize the model
dataset_name – the name of the dataset
split_name – the name of the split
split – a dictionary of feature name and values
optimizer – an optimizer to optimize the model
per_step_scheduler – scheduler to be applied per-batch
model – the model to be optimized
loss_fn – the loss function
history – a list of history step
callbacks_per_batch – the callbacks to be performed on each batch. if None, no callbacks to be run
callbacks_per_batch_loss_terms – the callbacks to be performed on each loss term. if None, no callbacks to be run
gradient_scaler – if mixed precision is enabled, this is the scale to be used for the gradient update
Notes
if
optimizer
is None, there MUST be a.backward()
to free graph and memory.
- trw.train.trainer.eval_loop(options, device, dataset_name, split_name, split, model, loss_fn, history, callbacks_per_batch=None, callbacks_per_batch_loss_terms=None)¶
Run the eval loop (i.e., the model parameters will NOT be updated)
Note
If callback_per_batch or callbacks_per_batch_loss_terms raise StopIteration, the eval loop will be stopped
- Parameters
device –
dataset_name –
split_name –
split –
model –
loss_fn –
history –
callbacks_per_batch –
callbacks_per_batch_loss_terms –
- Returns
- trw.train.trainer.approximate_batch_size_from_loss_terms(all_loss_terms)¶
Calculate on approximation of the number of samples from the loss terms. Error can be up to the number of samples within one batch
- trw.train.trainer.epoch_train_eval(options, datasets, optimizers, model, losses, schedulers, per_step_schedulers, history, callbacks_per_batch, callbacks_per_batch_loss_terms, run_eval, force_eval_mode, eval_loop_fn=eval_loop, train_loop_fn=train_loop)¶
- Parameters
options –
datasets –
optimizers –
model –
losses –
schedulers –
per_step_schedulers –
history –
callbacks_per_batch –
callbacks_per_batch_loss_terms –
run_eval –
force_eval_mode –
eval_loop_fn –
train_loop_fn –
Returns:
- trw.train.trainer.default_logger¶
- trw.train.trainer.default_pre_training_callbacks(logger=default_logger, with_lr_finder=False, with_export_augmentations=True, with_reporting_server=True, with_profiler=False, additional_callbacks=None)¶
Default callbacks to be performed before the fitting of the model
- trw.train.trainer.default_per_epoch_callbacks(logger=default_logger, with_worst_samples_by_epoch=True, with_activation_statistics=False, convolutional_kernel_export_frequency=None, additional_callbacks=None)¶
Default callbacks to be performed at the end of each epoch
- trw.train.trainer.default_post_training_callbacks(embedding_name='embedding', dataset_name=None, split_name=None, discard_train_error_export=False, export_errors=True, explain_decision=True, additional_callbacks=None)¶
Default callbacks to be performed after the model has been trained
- trw.train.trainer.trainer_callbacks_per_batch(dataset_name, split_name, batch)¶
Postprocessing step to be run on the batches (e.g., if we have functors, run the functor and replace it)
- Parameters
dataset_name –
split_name –
batch –
- Returns
- trw.train.trainer.strip_unpickable(outputs)¶
Remove the objects that cannot be pickled