trw.train.trainer_v2
¶
Module Contents¶
Classes¶
Attributes¶
- trw.train.trainer_v2.logger¶
- class trw.train.trainer_v2.TrainerV2(callbacks_per_batch=None, callbacks_per_batch_loss_terms=None, callbacks_per_epoch=default_per_epoch_callbacks(), callbacks_pre_training=default_pre_training_callbacks(), callbacks_post_training=default_post_training_callbacks(), trainer_callbacks_per_batch=trainer_callbacks_per_batch, run_epoch_fn=epoch_train_eval, logging_level=logging.DEBUG, skip_eval_epoch_0=True)¶
- static save_model(model, metadata: trw.train.utilities.RunMetadata, path, pickle_module=pickle)¶
Save a model to file
- Parameters
model – the model to serialize
metadata – an optional result file associated with the model
path – the base path to save the model
pickle_module – the serialization module that will be used to save the model and results
- static load_state(model: torch.nn.Module, path: str, device: torch.device = None, pickle_module: Any = pickle, strict: bool = True) None ¶
Load the state of a model
- Parameters
model – where to load the state
path – where the model’s state was saved
device – where to locate the model
pickle_module – how to read the model parameters and metadata
strict – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function
- static load_model(path: str, model_kwargs: Optional[Dict[Any, Any]] = None, with_result: bool = False, device: torch.device = None, pickle_module: Any = pickle) Tuple[torch.nn.Module, trw.train.utilities.RunMetadata] ¶
Load a previously saved model
Construct a model from the
RunMetadata.class_name
class and with argumentsmodel_kwargs
- Parameters
path – where to store the model. result’s will be loaded from path + ‘.result’
model_kwargs – arguments used to instantiate the model stored in
RunMetadata.class_name
with_result – if True, the results of the model will be loaded
device – where to load the model. For example, models are typically trained on GPU, but for deployment, CPU might be good enough. If None, use the same device as when the model was exported
pickle_module – the de-serialization module to be used to load model and results
- Returns
a tuple model, metadata
- fit(self, options, datasets, model: torch.nn.Module, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, log_path=None, with_final_evaluation=True, history=None, erase_logging_folder=True, eval_every_X_epoch=1) trw.train.utilities.RunMetadata ¶
Fit the model
- Parameters
options –
datasets –
a functor returning a dictionary of datasets. Alternatively, datasets infos can be specified. inputs_fn must return one of:
datasets: dictionary of dataset
(datasets, datasets_infos): dictionary of dataset and additional infos
We define:
datasets: a dictionary of dataset. a dataset is a dictionary of splits. a split is a dictionary of batched features.
Datasets infos are additional infos useful for the debugging of the dataset (e.g., class mappings, sample UIDs). Datasets infos are typically much smaller than datasets should be loaded in loadable in memory
model – a Module or a ModuleDict
optimizers_fn –
losses_fn –
loss_creator –
log_path – the path of the logs to be exported during the training of the model. if the log_path is not an absolute path, the options.workflow_options.logging_directory is used as root
with_final_evaluation –
history –
erase_logging_folder – if True, the logging will be erased when fitting starts
eval_every_X_epoch – evaluate the model every X epochs
Returns: