hive.agents.qnets.sequence_models module
- class hive.agents.qnets.sequence_models.SequenceFn[source]
Bases:
Registrable
,Module
A wrapper for callables that produce sequence functions.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- classmethod type_name()[source]
This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.
- class hive.agents.qnets.sequence_models.LSTMModel(rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True)[source]
Bases:
SequenceFn
A multi-layer long short-term memory (LSTM) RNN.
- Parameters
- forward(x, hidden_state)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.sequence_models.GRUModel(rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True)[source]
Bases:
SequenceFn
A multi-layer gated recurrent unit (GRU) RNN.
- Parameters
- forward(x, hidden_state)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.sequence_models.SequenceModel(in_dim, representation_network, sequence_fn, mlp_layers=None, normalization_factor=255, noisy=False, std_init=0.5)[source]
Bases:
Registrable
,Module
Basic convolutional recurrent neural network architecture. Applies a number of convolutional layers (each followed by a ReLU activation), recurrent layers, and then feeds the output into an
hive.agents.qnets.mlp.MLPNetwork
.Note, if
channels
isNone
, the network created for the convolution portion of the architecture is simply antorch.nn.Identity
module. Ifmlp_layers
isNone
, the mlp portion of the architecture is antorch.nn.Identity
module.- Parameters
in_dim (tuple) – The tuple of observations dimension (channels, width, height).
sequence_fn (SequenceFn) – A sequence neural network that learns recurrent representation. Usually placed between the convolutional layers and mlp layers.
normalization_factor (float | int) – What the input is divided by before the forward pass of the network.
noisy (bool) – Whether the MLP part of the network will use
NoisyLinear
layers ortorch.nn.Linear
layers.std_init (float) – The range for the initialization of the standard deviation of the weights in
NoisyLinear
.
- classmethod type_name()[source]
This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.
- forward(x, hidden_state=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.sequence_models.DRQNNetwork(base_network, hidden_dim, out_dim, linear_fn=None)[source]
Bases:
Module
Implements the standard DRQN value computation. This module returns two outputs, which correspond to the two outputs from
base_network
. In particular, it transforms the first output frombase_network
with output dimensionhidden_dim
to dimensionout_dim
, which should be equal to the number of actions. The second output of this module is the second output frombase_network
, which is the hidden state that will be used as the initial hidden state when computing the next action in the trajectory.- Parameters
base_network (torch.nn.Module) – Backbone network that returns two outputs, one is the representation used to compute action values, and the other one is the hidden state used as input hidden state later.
hidden_dim (int) – Dimension of the output of the
network
.out_dim (int) – Output dimension of the DRQN. Should be equal to the number of actions that you are computing values for.
linear_fn (torch.nn.Module) – Function that will create the
torch.nn.Module
that will take the output ofnetwork
and produce the final action values. IfNone
, atorch.nn.Linear
layer will be used.
- forward(x, hidden_state=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.