trw.layers.non_local
¶
Module Contents¶
Classes¶
Non local block implementation of [1] |
Functions¶
|
|
|
- trw.layers.non_local.linear_embedding(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) torch.nn.Module ¶
- trw.layers.non_local.identity(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) torch.nn.Module ¶
- class trw.layers.non_local.BlockNonLocal(config: trw.layers.layer_config.LayerConfig, input_channels: int, intermediate_channels: int, f_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = identity, g_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = identity, w_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = linear_embedding, normalize_output_fn: torch.nn.Module = nn.Softmax(dim=- 1))¶
Bases:
torch.nn.Module
Non local block implementation of [1]
Defaults to dot product of each feature of each location and using a softmax layer to normalize the attention mask.
Support n-d input data.
- forward(self, x: trw.basic_typing.TorchTensorNCX, return_non_local_map: bool = False)¶