trw.layers.backbone_decoder
¶
Module Contents¶
Classes¶
Reshape the bottom features to match the transverse feature using linear interpolation |
|
U-net like model with backbone used as encoder. |
- class trw.layers.backbone_decoder.BlockUpResizeDeconvSkipConv(layer_config: trw.layers.layer_config.LayerConfig, bloc_level: int, skip_channels: int, input_channels: int, output_channels: int, *, kernel_size: Optional[trw.basic_typing.KernelSize] = None, resize_mode: typing_extensions.Literal[linear, nearest] = 'linear', block=BlockConvNormActivation)¶
Bases:
torch.nn.Module
Reshape the bottom features to match the transverse feature using linear interpolation and apply a block on the concatenated (transverse, resampled bottom features)
- forward(self, skip: torch.Tensor, previous: torch.Tensor) torch.Tensor ¶
- class trw.layers.backbone_decoder.BackboneDecoder(decoding_channels: Sequence[int], output_channels: int, backbone: trw.layers.convs.ModuleWithIntermediate, backbone_transverse_connections: Sequence[int], backbone_input_shape: trw.basic_typing.ShapeNCX, *, up_block_fn: trw.layers.unet_base.UpType = BlockUpResizeDeconvSkipConv, middle_block_fn: trw.layers.unet_base.MiddleType = partial(LatentConv, block=partial(BlockConvNormActivation, kernel_size=5)), output_block_fn: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, latent_channels: Optional[int] = None, kernel_size: Optional[int] = 3, strides: Union[int, Sequence[int]] = 2, activation: Optional[Any] = None, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None))¶
Bases:
torch.nn.Module
,trw.layers.convs.ModuleWithIntermediate
U-net like model with backbone used as encoder.
Examples
>>> import trw >>> encoder = trw.layers.convs_3d(1, channels=[64, 128, 256]) >>> segmenter = trw.layers.BackboneDecoder([256, 128, 64], 3, encoder, [0, 1, 2], [1, 1, 64, 64, 64])
- forward_with_intermediate(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None, **kwargs) List[torch.Tensor] ¶
- forward(self, x: torch.Tensor, latent: Optional[torch.Tensor] = None) torch.Tensor ¶
- Parameters
x – the input image
latent – a latent variable appended by the middle block