trw.simple_layers.simple_layers
¶
Notes
we can’t use local functions or lambda to simplify the pickling and unpickling
Module Contents¶
Classes¶
Base layer for our simplified network specification |
|
Base class for nodes with multiple inputs |
|
Base class to calculate an output |
|
Generic module |
- class trw.simple_layers.simple_layers.SimpleLayerBase(parents, shape)¶
Base layer for our simplified network specification
Record the network node by node and keep track of the important information: parents, children, size.
Note
- nn.Module must be created during the initialization. This is to make sure we can easily share the
network for different sub-models
- get_module(self)¶
Return a nn.Module
- class trw.simple_layers.simple_layers.SimpleMergeBase(parents, shape)¶
Bases:
SimpleLayerBase
Base class for nodes with multiple inputs
- class trw.simple_layers.simple_layers.SimpleOutputBase(node, output_name, shape)¶
Bases:
SimpleLayerBase
Base class to calculate an output
- forward(self, inputs, batch)¶
Create a trw.train.Output from the inputs
- Parameters
inputs – a list of inputs of the output node
batch – the batch of data fed to the network
- Returns
a trw.train.Output object
- class trw.simple_layers.simple_layers.SimpleModule(node, module, shape=None)¶
Bases:
SimpleLayerBase
Generic module
Module must have a single input and all the module’s parameters should be on the same device.
- static calculate_shape(shape, module, parents)¶
- get_module(self)¶
Return a nn.Module