trw.train.data_parallel_extented
¶
Module Contents¶
Classes¶
Customized version of |
Functions¶
|
Gathers tensors from different GPUs on a specified device |
- trw.train.data_parallel_extented.gather_extended(outputs, target_device, dim=0)¶
- Gathers tensors from different GPUs on a specified device
(-1 means the CPU).
This is an extended version of `` compared to pytorch to support
trw.train.Output
- class trw.train.data_parallel_extented.DataParallelExtended(*arg, **argv)¶
Bases:
torch.nn.DataParallel
Customized version of
torch.nn.DataParallel
to support model with complex outputs such astrw.train.Output
- gather(self, outputs, output_device)¶