trw.simple_layers.compiled_net

Careful here: we MUST use ordered set as order of execution is important! If not there may be inconsistencies between compiled net and executed net

Module Contents

Classes

RuntimeAction

Specifies the type of runtime action (e.g., execution of a node, release of node's state or evaluation of a node)

CompiledNet

Encapsulate a compiled network so that we can efficiently calculate the outputs

Functions

find_layer_type(nodes: list, layer_type)

Find all layer with a given type

nodes_mark_output_dependencies(output_nodes: list)

Marks nodes by output IDs

remove_weak_ref(nodes: list)

Run through all the nodes of the graph and remove weakref.

create_weak_ref(output_nodes: list)

Re-create the weakref references of the children for the given output_nodes recursively

compile_nn(output_nodes: list, other_outputs_to_keep_alive=None)

Compile a network to calculate output_nodes

trw.simple_layers.compiled_net.find_layer_type(nodes: list, layer_type)

Find all layer with a given type

Parameters
  • nodes – the starting nodes [list]

  • layer_type – the type of the nodes to collect

Returns

a list of nodes of the corresponding type

trw.simple_layers.compiled_net.nodes_mark_output_dependencies(output_nodes: list)

Marks nodes by output IDs

Parameters

output_nodes – a list of output nodes to be marked

Returns

nodes with a set of output IDs

trw.simple_layers.compiled_net.remove_weak_ref(nodes: list)

Run through all the nodes of the graph and remove weakref.

weakref are an issue when importing or exporting the network with pickle.

We can safely remove these weakref and reconstruct them if necessary.

Parameters

nodes – the starting nodes [list]

Returns

None

trw.simple_layers.compiled_net.create_weak_ref(output_nodes: list)

Re-create the weakref references of the children for the given output_nodes recursively

Parameters

output_nodes – a list of output nodes to have the children weak ref updated

Returns

None

class trw.simple_layers.compiled_net.RuntimeAction

Bases: enum.Enum

Specifies the type of runtime action (e.g., execution of a node, release of node’s state or evaluation of a node)

EXECUTE_NODE = 1
REMOVE_STATE = 2
EVALUATE_STATE = 3
class trw.simple_layers.compiled_net.CompiledNet

Bases: torch.nn.Module

Encapsulate a compiled network so that we can efficiently calculate the outputs

of the network.

collect_parameters(self)

Make the parameters of each node visible to the current module

forward(self, batch)

Calculate the outputs of a network

Parameters

batch – (dict) a dictionary like of features

Returns

a dictionary of outputs

__getstate__(self)
__setstate__(self, state)
trw.simple_layers.compiled_net.compile_nn(output_nodes: list, other_outputs_to_keep_alive=None)

Compile a network to calculate output_nodes

Parameters
  • output_nodes – the output nodes to calculate. The order of the nodes indicates the order of the calculation and impacts the book-keeping of the shared calculation in multiple output networks

  • other_outputs_to_keep_alive – keeps alive unused output nodes

Returns

a CompiledNet