trw.layers.shift_scale

Module Contents

Classes

ShiftScale

Normalize a tensor with a mean and standard deviation

Functions

transfer_to_device(x: torch.Tensor, device: torch.device) → torch.Tensor

trw.layers.shift_scale.transfer_to_device(x: torch.Tensor, device: torch.device) torch.Tensor
class trw.layers.shift_scale.ShiftScale(mean: Union[float, torch.Tensor], standard_deviation: Union[float, torch.Tensor], output_dtype: torch.dtype = torch.float32)

Bases: torch.nn.Module

Normalize a tensor with a mean and standard deviation

The output tensor will be (x - mean) / standard_deviation

This layer simplify the preprocessing for the trw.simple_layers package

forward(self, x: torch.Tensor) torch.Tensor
Parameters

x – a tensor

Returns: return a flattened tensor