trw.train.utils

Module Contents

Classes

time_it

Simple decorator to measure the time taken to execute a function

CleanAddedHooks

Context manager that automatically track added hooks on the model and remove them when

RuntimeFormatter

Report the time since this formatter is instantiated

Functions

safe_filename(filename)

Clean the filename so that it can be used as a valid filename

log_info(msg)

Log the message to a log file as info

log_and_print(msg)

Log the message to a log file as info

log_console(msg)

Log the message to the console

len_batch(batch)

param batch

a data split or a collections.Sequence

create_or_recreate_folder(path, nb_tries=3, wait_time_between_tries=2.0)

Check if the path exist. If yes, remove the folder then recreate the folder, else create it

to_value(v)

Convert where appropriate from tensors to numpy arrays

make_unique_colors()

Return a set of unique and easily distinguishable colors

make_unique_colors_f()

Return a set of unique and easily distinguishable colors

get_class_name(mapping, classid)

get_classification_mappings(datasets_infos, dataset_name, split_name)

Return the output mappings of a classification output from the datasets infos

get_classification_mapping(datasets_infos, dataset_name, split_name, output_name)

Return the output mappings of a classification output from the datasets infos

set_optimizer_learning_rate(optimizer, learning_rate)

Set the learning rate of the optimizer to a specific value

collate_tensors(values, device, pin_memory=False, non_blocking=False)

express values as a torch.Tensor

collate_dicts(batch, device, pin_memory=False, non_blocking=False)

Default function to collate a dictionary of samples to a dictionary of torch.Tensor

collate_list_of_dicts(batches, device, pin_memory=False, non_blocking=False)

Default function to collate a list of dictionary to a dictionary of `torch.Tensor`s

default_collate_fn(batch, device, pin_memory=False, non_blocking=False)

param batches

a dictionary of features or a list of dictionary of features

transfer_batch_to_device(batch, device, non_blocking=False)

Transfer the Tensors and numpy arrays to the specified device. Other types will not be moved.

get_device(module, batch=None)

Return the device of a module. This may be incorrect if we have a module split accross different devices

find_default_dataset_and_split_names(datasets, default_dataset_name=None, default_split_name=None, train_split_name=None)

Return a good choice of dataset name and split name, possibly not the train split.

Attributes

logger

trw.train.utils.logger
trw.train.utils.safe_filename(filename)

Clean the filename so that it can be used as a valid filename

trw.train.utils.log_info(msg)

Log the message to a log file as info :param msg: :return:

trw.train.utils.log_and_print(msg)

Log the message to a log file as info :param msg: :return:

trw.train.utils.log_console(msg)

Log the message to the console :param msg: :return:

trw.train.utils.len_batch(batch)
Parameters

batch – a data split or a collections.Sequence

Returns

the number of elements within a data split

trw.train.utils.create_or_recreate_folder(path, nb_tries=3, wait_time_between_tries=2.0)

Check if the path exist. If yes, remove the folder then recreate the folder, else create it

class trw.train.utils.time_it(time_name=None, log=None)

Simple decorator to measure the time taken to execute a function :param time_name: the name of the function to time, else we will use fn.__str__() :param log: how to log the timing

__call__(self, fn, *args, **kwargs)
trw.train.utils.to_value(v)

Convert where appropriate from tensors to numpy arrays :param v: :return:

trw.train.utils.make_unique_colors()

Return a set of unique and easily distinguishable colors :return: a list of RBG colors

trw.train.utils.make_unique_colors_f()

Return a set of unique and easily distinguishable colors :return: a list of RBG colors

trw.train.utils.get_class_name(mapping, classid)
trw.train.utils.get_classification_mappings(datasets_infos, dataset_name, split_name)

Return the output mappings of a classification output from the datasets infos

Parameters
  • datasets_infos – the info of the datasets

  • dataset_name – the name of the dataset

  • split_name – the split name

  • output_name – the output name

Returns

a dictionary {outputs: {‘mapping’: {name->ID}, ‘mappinginv’: {ID->name}}}

