trw.layers.shift_scale

Module Contents

Classes

ShiftScale

Normalize a tensor with a mean and standard deviation

class trw.layers.shift_scale.ShiftScale(mean, standard_deviation)

Bases: torch.nn.Module

Normalize a tensor with a mean and standard deviation

The output tensor will be (x - mean) / standard_deviation

This layer simplify the preprocessing for the trw.simple_layers package

forward(self, x)
Parameters

x – a tensor

Returns: return a flattened tensor