trw.train.graph_reflection

The purpose of this file is to group all functions related to pytorch graph reflection such as finding layers of specified types in a nn.Module or using the grad_fn

Module Contents

Classes

_CaptureLastModuleType

Capture a specified by type and forward traversal with an optional relative index

Functions

find_tensor_leaves_with_grad(tensor: torch.Tensor) → Sequence[torch.Tensor]

Find the input leaves of a tensor.

find_last_forward_types(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]], relative_index: int = 0) → Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last layer of the specified type

find_last_forward_convolution(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) → Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last convolutional layer

find_first_forward_convolution(model: torch.nn.Module, inputs: Any = None, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) → Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last convolutional layer

Attributes

logger

trw.train.graph_reflection.logger
trw.train.graph_reflection.find_tensor_leaves_with_grad(tensor: torch.Tensor) Sequence[torch.Tensor]

Find the input leaves of a tensor.

Input Leaves REQUIRES have requires_grad=True, else they will not be found

Parameters

tensor – a torch.Tensor

Returns

a list of torch.Tensor with attribute requires_grad=True that is an input of tensor

class trw.train.graph_reflection._CaptureLastModuleType(types_of_module, relative_index=0)

Capture a specified by type and forward traversal with an optional relative index

__call__(self, module, module_input, module_output)
get_module(self)
trw.train.graph_reflection.find_last_forward_types(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]], relative_index: int = 0) Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last layer of the specified type

Parameters
  • inputs – the input of the model so that we can call model(inputs)

  • model – the model

  • types – the types to be captured. Can be a single type or a tuple of types

  • relative_index – indicate which module to return from the last collected module

Returns

None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found

trw.train.graph_reflection.find_last_forward_convolution(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last convolutional layer

Parameters
  • inputs – the input of the model so that we can call model(inputs)

  • model – the model

  • types – the types to be captured. Can be a single type or a tuple of types

  • relative_index (int) – indicate which module to return from the last collected module

Returns

None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found

trw.train.graph_reflection.find_first_forward_convolution(model: torch.nn.Module, inputs: Any = None, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) Optional[Mapping]

Perform a forward pass of the model with given inputs and retrieve the last convolutional layer

Parameters
  • inputs – NOT USED

  • model – the model

  • types – the types to be captured. Can be a single type or a tuple of types

  • relative_index (int) – indicate which module to return from the last collected module

Returns

None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found