trw.train.data_parallel_extented

Module Contents

Classes

DataParallelExtended

Customized version of torch.nn.DataParallel to support model with

Functions

gather_extended(outputs, target_device, dim=0)

Gathers tensors from different GPUs on a specified device

trw.train.data_parallel_extented.gather_extended(outputs, target_device, dim=0)
Gathers tensors from different GPUs on a specified device

(-1 means the CPU).

This is an extended version of `` compared to pytorch to support trw.train.Output

class trw.train.data_parallel_extented.DataParallelExtended(*arg, **argv)

Bases: torch.nn.DataParallel

Customized version of torch.nn.DataParallel to support model with complex outputs such as trw.train.Output

gather(self, outputs, output_device)