trw.utils.sub_tensor

Module Contents

Functions

sub_tensor(tensor: torch.Tensor, min_indices: trw.basic_typing.Shape, max_indices_exclusive: trw.basic_typing.Shape) → torch.Tensor

Select a region of a tensor (without copy)

trw.utils.sub_tensor.sub_tensor(tensor: torch.Tensor, min_indices: trw.basic_typing.Shape, max_indices_exclusive: trw.basic_typing.Shape) torch.Tensor

Select a region of a tensor (without copy)

Examples

>>> t = torch.randn([5, 10])
>>> sub_t = sub_tensor(t, [2, 3], [4, 8])
Returns the t[2:4, 3:8]
>>> t = torch.randn([5, 10])
>>> sub_t = sub_tensor(t, [2], [4])
Returns the t[2:4]
Parameters
  • tensor – a tensor

  • min_indices – the minimum indices to select for each dimension

  • max_indices_exclusive – the maximum indices (excluded) to select for each dimension

Returns

torch.tensor