trw.simple_layers.simple_layers_implementations

Module Contents

Classes

Input

Represent an input (i.e., a feature) to a network

OutputClassification

Output class for classification

OutputRecord

Record a field based from the node input values or the batch

OutputEmbedding

Create an embedding for display purposes

ReLU

Generic module

Linear

Generic module

Flatten

Generic module

Conv2d

Generic module

Conv3d

Generic module

MaxPool2d

Generic module

ConcatChannels

Implement a channel concatenation layer

Functions

return_output(outputs, batch)

_conv_2d_shape_fn(node, module_args)

_conv_3d_shape_fn(node, module_args)

class trw.simple_layers.simple_layers_implementations.Input(shape: list, feature_name: str)

Bases: trw.simple_layers.simple_layers.SimpleLayerBase

Represent an input (i.e., a feature) to a network

get_module(self)

Return a nn.Module

class trw.simple_layers.simple_layers_implementations.OutputClassification(node, output_name, classes_name, **kwargs)

Bases: trw.simple_layers.simple_layers.SimpleOutputBase

Output class for classification

forward(self, inputs, batch)

Create a trw.train.Output from the inputs

Parameters
  • inputs – a list of inputs of the output node

  • batch – the batch of data fed to the network

Returns

a trw.train.Output object

get_module(self)

Return a nn.Module

trw.simple_layers.simple_layers_implementations.return_output(outputs, batch)
class trw.simple_layers.simple_layers_implementations.OutputRecord(node, output_name, functor=return_output)

Bases: trw.simple_layers.simple_layers.SimpleOutputBase

Record a field based from the node input values or the batch

forward(self, inputs, batch)

Create a trw.train.Output from the inputs

Parameters
  • inputs – a list of inputs of the output node

  • batch – the batch of data fed to the network

Returns

a trw.train.Output object

get_module(self)

Return a nn.Module

class trw.simple_layers.simple_layers_implementations.OutputEmbedding(node, output_name)

Bases: trw.simple_layers.simple_layers.SimpleOutputBase

Create an embedding for display purposes

forward(self, inputs, batch)

Create a trw.train.Output from the inputs

Parameters
  • inputs – a list of inputs of the output node

  • batch – the batch of data fed to the network

Returns

a trw.train.Output object

get_module(self)

Return a nn.Module

class trw.simple_layers.simple_layers_implementations.ReLU(node)

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

class trw.simple_layers.simple_layers_implementations.Linear(node, out_features)

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

class trw.simple_layers.simple_layers_implementations.Flatten(node)

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

trw.simple_layers.simple_layers_implementations._conv_2d_shape_fn(node, module_args)
class trw.simple_layers.simple_layers_implementations.Conv2d(node, out_channels, kernel_size, stride=1, padding='same')

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

trw.simple_layers.simple_layers_implementations._conv_3d_shape_fn(node, module_args)
class trw.simple_layers.simple_layers_implementations.Conv3d(node, out_channels, kernel_size, stride=1, padding='same')

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

class trw.simple_layers.simple_layers_implementations.MaxPool2d(node, kernel_size, stride=None)

Bases: trw.simple_layers.simple_layers.SimpleModule

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

class trw.simple_layers.simple_layers_implementations.ConcatChannels(nodes, flatten=False)

Bases: trw.simple_layers.simple_layers.SimpleMergeBase

Implement a channel concatenation layer

static calculate_shape(parents)
get_module(self)

Return a nn.Module