trw.layers.unet_base
¶
Module Contents¶
Classes¶
Base class for protocol classes. Protocol classes are defined as: |
|
Base class for protocol classes. Protocol classes are defined as: |
|
Base class for protocol classes. Protocol classes are defined as: |
|
Base class for all neural network modules. |
|
Base class for protocol classes. Protocol classes are defined as: |
|
Base class for all neural network modules. |
|
Base class for protocol classes. Protocol classes are defined as: |
|
Concatenate a latent variable (possibly resized to the input shape) and apply a convolution |
|
Configurable UNet-like architecture |
Attributes¶
- class trw.layers.unet_base.DownType¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module ¶
- class trw.layers.unet_base.BlockConvType¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, config: trw.layers.layer_config.LayerConfig, input_channels: int, output_channels: int, **kwargs) torch.nn.Module ¶
- class trw.layers.unet_base.BlockTypeConvSkip¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, config: trw.layers.layer_config.LayerConfig, skip_channels: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module ¶
- class trw.layers.unet_base.Down(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, block: BlockConvType = BlockConvNormActivation, **block_kwargs)¶
Bases:
torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call
to()
, etc.- Variables
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- forward(self, x: torch.Tensor) torch.Tensor ¶
- class trw.layers.unet_base.UpType¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, **kwargs) torch.nn.Module ¶
- class trw.layers.unet_base.Up(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, block: BlockTypeConvSkip = BlockUpDeconvSkipConv, **block_kwargs)¶
Bases:
torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call
to()
, etc.- Variables
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- forward(self, skip: torch.Tensor, previous: torch.Tensor) torch.Tensor ¶
- trw.layers.unet_base.UpResize¶
- class trw.layers.unet_base.MiddleType¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, latent_channels: Optional[int], **kwargs) torch.nn.Module ¶
- class trw.layers.unet_base.LatentConv(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, input_channels: int, output_channels: int, latent_channels: Optional[int] = None, block: BlockConvType = BlockConvNormActivation, **block_kwargs)¶
Bases:
torch.nn.Module
Concatenate a latent variable (possibly resized to the input shape) and apply a convolution
- forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor ¶
- class trw.layers.unet_base.UNetBase(dim: int, input_channels: int, channels: Sequence[int], output_channels: int, down_block_fn: DownType = Down, up_block_fn: UpType = UpResize, init_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, middle_block_fn: MiddleType = partial(LatentConv, block=partial(BlockConvNormActivation, kernel_size=5)), output_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, init_block_channels: Optional[int] = None, latent_channels: Optional[int] = None, kernel_size: Optional[int] = 3, strides: Union[int, Sequence[int]] = 2, activation: Optional[Any] = nn.PReLU, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None), add_last_downsampling_to_intermediates: bool = False)¶
Bases:
torch.nn.Module
,trw.layers.convs.ModuleWithIntermediate
Configurable UNet-like architecture
- _build(self, config, init_block_fn, down_block_fn, up_block_fn, middle_block_fn, output_block_fn, strides)¶
- forward_with_intermediate(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None, **kwargs) Sequence[torch.Tensor] ¶
- forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor ¶
- Parameters
x – the input image
latent – a latent variable appended by the middle block