trw.callbacks.callback_explain_decision

Module Contents

Classes

ExplainableAlgorithm

Generic enumeration.

CallbackExplainDecision

Explain the decision of a model

Functions

default_algorithm_args()

Default algorithm arguments

run_classification_explanation(root, dataset_name, split_name, model, batch, datasets_infos, nb_samples, algorithm_name, algorithm_fn, output_name, algorithm_kwargs=None, nb_explanations=1, epoch=None, average_filters=True)

Run an explanation of a classification output

fill_class_name(output, class_index, datasets_infos, dataset_name, split_name)

Get the class name if available, if not the class index

Attributes

logger

trw.callbacks.callback_explain_decision.logger
class trw.callbacks.callback_explain_decision.ExplainableAlgorithm

Bases: enum.Enum

Generic enumeration.

Derive from this class to define new enumerations.

GuidedBackPropagation
GradCAM
Gradient
IntegratedGradients
MeaningfulPerturbations
trw.callbacks.callback_explain_decision.default_algorithm_args()

Default algorithm arguments

trw.callbacks.callback_explain_decision.run_classification_explanation(root, dataset_name, split_name, model, batch, datasets_infos, nb_samples, algorithm_name, algorithm_fn, output_name, algorithm_kwargs=None, nb_explanations=1, epoch=None, average_filters=True)

Run an explanation of a classification output

trw.callbacks.callback_explain_decision.fill_class_name(output, class_index, datasets_infos, dataset_name, split_name)

Get the class name if available, if not the class index

class trw.callbacks.callback_explain_decision.CallbackExplainDecision(max_samples=10, dirname='explained', dataset_name=None, split_name=None, algorithm=(ExplainableAlgorithm.MeaningfulPerturbations, ExplainableAlgorithm.GuidedBackPropagation, ExplainableAlgorithm.GradCAM, ExplainableAlgorithm.Gradient, ExplainableAlgorithm.IntegratedGradients), output_name=None, nb_explanations=1, algorithms_kwargs=default_algorithm_args(), average_filters=True)

Bases: trw.callbacks.callback.Callback

Explain the decision of a model

first_time(self, datasets, options)
static find_output_name(outputs, name)
__call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs)