trw.train.utilities
¶
Module Contents¶
Classes¶
Accept None context manager. In that case do nothing, else execute |
|
Context manager that automatically track added hooks on the model and remove them when |
|
Report the time since this formatter is instantiated |
|
Functions¶
|
Clean the filename so that it can be used as a valid filename |
|
Log the message to a log file as info |
|
Log the message to a log file as info |
|
Log the message to the console |
|
Check if the path exist. If yes, remove the folder then recreate the folder, else create it |
Return a set of unique and easily distinguishable colors |
|
Return a set of unique and easily distinguishable colors |
|
|
|
|
Return the output mappings of a classification output from the datasets infos |
|
Return the output mappings of a classification output from the datasets infos |
|
Set the learning rate of the optimizer to a specific value |
|
Transfer the Tensors and numpy arrays to the specified device. Other types will not be moved. |
|
Return the device of a module. This may be incorrect if we have a module split accross different devices |
|
Return a good choice of dataset name and split name, possibly not the train split. |
|
Make random index triplets (anchor, positive, negative) such that |
|
Make random indices of pairs of samples that belongs or not to the same target. |
|
Update a JSON document stored on a local drive. |
|
Return the loss_terms for the given outputs |
|
Default loss is the sum of all loss terms |
|
Post process a batch of data (e.g., this can be useful to add additional |
|
Apply spectral norm on every sub-modules |
|
Apply gradient clipping recursively on a module as callback. |
Attributes¶
- trw.train.utilities.logger¶
- trw.train.utilities.safe_filename(filename)¶
Clean the filename so that it can be used as a valid filename
- trw.train.utilities.log_info(msg)¶
Log the message to a log file as info :param msg: :return:
- trw.train.utilities.log_and_print(msg)¶
Log the message to a log file as info :param msg: :return:
- trw.train.utilities.log_console(msg)¶
Log the message to the console :param msg: :return:
- trw.train.utilities.create_or_recreate_folder(path, nb_tries=3, wait_time_between_tries=2.0)¶
Check if the path exist. If yes, remove the folder then recreate the folder, else create it
- Parameters
path – the path to create or recreate
nb_tries – the number of tries to be performed before failure
wait_time_between_tries – the time to wait before the next try
- Returns
True
if successful orFalse
if failed.
- trw.train.utilities.make_unique_colors()¶
Return a set of unique and easily distinguishable colors :return: a list of RBG colors
- trw.train.utilities.make_unique_colors_f()¶
Return a set of unique and easily distinguishable colors :return: a list of RBG colors
- trw.train.utilities.get_class_name(mapping, classid)¶
- trw.train.utilities.get_classification_mappings(datasets_infos, dataset_name, split_name)¶
Return the output mappings of a classification output from the datasets infos
- Parameters
datasets_infos – the info of the datasets
dataset_name – the name of the dataset
split_name – the split name
- Returns
a dictionary {outputs: {‘mapping’: {name->ID}, ‘mappinginv’: {ID->name}}}
- trw.train.utilities.get_classification_mapping(datasets_infos, dataset_name, split_name, output_name)¶
Return the output mappings of a classification output from the datasets infos
- Parameters
datasets_infos – the info of the datasets
dataset_name – the name of the dataset
split_name – the split name
output_name – the output name
- Returns
a dictionary {‘mapping’: {name->ID}, ‘mappinginv’: {ID->name}}
- trw.train.utilities.set_optimizer_learning_rate(optimizer, learning_rate)¶
Set the learning rate of the optimizer to a specific value
- Parameters
optimizer – the optimizer to update
learning_rate – the learning rate to set
- Returns
None
- trw.train.utilities.transfer_batch_to_device(batch, device, non_blocking=True)¶
Transfer the Tensors and numpy arrays to the specified device. Other types will not be moved.
- Parameters
batch – the batch of data to be transferred
device – the device to move the tensors to
non_blocking – non blocking memory transfer to GPU
- Returns
a batch of data on the specified device
- class trw.train.utilities.NullableContextManager(base_context_manager: Optional[Any])¶
Accept None context manager. In that case do nothing, else execute the context manager enter and exit.
This is a helper class to simplify the handling of possibly None context manager.
- __enter__(self)¶
- __exit__(self, type, value, traceback)¶
- class trw.train.utilities.CleanAddedHooks(model)¶
Context manager that automatically track added hooks on the model and remove them when the context is released
- __enter__(self)¶
- __exit__(self, type, value, traceback)¶
- static record_hooks(module_source)¶
Record hooks :param module_source: the module to track the hooks
- Returns
at tuple (forward, backward). forward and backward are a dictionary of hooks ID by module
- trw.train.utilities.get_device(module, batch=None)¶
Return the device of a module. This may be incorrect if we have a module split accross different devices
- class trw.train.utilities.RuntimeFormatter(*args, **kwargs)¶
Bases:
logging.Formatter
Report the time since this formatter is instantiated
- formatTime(self, record, datefmt=None)¶
Return the creation time of the specified LogRecord as formatted text.
This method should be called from format() by a formatter which wants to make use of a formatted time. This method can be overridden in formatters to provide for any specific requirement, but the basic behaviour is as follows: if datefmt (a string) is specified, it is used with time.strftime() to format the creation time of the record. Otherwise, an ISO8601-like (or RFC 3339-like) format is used. The resulting string is returned. This function uses a user-configurable function to convert the creation time to a tuple. By default, time.localtime() is used; to change this for a particular formatter instance, set the ‘converter’ attribute to a function with the same signature as time.localtime() or time.gmtime(). To change it for all formatters, for example if you want all logging times to be shown in GMT, set the ‘converter’ attribute in the Formatter class.
- trw.train.utilities.find_default_dataset_and_split_names(datasets, default_dataset_name=None, default_split_name=None, train_split_name=None)¶
Return a good choice of dataset name and split name, possibly not the train split.
- Parameters
datasets – the datasets
default_dataset_name – a possible dataset name. If None, find a suitable dataset, if not, the dataset must be present
default_split_name – a possible split name. If None, find a suitable split, if not, the dataset must be present. if train_split_name is specified, the selected split name will be different from train_split_name
train_split_name – if not None, exclude the train split
- Returns
a tuple (dataset_name, split_name)
- trw.train.utilities.make_triplet_indices(targets)¶
- Make random index triplets (anchor, positive, negative) such that
anchor
andpositive
belong to the same target while
negative
belongs to a different target
- Parameters
targets – a 1D integral tensor in range [0..C]
- Returns
a tuple of indices (samples, samples_positive, samples_negative)
- Make random index triplets (anchor, positive, negative) such that
- trw.train.utilities.make_pair_indices(targets, same_target_ratio=0.5)¶
Make random indices of pairs of samples that belongs or not to the same target.
- Parameters
same_target_ratio – specify the ratio of same target to be generated for sample pairs
targets – a 1D integral tensor in range [0..C] to be used to group the samples into same or different target
- Returns
a tuple with (samples_0 indices, samples_1 indices, same_target)
- trw.train.utilities.update_json_config(path_to_json, config_update)¶
Update a JSON document stored on a local drive.
- Parameters
path_to_json – the path to the local JSON configuration
config_update – a possibly nested dictionary
- trw.train.utilities.prepare_loss_terms(outputs, batch, is_training)¶
Return the loss_terms for the given outputs
- trw.train.utilities.default_sum_all_losses(dataset_name, batch, loss_terms)¶
Default loss is the sum of all loss terms
- trw.train.utilities.postprocess_batch(dataset_name, split_name, batch, callbacks_per_batch, batch_id=None)¶
Post process a batch of data (e.g., this can be useful to add additional data to the current batch)
- Parameters
dataset_name (str) – the name of the dataset the batch belongs to
split_name (str) – the name of the split the batch belongs to
batch – the current batch of data
callbacks_per_batch (list) – the callbacks to be executed for each batch. Each callback must be callable with (dataset_name, split_name, batch). if None, no callbacks
batch_id – indicate the current batch within an epoch. May be
None
. This can be useful for embedding optimizer within a module (e.g., scheduler support)
- trw.train.utilities.apply_spectral_norm(module, n_power_iterations=1, eps=1e-12, dim=None, name='weight', discard_layer_types=(torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d))¶
Apply spectral norm on every sub-modules
- Parameters
module – the parent module to apply spectral norm
discard_layer_types – the layers_legacy of this type will not have spectral norm applied
n_power_iterations – number of power iterations to calculate spectral norm
eps – epsilon for numerical stability in calculating norms
dim – dimension corresponding to number of outputs, the default is
0
, except for modules that are instances of ConvTranspose{1,2,3}d, when it is1
name – name of weight parameter
- Returns
the same module as input module
- trw.train.utilities.apply_gradient_clipping(module: torch.nn.Module, value)¶
Apply gradient clipping recursively on a module as callback.
Every time the gradient is calculated, it is intercepted and clipping applied.
- Parameters
module – a module where sub-modules will have their gradients clipped
value – the maximum value of the gradient
- class trw.train.utilities.RunMetadata(options: Optional[trw.train.options.Options], history: Optional[trw.basic_typing.History], outputs: Optional[Any], datasets_infos: Optional[trw.basic_typing.DatasetsInfo] = None, class_name: Optional[str] = None)¶