trw.train.callback_embedding_statistics
¶
Module Contents¶
Classes¶
Collect statistics on batches and aggregate them |
|
This callback records the statistics of specified embeddings |
Functions¶
Attributes¶
- trw.train.callback_embedding_statistics.logger¶
- trw.train.callback_embedding_statistics.default_statistics()¶
- class trw.train.callback_embedding_statistics.CollectBatchAndProcessStats(model, embedding_names, statistics_fn, number_of_samples_to_evaluate, embedding_output_name)¶
Collect statistics on batches and aggregate them
- __call__(self, dataset_name, split_name, batch, loss_terms)¶
- get_stats(self)¶
- class trw.train.callback_embedding_statistics.CallbackTensorboardEmbedding(embedding_names, dataset_name=None, split_name='test', number_of_samples=2000, statistics=default_statistics(), embedding_output_name='output')¶
Bases:
trw.train.callback.Callback
This callback records the statistics of specified embeddings
- Note: we must recalculate the embedding as we need to associate a specific input (i.e., we can’t store
everything in memory so we need to collect what we need batch by batch)
- first_time(self, datasets)¶
- __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs)¶