trw.layers.deep_supervision

Module Contents

Classes

OutputCreator

Base class for protocol classes. Protocol classes are defined as:

DeepSupervision

Apply a deep supervision layer to help the flow of gradient reach top level layers.

Functions

adaptative_weighting(outputs: Sequence[trw.basic_typing.TorchTensorNCX]) → numpy.ndarray

Weight the outputs proportionally to their spatial extent

select_third_to_last_skip_before_last(s: Sequence[trw.basic_typing.TorchTensorNCX]) → Sequence[trw.basic_typing.TorchTensorNCX]

trw.layers.deep_supervision.adaptative_weighting(outputs: Sequence[trw.basic_typing.TorchTensorNCX]) numpy.ndarray

Weight the outputs proportionally to their spatial extent

trw.layers.deep_supervision.select_third_to_last_skip_before_last(s: Sequence[trw.basic_typing.TorchTensorNCX]) Sequence[trw.basic_typing.TorchTensorNCX]
class trw.layers.deep_supervision.OutputCreator

Bases: typing_extensions.Protocol

Base class for protocol classes. Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing_extensions.runtime act as simple-minded runtime protocol that checks only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__call__(self, output: trw.basic_typing.TensorNCX, output_truth: trw.basic_typing.TensorNCX, loss_scaling: float) trw.train.outputs_trw.Output
class trw.layers.deep_supervision.DeepSupervision(backbone: trw.layers.convs.ModuleWithIntermediate, input_target_shape: trw.basic_typing.ShapeCX, output_creator: OutputCreator = OutputSegmentation, output_block: trw.layers.blocks.ConvBlockType = BlockConvNormActivation, select_outputs_fn: Callable[[Sequence[trw.basic_typing.TorchTensorNCX]], Sequence[trw.basic_typing.TorchTensorNCX]] = select_third_to_last_skip_before_last, resize_mode: typing_extensions.Literal[nearest, linear] = 'linear', weighting_fn: Optional[Callable[[Sequence[trw.basic_typing.TorchTensorNCX]], Sequence[float]]] = adaptative_weighting, config: trw.layers.layer_config.LayerConfig = default_layer_config(dimensionality=None), return_intermediate: bool = False)

Bases: torch.nn.Module

Apply a deep supervision layer to help the flow of gradient reach top level layers.

This is mostly used for segmentation tasks.

Example

>>> import trw
>>> backbone = trw.layers.UNetBase(dim=2, input_channels=3, channels=[2, 4, 8], output_channels=2)
>>> deep_supervision = DeepSupervision(backbone, [3, 8, 16])
>>> i = torch.zeros([1, 3, 8, 16], dtype=torch.float32)
>>> t = torch.zeros([1, 1, 8, 16], dtype=torch.long)
>>> outputs = deep_supervision(i, t)
forward(self, x: torch.Tensor, target: torch.Tensor, latent: Optional[torch.Tensor] = None) Union[List[trw.train.outputs_trw.Output], Tuple[List[trw.train.outputs_trw.Output], List[torch.Tensor]]]