trw.layers.shift_scale
¶
Module Contents¶
Classes¶
Normalize a tensor with a mean and standard deviation |
Functions¶
|
- trw.layers.shift_scale.transfer_to_device(x: torch.Tensor, device: torch.device) torch.Tensor ¶
- class trw.layers.shift_scale.ShiftScale(mean: Union[float, torch.Tensor], standard_deviation: Union[float, torch.Tensor], output_dtype: torch.dtype = torch.float32)¶
Bases:
torch.nn.Module
Normalize a tensor with a mean and standard deviation
The output tensor will be (x - mean) / standard_deviation
This layer simplify the preprocessing for the trw.simple_layers package
- forward(self, x: torch.Tensor) torch.Tensor ¶
- Parameters
x – a tensor
Returns: return a flattened tensor