trw.layers.unet_attention

Module Contents

Classes

BlockAttention

Attention UNet style of attention.

MergeBlockAttention_Gating_Input

Base class for all neural network modules.

Attributes

UNetAttention

class trw.layers.unet_attention.BlockAttention(config: trw.layers.layer_config.LayerConfig, gating_channels: int, input_channels: int, intermediate_channels: int)

Bases: torch.nn.Module

Attention UNet style of attention.

See:

“Attention U-Net: Learning Where to Look for the Pancreas”, https://arxiv.org/pdf/1804.03999.pdf

Parameters

nn (_type_) – _description_

forward(self, g: torch.Tensor, x: torch.Tensor) torch.Tensor
class trw.layers.unet_attention.MergeBlockAttention_Gating_Input(config: trw.layers.layer_config.LayerConfig, layer_channels: Sequence[int], attention_block_fn=BlockAttention, num_intermediate_fn: Callable[[Sequence[int]], int] = lambda layer_channels: ...)

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

Variables

training (bool) – Boolean represents whether this module is in training or evaluation mode.

get_output_channels(self)
forward(self, layers: Sequence[torch.Tensor]) torch.Tensor
trw.layers.unet_attention.UNetAttention