trw.transforms.transforms_cast

Module Contents

Classes

TransformCast

Cast tensors to a specified type.

Functions

cast_np(tensor: numpy.ndarray, cast_type: str) → numpy.ndarray

cast_torch(tensor: torch.Tensor, cast_type: str) → torch.Tensor

cast(feature_names: Sequence[str], batch: trw.basic_typing.Batch, cast_type: str) → trw.basic_typing.Batch

Attributes

NUMPY_CONVERSION

TORCH_CONVERSION

trw.transforms.transforms_cast.NUMPY_CONVERSION
trw.transforms.transforms_cast.TORCH_CONVERSION
trw.transforms.transforms_cast.cast_np(tensor: numpy.ndarray, cast_type: str) numpy.ndarray
trw.transforms.transforms_cast.cast_torch(tensor: torch.Tensor, cast_type: str) torch.Tensor
trw.transforms.transforms_cast.cast(feature_names: Sequence[str], batch: trw.basic_typing.Batch, cast_type: str) trw.basic_typing.Batch
class trw.transforms.transforms_cast.TransformCast(feature_names: Sequence[str], cast_type: str)

Bases: trw.transforms.transforms.TransformBatchWithCriteria

Cast tensors to a specified type.

Only numpy.ndarray and torch.Tensor types will be casted