trw.layers.unet_base

Module Contents

Classes

DownType

Base class for protocol classes. Protocol classes are defined as:

BlockConvType

Base class for protocol classes. Protocol classes are defined as:

BlockTypeConvSkip

Base class for protocol classes. Protocol classes are defined as:

Down

Base class for all neural network modules.

UpType

Base class for protocol classes. Protocol classes are defined as:

Up

Base class for all neural network modules.

MiddleType

Base class for protocol classes. Protocol classes are defined as:

LatentConv

Concatenate a latent variable (possibly resized to the input shape) and apply a convolution

UNetBase

Configurable UNet-like architecture

Attributes

UpResize

class trw.layers.unet_base.DownType

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module
class trw.layers.unet_base.BlockConvType

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int, **kwargs) torch.nn.Module
class trw.layers.unet_base.BlockTypeConvSkip

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, config: trw.layers.layer_config.LayerConfig, skip_channels: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module
class trw.layers.unet_base.Down(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, block: BlockConvType = BlockConvNormActivation, **block_kwargs)

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

Variables

training (bool) – Boolean represents whether this module is in training or evaluation mode.

forward(self, x: torch.Tensor) torch.Tensor
class trw.layers.unet_base.UpType

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module
class trw.layers.unet_base.Up(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, block: BlockTypeConvSkip = BlockUpDeconvSkipConv, **block_kwargs)

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

Variables

training (bool) – Boolean represents whether this module is in training or evaluation mode.

forward(self, skip: torch.Tensor, previous: torch.Tensor) torch.Tensor
trw.layers.unet_base.UpResize
class trw.layers.unet_base.MiddleType

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, latent_channels: Optional[int], **kwargs) torch.nn.Module
class trw.layers.unet_base.LatentConv(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, latent_channels: Optional[int] = None, block: BlockConvType = BlockConvNormActivation, **block_kwargs)

Bases: torch.nn.Module

Concatenate a latent variable (possibly resized to the input shape) and apply a convolution

forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor
class trw.layers.unet_base.UNetBase(dim: int, input_channels: int, channels: Sequence[int], output_channels: int, down_block_fn: DownType = Down, up_block_fn: UpType = UpResize, init_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, middle_block_fn: MiddleType = partial(LatentConv, block=partial(BlockConvNormActivation, kernel_size=5)), output_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, init_block_channels: Optional[int] = None, latent_channels: Optional[int] = None, kernel_size: Optional[int] = 3, strides: Union[int, Sequence[int]] = 2, activation: Optional[Any] = nn.PReLU, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None), add_last_downsampling_to_intermediates: bool = False)

Bases: torch.nn.Module, trw.layers.convs.ModuleWithIntermediate

Configurable UNet-like architecture

_build(self, config, init_block_fn, down_block_fn, up_block_fn, middle_block_fn, output_block_fn, strides)
forward_with_intermediate(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None, **kwargs) Sequence[torch.Tensor]
forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor
Parameters
  • x – the input image

  • latent – a latent variable appended by the middle block