trw.simple_layers.compiled_net
¶
Careful here: we MUST use ordered set as order of execution is important! If not there may be inconsistencies between compiled net and executed net
Module Contents¶
Classes¶
Specifies the type of runtime action (e.g., execution of a node, release of node's state or evaluation of a node) |
|
Encapsulate a compiled network so that we can efficiently calculate the outputs |
Functions¶
|
Find all layer with a given type |
|
Marks nodes by output IDs |
|
Run through all the nodes of the graph and remove weakref. |
|
Re-create the weakref references of the children for the given output_nodes recursively |
|
Compile a network to calculate output_nodes |
- trw.simple_layers.compiled_net.find_layer_type(nodes: list, layer_type)¶
Find all layer with a given type
- Parameters
nodes – the starting nodes [list]
layer_type – the type of the nodes to collect
- Returns
a list of nodes of the corresponding type
- trw.simple_layers.compiled_net.nodes_mark_output_dependencies(output_nodes: list)¶
Marks nodes by output IDs
- Parameters
output_nodes – a list of output nodes to be marked
- Returns
nodes with a set of output IDs
- trw.simple_layers.compiled_net.remove_weak_ref(nodes: list)¶
Run through all the nodes of the graph and remove weakref.
- weakref are an issue when importing or exporting the network with pickle.
We can safely remove these weakref and reconstruct them if necessary.
- Parameters
nodes – the starting nodes [list]
- Returns
None
- trw.simple_layers.compiled_net.create_weak_ref(output_nodes: list)¶
Re-create the weakref references of the children for the given output_nodes recursively
- Parameters
output_nodes – a list of output nodes to have the children weak ref updated
- Returns
None
- class trw.simple_layers.compiled_net.RuntimeAction¶
Bases:
enum.Enum
Specifies the type of runtime action (e.g., execution of a node, release of node’s state or evaluation of a node)
- EXECUTE_NODE = 1¶
- REMOVE_STATE = 2¶
- EVALUATE_STATE = 3¶
- class trw.simple_layers.compiled_net.CompiledNet¶
Bases:
torch.nn.Module
- Encapsulate a compiled network so that we can efficiently calculate the outputs
of the network.
- collect_parameters(self)¶
Make the parameters of each node visible to the current module
- forward(self, batch)¶
Calculate the outputs of a network
- Parameters
batch – (dict) a dictionary like of features
- Returns
a dictionary of outputs
- __getstate__(self)¶
- __setstate__(self, state)¶
- trw.simple_layers.compiled_net.compile_nn(output_nodes: list, other_outputs_to_keep_alive=None)¶
Compile a network to calculate output_nodes
- Parameters
output_nodes – the output nodes to calculate. The order of the nodes indicates the order of the calculation and impacts the book-keeping of the shared calculation in multiple output networks
other_outputs_to_keep_alive – keeps alive unused output nodes
- Returns
a CompiledNet