hive.agents.qnets.sequence_models module

class hive.agents.qnets.sequence_models.SequenceFn[source]

Bases: Registrable, Module

A wrapper for callables that produce sequence functions.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

abstract init_hidden(batch_size)[source]
get_hidden_spec()[source]
training: bool
class hive.agents.qnets.sequence_models.LSTMModel(rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True)[source]

Bases: SequenceFn

A multi-layer long short-term memory (LSTM) RNN.

Parameters
  • rnn_input_size (int) – The number of expected features in the input x.

  • rnn_hidden_size (int) – The number of features in the hidden state h.

  • num_rnn_layers (int) – Number of recurrent layers.

  • batch_first (bool) – If True, then the input and output tensors are

  • as (provided) –

forward(x, hidden_state)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_hidden(batch_size)[source]
get_hidden_spec()[source]
training: bool
class hive.agents.qnets.sequence_models.GRUModel(rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True)[source]

Bases: SequenceFn

A multi-layer gated recurrent unit (GRU) RNN.

Parameters
  • rnn_input_size (int) – The number of expected features in the input x.

  • rnn_hidden_size (int) – The number of features in the hidden state h.

  • num_rnn_layers (int) – Number of recurrent layers.

  • batch_first (bool) – If True, then the input and output tensors are

  • as (provided) –

forward(x, hidden_state)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_hidden(batch_size)[source]
get_hidden_spec()[source]
training: bool
class hive.agents.qnets.sequence_models.SequenceModel(in_dim, representation_network, sequence_fn, mlp_layers=None, normalization_factor=255, noisy=False, std_init=0.5)[source]

Bases: Registrable, Module

Basic convolutional recurrent neural network architecture. Applies a number of convolutional layers (each followed by a ReLU activation), recurrent layers, and then feeds the output into an hive.agents.qnets.mlp.MLPNetwork.

Note, if channels is None, the network created for the convolution portion of the architecture is simply an torch.nn.Identity module. If mlp_layers is None, the mlp portion of the architecture is an torch.nn.Identity module.

Parameters
  • in_dim (tuple) – The tuple of observations dimension (channels, width, height).

  • sequence_fn (SequenceFn) – A sequence neural network that learns recurrent representation. Usually placed between the convolutional layers and mlp layers.

  • normalization_factor (float | int) – What the input is divided by before the forward pass of the network.

  • noisy (bool) – Whether the MLP part of the network will use NoisyLinear layers or torch.nn.Linear layers.

  • std_init (float) – The range for the initialization of the standard deviation of the weights in NoisyLinear.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

forward(x, hidden_state=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_hidden_spec()[source]
training: bool
class hive.agents.qnets.sequence_models.DRQNNetwork(base_network, hidden_dim, out_dim, linear_fn=None)[source]

Bases: Module

Implements the standard DRQN value computation. This module returns two outputs, which correspond to the two outputs from base_network. In particular, it transforms the first output from base_network with output dimension hidden_dim to dimension out_dim, which should be equal to the number of actions. The second output of this module is the second output from base_network, which is the hidden state that will be used as the initial hidden state when computing the next action in the trajectory.

Parameters
  • base_network (torch.nn.Module) – Backbone network that returns two outputs, one is the representation used to compute action values, and the other one is the hidden state used as input hidden state later.

  • hidden_dim (int) – Dimension of the output of the network.

  • out_dim (int) – Output dimension of the DRQN. Should be equal to the number of actions that you are computing values for.

  • linear_fn (torch.nn.Module) – Function that will create the torch.nn.Module that will take the output of network and produce the final action values. If None, a torch.nn.Linear layer will be used.

forward(x, hidden_state=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_hidden_spec()[source]
training: bool