trw.train.guided_back_propagation
¶
Module Contents¶
Classes¶
Produces gradients generated with guided back propagation from the given image |
Functions¶
|
|
|
Postptocess the output to be suitable for gradient attribution. |
Attributes¶
- trw.train.guided_back_propagation.logger¶
- trw.train.guided_back_propagation.post_process_output_id(output: trw.train.outputs_trw.Output) torch.Tensor ¶
- trw.train.guided_back_propagation.post_process_output_for_gradient_attribution(output: trw.train.outputs_trw.Output)¶
Postptocess the output to be suitable for gradient attribution.
In particular, if we have a
trw.train.OutputClassification
, we need to apply a softmax operation so that we can backpropagate the loss of a particular class with the appropriate value (1.0).- Parameters
output – a
trw.train.OutputClassification
- Returns
a
torch.Tensor
- class trw.train.guided_back_propagation.GuidedBackprop(model: torch.nn.Module, unguided_gradient: bool = False, post_process_output: Callable[[Any], torch.Tensor] = post_process_output_id)¶
Produces gradients generated with guided back propagation from the given image
- update_relus(self) None ¶
- Updates relu activation functions so that
1- stores output in forward pass 2- imputes zero for gradient values that are less than zero
- static get_floating_inputs_with_gradients(inputs)¶
Extract inputs that have a gradient
- Parameters
inputs – a tensor of dictionary of tensors
- Returns
Return a list of tuple (name, input) for the input that have a gradient
- __call__(self, inputs: Tuple[torch.Tensor, trw.basic_typing.Batch], target_class: int, target_class_name: str) Optional[Tuple[str, Mapping]] ¶
Generate the guided back-propagation gradient
- Parameters
inputs – a tensor or dictionary of tensors
target_class – the target class to be explained
target_class_name – the name of the output class if multiple outputs
- Returns
a tuple (output_name, dictionary (input, gradient))
- static get_positive_negative_saliency(gradient: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] ¶
Generates positive and negative saliency maps based on the gradient
- Parameters
gradient (numpy arr) – Gradient of the operation to visualize
- Returns
pos_saliency ( )