trw.utils.upsample

Module Contents

Functions

_upsample_int_1d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) → trw.basic_typing.TorchTensorNCX

_upsample_int_2d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) → trw.basic_typing.TorchTensorNCX

_upsample_int_3d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) → trw.basic_typing.TorchTensorNCX

upsample(tensor: trw.basic_typing.TensorNCX, size: trw.basic_typing.ShapeX, mode: typing_extensions.Literal[linear, nearest] = 'linear') → trw.basic_typing.TensorNCX

Upsample a 1D, 2D, 3D tensor

trw.utils.upsample._upsample_int_1d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) trw.basic_typing.TorchTensorNCX
trw.utils.upsample._upsample_int_2d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) trw.basic_typing.TorchTensorNCX
trw.utils.upsample._upsample_int_3d(tensor: trw.basic_typing.TorchTensorNCX, size: trw.basic_typing.ShapeNCX) trw.basic_typing.TorchTensorNCX
trw.utils.upsample.upsample(tensor: trw.basic_typing.TensorNCX, size: trw.basic_typing.ShapeX, mode: typing_extensions.Literal[linear, nearest] = 'linear') trw.basic_typing.TensorNCX

Upsample a 1D, 2D, 3D tensor

This is a wrapper around torch.nn.Upsample to make it more practical. Support integer based tensors.

Note

PyTorch as of version 1.3 doesn’t support non-floating point upsampling (see https://github.com/pytorch/pytorch/issues/13218 and https://github.com/pytorch/pytorch/issues/5580). Instead use a workaround (TODO assess the speed impact!).

Parameters
  • tensor – 1D (shape = b x c x n), 2D (shape = b x c x h x w) or 3D (shape = b x c x d x h x w)

  • size – if 1D, shape = n, if 2D shape = h x w, if 3D shape = d x h x w

  • modelinear or nearest

Returns

an up-sampled tensor with same batch size and filter size as the input