trw.callbacks.callback_save_last_model
¶
Module Contents¶
Classes¶
Save the current model to disk as well as metadata (history, outputs, infos). |
Functions¶
|
Remove from the outputs embeddings larger than a specified threshold. |
|
Attributes¶
- trw.callbacks.callback_save_last_model.logger¶
- class trw.callbacks.callback_save_last_model.ModelWithLowestMetricBase¶
- abstract update(self, metric_value, model, metadata, root_path)¶
- class trw.callbacks.callback_save_last_model.ModelWithLowestMetric(dataset_name, split_name, output_name, metric_name, minimum_metric=0.2)¶
Bases:
ModelWithLowestMetricBase
- update(self, metric_value, model, metadata, root_path)¶
Check the metrics and export the model if thresholds are satisfied
- trw.callbacks.callback_save_last_model.exclude_large_embeddings(outputs: trw.basic_typing.Datasets, counts_greater_than=10000) Optional[trw.basic_typing.Datasets] ¶
Remove from the outputs embeddings larger than a specified threshold.
- Parameters
outputs – the outputs to check
counts_greater_than – the number of elements above which the embedding will be stripped
- Returns
outputs
- trw.callbacks.callback_save_last_model.should_not_export_model(last_step, revert_if_nan_metrics)¶
- class trw.callbacks.callback_save_last_model.CallbackSaveLastModel(model_name='last', with_outputs=False, is_versioned=False, rolling_size=None, keep_model_with_best_metric: ModelWithLowestMetric = None, revert_if_nan_metrics: Optional[Sequence[str]] = ('loss',), post_process_outputs: Optional[Callable[[trw.basic_typing.Datasets], trw.basic_typing.Datasets]] = exclude_large_embeddings)¶
Bases:
trw.callbacks.callback.Callback
Save the current model to disk as well as metadata (history, outputs, infos).
This callback can be used during training (e.g., checkpoint) or at the end of the training.
Optionally, record the best model for a given dataset, split, output and metric.
- __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs)¶