trw.utils.global_pooling
¶
Module Contents¶
Functions¶
|
2D Global max pooling. |
|
2D Global average pooling. |
|
3D Global max pooling. |
|
3D Global average pooling. |
- trw.utils.global_pooling.global_max_pooling_2d(tensor: trw.basic_typing.TorchTensorNCX) trw.basic_typing.TorchTensorNCX ¶
2D Global max pooling.
Calculate the max value per sample per channel of a tensor.
- Parameters
tensor – tensor with shape NCHW
- Returns
a tensor of shape NC
- trw.utils.global_pooling.global_average_pooling_2d(tensor: trw.basic_typing.TorchTensorNCX) trw.basic_typing.TorchTensorNCX ¶
2D Global average pooling.
Calculate the average value per sample per channel of a tensor.
- Parameters
tensor – tensor with shape NCHW
- Returns
a tensor of shape NC
- trw.utils.global_pooling.global_max_pooling_3d(tensor: trw.basic_typing.TorchTensorNCX) trw.basic_typing.TorchTensorNCX ¶
3D Global max pooling.
Calculate the max value per sample per channel of a tensor.
- Parameters
tensor – tensor with shape NCDHW
- Returns
a tensor of shape NC
- trw.utils.global_pooling.global_average_pooling_3d(tensor: trw.basic_typing.TorchTensorNCX) trw.basic_typing.TorchTensorNCX ¶
3D Global average pooling.
Calculate the average value per sample per channel of a tensor.
- Parameters
tensor – tensor with shape NCDHW
- Returns
a tensor of shape NC