trw.train.graph_reflection
¶
The purpose of this file is to group all functions related to pytorch graph reflection such as finding layers of specified types in a nn.Module or using the grad_fn
Module Contents¶
Classes¶
Capture a specified by type and forward traversal with an optional relative index |
Functions¶
|
Find the input leaves of a tensor. |
|
Perform a forward pass of the model with given inputs and retrieve the last layer of the specified type |
|
Perform a forward pass of the model with given inputs and retrieve the last convolutional layer |
|
Perform a forward pass of the model with given inputs and retrieve the last convolutional layer |
Attributes¶
- trw.train.graph_reflection.logger¶
- trw.train.graph_reflection.find_tensor_leaves_with_grad(tensor: torch.Tensor) Sequence[torch.Tensor] ¶
Find the input leaves of a tensor.
Input Leaves REQUIRES have requires_grad=True, else they will not be found
- Parameters
tensor – a torch.Tensor
- Returns
a list of torch.Tensor with attribute requires_grad=True that is an input of tensor
- class trw.train.graph_reflection._CaptureLastModuleType(types_of_module, relative_index=0)¶
Capture a specified by type and forward traversal with an optional relative index
- __call__(self, module, module_input, module_output)¶
- get_module(self)¶
- trw.train.graph_reflection.find_last_forward_types(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]], relative_index: int = 0) Optional[Mapping] ¶
Perform a forward pass of the model with given inputs and retrieve the last layer of the specified type
- Parameters
inputs – the input of the model so that we can call model(inputs)
model – the model
types – the types to be captured. Can be a single type or a tuple of types
relative_index – indicate which module to return from the last collected module
- Returns
None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found
- trw.train.graph_reflection.find_last_forward_convolution(model: torch.nn.Module, inputs: Any, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) Optional[Mapping] ¶
Perform a forward pass of the model with given inputs and retrieve the last convolutional layer
- Parameters
inputs – the input of the model so that we can call model(inputs)
model – the model
types – the types to be captured. Can be a single type or a tuple of types
relative_index (int) – indicate which module to return from the last collected module
- Returns
None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found
- trw.train.graph_reflection.find_first_forward_convolution(model: torch.nn.Module, inputs: Any = None, types: Union[Any, Tuple[Any]] = (nn.Conv2d, nn.Conv3d, nn.Conv1d), relative_index=0) Optional[Mapping] ¶
Perform a forward pass of the model with given inputs and retrieve the last convolutional layer
- Parameters
inputs – NOT USED
model – the model
types – the types to be captured. Can be a single type or a tuple of types
relative_index (int) – indicate which module to return from the last collected module
- Returns
None if no layer found or a dictionary of (outputs, matched_module, matched_module_input, matched_module_output) if found