trw.transforms.transforms_unsqueeze

Module Contents

Classes

TransformUnsqueeze

Unsqueeze a dimension of a tensor.

Functions

unsqueeze_fn(feature_names: Sequence[str], batch: trw.basic_typing.Batch, axis: int) → trw.basic_typing.Batch

trw.transforms.transforms_unsqueeze.unsqueeze_fn(feature_names: Sequence[str], batch: trw.basic_typing.Batch, axis: int) trw.basic_typing.Batch
class trw.transforms.transforms_unsqueeze.TransformUnsqueeze(axis: int, criteria_fn: Optional[trw.transforms.transforms.CriteriaFn] = criteria_is_array_4_or_above)

Bases: trw.transforms.transforms.TransformBatchWithCriteria

Unsqueeze a dimension of a tensor.

Only numpy.ndarray and torch.Tensor types will be transformed