trw.layers.non_local

Module Contents

Classes

BlockNonLocal

Non local block implementation of [1]

Functions

linear_embedding(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) → torch.nn.Module

identity(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) → torch.nn.Module

trw.layers.non_local.linear_embedding(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) torch.nn.Module
trw.layers.non_local.identity(config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int) torch.nn.Module
class trw.layers.non_local.BlockNonLocal(config: trw.layers.layer_config.LayerConfig, input_channels: int, intermediate_channels: int, f_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = identity, g_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = identity, w_mapping_fn: Callable[[trw.layers.layer_config.LayerConfig, int, int], torch.nn.Module] = linear_embedding, normalize_output_fn: torch.nn.Module = nn.Softmax(dim=- 1))

Bases: torch.nn.Module

Non local block implementation of [1]

Defaults to dot product of each feature of each location and using a softmax layer to normalize the attention mask.

[1] https://openaccess.thecvf.com/content_cvpr_2018/papers/Wang_Non-Local_Neural_Networks_CVPR_2018_paper.pdf

Support n-d input data.

forward(self, x: trw.basic_typing.TorchTensorNCX, return_non_local_map: bool = False)