trw.callbacks.callback_reporting_layer_statistics

Module Contents

Classes

CallbackReportingLayerStatistics

Report the activation and gradient statistics layer by layer

Functions

generic_tracing()

Trace only basic building blocks to avoid too much clutter

collect_gradient(model, gradient_store)

Collect the gradient of each parameter of a given model

aggregate_stats(all_stats, batch_stat)

aggregate_stats_end(all_stats)

calculate_stats_gradient(model, sequence, nb_samples, aggregate_stats_fn=aggregate_stats, aggregate_stats_end_fn=aggregate_stats_end, modules_type_to_trace=generic_tracing())

Collect the activation statistics and the gradient update stats for each layer

Attributes

logger

trw.callbacks.callback_reporting_layer_statistics.logger
trw.callbacks.callback_reporting_layer_statistics.generic_tracing()

Trace only basic building blocks to avoid too much clutter

trw.callbacks.callback_reporting_layer_statistics.collect_gradient(model, gradient_store)

Collect the gradient of each parameter of a given model :param model: the model :param gradient_store: where to store the parameter gradients

Returns:

trw.callbacks.callback_reporting_layer_statistics.aggregate_stats(all_stats, batch_stat)
trw.callbacks.callback_reporting_layer_statistics.aggregate_stats_end(all_stats)
trw.callbacks.callback_reporting_layer_statistics.calculate_stats_gradient(model, sequence, nb_samples, aggregate_stats_fn=aggregate_stats, aggregate_stats_end_fn=aggregate_stats_end, modules_type_to_trace=generic_tracing())

Collect the activation statistics and the gradient update stats for each layer

Returns

a tuple (gradient stats, activation stats)

class trw.callbacks.callback_reporting_layer_statistics.CallbackReportingLayerStatistics(dataset_name=None, split_name=None, nb_samples=500, table_name='layer')

Bases: trw.callbacks.callback.Callback

Report the activation and gradient statistics layer by layer

first_time(self, options, datasets)
__call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs)