trw.simple_layers.compiled_net
¶
Careful here: we MUST use ordered set as order of execution is important! If not there may be inconsistencies between compiled net and executed net
Module Contents¶
Classes¶
Specifies the type of runtime action (e.g., execution of a node, release of node's state or evaluation of a node) |
|
Encapsulate a compiled network so that we can efficiently calculate the outputs |
|
Defines an empty module (i.e., equivalent to None but for nn.Module so that it can be stored in a |
|
nn.ModuleList can only store nn.Module so create a wrapper that can be stored in a nn.ModuleList |
Functions¶
|
Find all layer with a given type |
|
Marks nodes by output IDs |
|
Run through all the nodes of the graph and remove weakref. |
|
Re-create the weakref references of the children for the given output_nodes recursively |
|
|
|
Compile a network to calculate output_nodes |
- trw.simple_layers.compiled_net.find_layer_type(nodes: list, layer_type)¶
Find all layer with a given type
- Parameters
nodes – the starting nodes [list]
layer_type – the type of the nodes to collect
- Returns
a list of nodes of the corresponding type
- trw.simple_layers.compiled_net.nodes_mark_output_dependencies(output_nodes: list)¶
Marks nodes by output IDs
- Parameters
output_nodes – a list of output nodes to be marked
- Returns
nodes with a set of output IDs
- trw.simple_layers.compiled_net.remove_weak_ref(nodes: list)¶
Run through all the nodes of the graph and remove weakref.
- weakref are an issue when importing or exporting the network with pickle.
We can safely remove these weakref and reconstruct them if necessary.
- Parameters
nodes – the starting nodes [list]
- Returns
None
- trw.simple_layers.compiled_net.create_weak_ref(output_nodes: list)¶
Re-create the weakref references of the children for the given output_nodes recursively
- Parameters
output_nodes – a list of output nodes to have the children weak ref updated
- Returns
None
- class trw.simple_layers.compiled_net.RuntimeAction¶
Bases:
enum.Enum
Specifies the type of runtime action (e.g., execution of a node, release of node’s state or evaluation of a node)
- EXECUTE_NODE = 1¶
- REMOVE_STATE = 2¶
- EVALUATE_STATE = 3¶
- trw.simple_layers.compiled_net._prepare_inputs(action, states)¶
- class trw.simple_layers.compiled_net.CompiledNet(remove_checks=False)¶
Bases:
torch.nn.Module
- Encapsulate a compiled network so that we can efficiently calculate the outputs
of the network.
- forward(self, batch)¶
Calculate the outputs of a network
- Parameters
batch – (dict) a dictionary like of features
- Returns
a dictionary of outputs
- __getstate__(self)¶
- __setstate__(self, state)¶
- class trw.simple_layers.compiled_net.EmptyModule¶
Bases:
torch.nn.Module
Defines an empty module (i.e., equivalent to None but for nn.Module so that it can be stored in a nn.ModuleList)
- class trw.simple_layers.compiled_net.WrapperModule(module)¶
Bases:
torch.nn.Module
nn.ModuleList can only store nn.Module so create a wrapper that can be stored in a nn.ModuleList
- forward(self, i)¶
- trw.simple_layers.compiled_net.compile_nn(output_nodes: list, other_outputs_to_keep_alive=None, remove_checks=False)¶
Compile a network to calculate output_nodes
- Parameters
output_nodes – the output nodes to calculate. The order of the nodes indicates the order of the calculation and impacts the book-keeping of the shared calculation in multiple output networks
other_outputs_to_keep_alive – keeps alive unused output nodes
remove_checks – if True, some runtime checks will be disabled. This can be useful for example FCNN where the output shape will depend on the input shape
- Returns
a CompiledNet