trw.utils.clamp_n

Module Contents

Functions

clamp_n(tensor: torch.Tensor, min_values: Sequence[Any], max_values: Sequence[Any]) → torch.Tensor

Clamp a tensor with axis dependent values.

trw.utils.clamp_n.clamp_n(tensor: torch.Tensor, min_values: Sequence[Any], max_values: Sequence[Any]) torch.Tensor

Clamp a tensor with axis dependent values.

Parameters
  • tensor – a N-d torch.Tensor

  • min_values – a 1D torch.Tensor. Min value is axis dependent

  • max_values – a 1D torch.Tensor. Max value is axis dependent

Returns

tensor with values clamped to min_values and max_values

Examples

>>> t = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
>>> min_values = torch.LongTensor([3, 2, 4])
>>> max_values = torch.LongTensor([3, 4, 8])
>>> clamped_t = clamp_n(t, min_values, max_values)