trw.layers.deep_supervision
¶
Module Contents¶
Classes¶
Base class for protocol classes. Protocol classes are defined as: |
|
Apply a deep supervision layer to help the flow of gradient reach top level layers. |
Functions¶
|
Weight the outputs proportionally to their spatial extent |
|
- trw.layers.deep_supervision.adaptative_weighting(outputs: Sequence[trw.basic_typing.TorchTensorNCX]) numpy.ndarray ¶
Weight the outputs proportionally to their spatial extent
- trw.layers.deep_supervision.select_third_to_last_skip_before_last(s: Sequence[trw.basic_typing.TorchTensorNCX]) Sequence[trw.basic_typing.TorchTensorNCX] ¶
- class trw.layers.deep_supervision.OutputCreator¶
Bases:
typing_extensions.Protocol
Base class for protocol classes. Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __call__(self, output: trw.basic_typing.TensorNCX, output_truth: trw.basic_typing.TensorNCX, loss_scaling: float) trw.train.outputs_trw.Output ¶
- class trw.layers.deep_supervision.DeepSupervision(backbone: trw.layers.convs.ModuleWithIntermediate, input_target_shape: trw.basic_typing.ShapeCX, output_creator: OutputCreator = OutputSegmentation, output_block: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, select_outputs_fn: Callable[[Sequence[trw.basic_typing.TorchTensorNCX]], Sequence[trw.basic_typing.TorchTensorNCX]] = select_third_to_last_skip_before_last, resize_mode: typing_extensions.Literal[nearest, linear] = 'linear', weighting_fn: Optional[Callable[[Sequence[trw.basic_typing.TorchTensorNCX]], Sequence[float]]] = adaptative_weighting, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None), return_intermediate: bool = False)¶
Bases:
torch.nn.Module
Apply a deep supervision layer to help the flow of gradient reach top level layers.
This is mostly used for segmentation tasks.
Example
>>> import trw >>> backbone = trw.layers.UNetBase(dim=2, input_channels=3, channels=[2, 4, 8], output_channels=2) >>> deep_supervision = DeepSupervision(backbone, [3, 8, 16]) >>> i = torch.zeros([1, 3, 8, 16], dtype=torch.float32) >>> t = torch.zeros([1, 1, 8, 16], dtype=torch.long) >>> outputs = deep_supervision(i, t)
- forward(self, x: torch.Tensor, target: torch.Tensor, latent: Optional[torch.Tensor] = None) Union[List[trw.train.outputs_trw.Output], Tuple[List[trw.train.outputs_trw.Output], List[torch.Tensor]]] ¶