trw.train.upsample

Module Contents

Functions

_upsample_int_1d(tensor, size)

_upsample_int_2d(tensor, size)

_upsample_int_3d(tensor, size)

upsample(tensor, size, mode='linear')

Upsample a 1D, 2D, 3D tensor

trw.train.upsample._upsample_int_1d(tensor, size)
trw.train.upsample._upsample_int_2d(tensor, size)
trw.train.upsample._upsample_int_3d(tensor, size)
trw.train.upsample.upsample(tensor, size, mode='linear')

Upsample a 1D, 2D, 3D tensor

This is a wrapper around torch.nn.Upsample to make it more practical. Support integer based tensors.

Note

PyTorch as of version 1.3 doesn’t support non-floating point upsampling (see https://github.com/pytorch/pytorch/issues/13218 and https://github.com/pytorch/pytorch/issues/5580). Instead use a workaround (TODO assess the speed impact!).

Parameters
  • tensor – 1D (shape = b x c x n), 2D (shape = b x c x h x w) or 3D (shape = b x c x d x h x w)

  • size – if 1D, shape = n, if 2D shape = h x w, if 3D shape = d x h x w

  • modelinear or nearest

Returns

an up-sampled tensor with same batch size and filter size as the input