trw.train.collate

Module Contents

Functions

collate_tensors(values: Union[numpy.ndarray, torch.Tensor, Union[List[numpy.ndarray], List[torch.Tensor], List[numbers.Number], List[str], List[List[numpy.ndarray]], List[List[torch.Tensor]], List[List[numbers.Number]], List[List[str]]]], device: torch.device, pin_memory: bool = False, non_blocking: bool = False) → Union[torch.Tensor, List]

express values as a torch.Tensor

collate_dicts(batch: trw.basic_typing.Batch, device: torch.device, pin_memory: bool = False, non_blocking: bool = False) → Mapping[str, Union[torch.Tensor, List]]

Default function to collate a dictionary of samples to a dictionary of torch.Tensor

collate_list_of_dicts(batches: Sequence[trw.basic_typing.Batch], device: torch.device, pin_memory: bool = False, non_blocking: bool = False) → Mapping[str, Union[torch.Tensor, List]]

Default function to collate a list of dictionary to a dictionary of `torch.Tensor`s

default_collate_fn(batch: Union[Sequence[Any], Mapping[str, Any]], device: torch.device, pin_memory: bool = False, non_blocking: bool = False)

param batch

a dictionary of features or a list of dictionary of features

Attributes

logger

trw.train.collate.logger
trw.train.collate.collate_tensors(values: Union[numpy.ndarray, torch.Tensor, Union[List[numpy.ndarray], List[torch.Tensor], List[numbers.Number], List[str], List[List[numpy.ndarray]], List[List[torch.Tensor]], List[List[numbers.Number]], List[List[str]]]], device: torch.device, pin_memory: bool = False, non_blocking: bool = False) Union[torch.Tensor, List]

express values as a torch.Tensor

Parameters
  • values – nd.array or torch.Tensor

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

  • non_blocking – if True, use non blocking memory transfer

Returns

a torch.Tensor if of type numpy.ndarray else, the input type

trw.train.collate.collate_dicts(batch: trw.basic_typing.Batch, device: torch.device, pin_memory: bool = False, non_blocking: bool = False) Mapping[str, Union[torch.Tensor, List]]

Default function to collate a dictionary of samples to a dictionary of torch.Tensor

Parameters
  • batch – a dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

  • non_blocking – if True, use non blocking memory transfer

Returns

a dictionary of torch.Tensor

trw.train.collate.collate_list_of_dicts(batches: Sequence[trw.basic_typing.Batch], device: torch.device, pin_memory: bool = False, non_blocking: bool = False) Mapping[str, Union[torch.Tensor, List]]

Default function to collate a list of dictionary to a dictionary of `torch.Tensor`s

Parameters
  • batches – a list of dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

  • non_blocking – if True, use non blocking memory transfer

Returns

a dictionary of torch.Tensor

trw.train.collate.default_collate_fn(batch: Union[Sequence[Any], Mapping[str, Any]], device: torch.device, pin_memory: bool = False, non_blocking: bool = False)
Parameters
  • batch – a dictionary of features or a list of dictionary of features

  • device – the device where to create the torch.Tensor

  • pin_memory – if True, pin the memory. Required to be a CUDA allocated torch.Tensor

  • non_blocking – if True, use non blocking memory transfer

Returns

a dictionary of torch.Tensor