trw.train.trainer_v2

Module Contents

Classes

TrainerV2

Attributes

logger

trw.train.trainer_v2.logger
class trw.train.trainer_v2.TrainerV2(callbacks_per_batch=None, callbacks_per_batch_loss_terms=None, callbacks_per_epoch=default_per_epoch_callbacks(), callbacks_pre_training=default_pre_training_callbacks(), callbacks_post_training=default_post_training_callbacks(), trainer_callbacks_per_batch=trainer_callbacks_per_batch, run_epoch_fn=epoch_train_eval, logging_level=logging.DEBUG, skip_eval_epoch_0=True)
static save_model(model, metadata: trw.train.utilities.RunMetadata, path, pickle_module=pickle)

Save a model to file

Parameters
  • model – the model to serialize

  • metadata – an optional result file associated with the model

  • path – the base path to save the model

  • pickle_module – the serialization module that will be used to save the model and results

static load_state(model: torch.nn.Module, path: str, device: torch.device = None, pickle_module: Any = pickle, strict: bool = True) None

Load the state of a model

Parameters
  • model – where to load the state

  • path – where the model’s state was saved

  • device – where to locate the model

  • pickle_module – how to read the model parameters and metadata

  • strict – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function

static load_model(path: str, model_kwargs: Optional[Dict[Any, Any]] = None, with_result: bool = False, device: torch.device = None, pickle_module: Any = pickle) Tuple[torch.nn.Module, trw.train.utilities.RunMetadata]

Load a previously saved model

Construct a model from the RunMetadata.class_name class and with arguments model_kwargs

Parameters
  • path – where to store the model. result’s will be loaded from path + ‘.result’

  • model_kwargs – arguments used to instantiate the model stored in RunMetadata.class_name

  • with_result – if True, the results of the model will be loaded

  • device – where to load the model. For example, models are typically trained on GPU, but for deployment, CPU might be good enough. If None, use the same device as when the model was exported

  • pickle_module – the de-serialization module to be used to load model and results

Returns

a tuple model, metadata

fit(self, options, datasets, model: torch.nn.Module, optimizers_fn, losses_fn=default_sum_all_losses, loss_creator=create_losses_fn, log_path=None, with_final_evaluation=True, history=None, erase_logging_folder=True, eval_every_X_epoch=1) trw.train.utilities.RunMetadata

Fit the model

Parameters
  • options

  • datasets

    a functor returning a dictionary of datasets. Alternatively, datasets infos can be specified. inputs_fn must return one of:

    • datasets: dictionary of dataset

    • (datasets, datasets_infos): dictionary of dataset and additional infos

    We define:

    • datasets: a dictionary of dataset. a dataset is a dictionary of splits. a split is a dictionary of batched features.

    • Datasets infos are additional infos useful for the debugging of the dataset (e.g., class mappings, sample UIDs). Datasets infos are typically much smaller than datasets should be loaded in loadable in memory

  • model – a Module or a ModuleDict

  • optimizers_fn

  • losses_fn

  • loss_creator

  • log_path – the path of the logs to be exported during the training of the model. if the log_path is not an absolute path, the options.workflow_options.logging_directory is used as root

  • with_final_evaluation

  • history

  • erase_logging_folder – if True, the logging will be erased when fitting starts

  • eval_every_X_epoch – evaluate the model every X epochs

Returns: