trw.layers.backbone_decoder

Module Contents

Classes

BlockUpResizeDeconvSkipConv

Reshape the bottom features to match the transverse feature using linear interpolation

BackboneDecoder

U-net like model with backbone used as encoder.

class trw.layers.backbone_decoder.BlockUpResizeDeconvSkipConv(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, *, kernel_size: Optional[trw.basic_typing.KernelSize] = None, resize_mode: typing_extensions.Literal[linear, nearest] = 'linear', block=BlockConvNormActivation)

Bases: torch.nn.Module

Reshape the bottom features to match the transverse feature using linear interpolation and apply a block on the concatenated (transverse, resampled bottom features)

forward(self, skip: torch.Tensor, previous: torch.Tensor) torch.Tensor
class trw.layers.backbone_decoder.BackboneDecoder(decoding_channels: Sequence[int], output_channels: int, backbone: trw.layers.convs.ModuleWithIntermediate, backbone_transverse_connections: Sequence[int], backbone_input_shape: trw.basic_typing.ShapeNCX, *, up_block_fn: trw.layers.unet_base.UpType = BlockUpResizeDeconvSkipConv, middle_block_fn: trw.layers.unet_base.MiddleType = partial(LatentConv, block=partial(BlockConvNormActivation, kernel_size=5)), output_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, latent_channels: Optional[int] = None, kernel_size: Optional[int] = 3, strides: Union[int, Sequence[int]] = 2, activation: Optional[Any] = None, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None))

Bases: torch.nn.Module, trw.layers.convs.ModuleWithIntermediate

U-net like model with backbone used as encoder.

Examples

>>> import trw
>>> encoder = trw.layers.convs_3d(1, channels=[64, 128, 256])
>>> segmenter = trw.layers.BackboneDecoder([256, 128, 64], 3, encoder, [0, 1, 2], [1, 1, 64, 64, 64])
forward_with_intermediate(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None, **kwargs) List[torch.Tensor]
forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor
Parameters
  • x – the input image

  • latent – a latent variable appended by the middle block