trw.train.callback_embedding_statistics

Module Contents

Classes

CollectBatchAndProcessStats

Collect statistics on batches and aggregate them

CallbackTensorboardEmbedding

This callback records the statistics of specified embeddings

Functions

default_statistics()

Attributes

logger

trw.train.callback_embedding_statistics.logger
trw.train.callback_embedding_statistics.default_statistics()
class trw.train.callback_embedding_statistics.CollectBatchAndProcessStats(model, embedding_names, statistics_fn, number_of_samples_to_evaluate, embedding_output_name)

Collect statistics on batches and aggregate them

__call__(self, dataset_name, split_name, batch, loss_terms)
get_stats(self)
class trw.train.callback_embedding_statistics.CallbackTensorboardEmbedding(embedding_names, dataset_name=None, split_name='test', number_of_samples=2000, statistics=default_statistics(), embedding_output_name='output')

Bases: trw.train.callback.Callback

This callback records the statistics of specified embeddings

Note: we must recalculate the embedding as we need to associate a specific input (i.e., we can’t store

everything in memory so we need to collect what we need batch by batch)

first_time(self, datasets)
__call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs)