trw.simple_layers.simple_layers

Notes

  • we can’t use local functions or lambda to simplify the pickling and unpickling

Module Contents

Classes

SimpleLayerBase

Base layer for our simplified network specification

SimpleMergeBase

Base class for nodes with multiple inputs

SimpleOutputBase

Base class to calculate an output

SimpleModule

Generic module

class trw.simple_layers.simple_layers.SimpleLayerBase(parents, shape)

Base layer for our simplified network specification

Record the network node by node and keep track of the important information: parents, children, size.

Note

  • nn.Module must be created during the initialization. This is to make sure we can easily share the

    network for different sub-models

get_module(self)

Return a nn.Module

class trw.simple_layers.simple_layers.SimpleMergeBase(parents, shape)

Bases: SimpleLayerBase

Base class for nodes with multiple inputs

class trw.simple_layers.simple_layers.SimpleOutputBase(node, output_name, shape)

Bases: SimpleLayerBase

Base class to calculate an output

forward(self, inputs, batch)

Create a trw.train.Output from the inputs

Parameters
  • inputs – a list of inputs of the output node

  • batch – the batch of data fed to the network

Returns

a trw.train.Output object

class trw.simple_layers.simple_layers.SimpleModule(node, module, shape=None)

Bases: SimpleLayerBase

Generic module

Module must have a single input and all the module’s parameters should be on the same device.

static calculate_shape(shape, module, parents)
get_module(self)

Return a nn.Module