trw.train.utils.get_classification_mapping(datasets_infos, dataset_name, split_name, output_name)

Return the output mappings of a classification output from the datasets infos

Parameters
  • datasets_infos – the info of the datasets

  • dataset_name – the name of the dataset

  • split_name – the split name

  • output_name – the output name

Returns

a dictionary {‘mapping’: {name->ID}, ‘mappinginv’: {ID->name}}

trw.train.utils.set_optimizer_learning_rate(optimizer, learning_rate)

Set the learning rate of the optimizer to a specific value

Parameters
  • optimizer – the optimizer to update

  • learning_rate – the learning rate to set

Returns

None

trw.train.utils.collate_tensors(values, device, pin_memory=False, non_blocking=False)

express values as a torch.Tensor

Parameters
  • values – nd.array or torch.Tensor

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a Cuda allocated torch.Tensor

Returns

a torch.Tensor if of type numpy.ndarray else, the input type

trw.train.utils.collate_dicts(batch, device, pin_memory=False, non_blocking=False)

Default function to collate a dictionary of samples to a dictionary of torch.Tensor

Parameters
  • batch – a dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

Returns

a dictionary of torch.Tensor

trw.train.utils.collate_list_of_dicts(batches, device, pin_memory=False, non_blocking=False)

Default function to collate a list of dictionary to a dictionary of `torch.Tensor`s

Parameters
  • batches – a list of dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

Returns

a dictionary of torch.Tensor

trw.train.utils.default_collate_fn(batch, device, pin_memory=False, non_blocking=False)
Parameters
  • batches – a dictionary of features or a list of dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

Returns

a dictionary of torch.Tensor

trw.train.utils.transfer_batch_to_device(batch, device, non_blocking=False)

Transfer the Tensors and numpy arrays to the specified device. Other types will not be moved.

Parameters
  • batch – the batch of data to be transferred

  • device – the device to move the tensors to

  • non_blocking – non blocking memory transfer to GPU

Returns

a batch of data on the specified device

class trw.train.utils.CleanAddedHooks(model)

Context manager that automatically track added hooks on the model and remove them when the context is released

__enter__(self)
__exit__(self, type, value, traceback)
static record_hooks(module_source)

Record hooks :param module_source: the module to track the hooks

Returns

at tuple (forward, backward). forward and backward are a dictionary of hooks ID by module

trw.train.utils.get_device(module, batch=None)

Return the device of a module. This may be incorrect if we have a module split accross different devices

class trw.train.utils.RuntimeFormatter(*args, **kwargs)

Bases: logging.Formatter

Report the time since this formatter is instantiated

formatTime(self, record, datefmt=None)

Return the creation time of the specified LogRecord as formatted text.

This method should be called from format() by a formatter which wants to make use of a formatted time. This method can be overridden in formatters to provide for any specific requirement, but the basic behaviour is as follows: if datefmt (a string) is specified, it is used with time.strftime() to format the creation time of the record. Otherwise, an ISO8601-like (or RFC 3339-like) format is used. The resulting string is returned. This function uses a user-configurable function to convert the creation time to a tuple. By default, time.localtime() is used; to change this for a particular formatter instance, set the ‘converter’ attribute to a function with the same signature as time.localtime() or time.gmtime(). To change it for all formatters, for example if you want all logging times to be shown in GMT, set the ‘converter’ attribute in the Formatter class.

trw.train.utils.find_default_dataset_and_split_names(datasets, default_dataset_name=None, default_split_name=None, train_split_name=None)

Return a good choice of dataset name and split name, possibly not the train split.

Parameters
  • datasets – the datasets

  • default_dataset_name – a possible dataset name. If None, find a suitable dataset, if not, the dataset must be present

  • default_split_name – a possible split name. If None, find a suitable split, if not, the dataset must be present. if train_split_name is specified, the selected split name will be different from train_split_name

  • train_split_name – if not None, exclude the train split

Returns

a tuple (dataset_name, split_name)