trw.utils.flatten

Module Contents

Functions

flatten(x: trw.basic_typing.TorchTensorNCX) → trw.basic_typing.TorchTensorNX

Flatten a tensor

trw.utils.flatten.flatten(x: trw.basic_typing.TorchTensorNCX) trw.basic_typing.TorchTensorNX

Flatten a tensor

Example, a tensor of shape[N, Z, Y, X] will be reshaped [N, Z * Y * X]

Parameters

x – a tensor

Returns: return a flattened tensor