Source code for hive.agents.qnets.sequence_models

import abc

import numpy as np
import torch
from torch import nn

from hive.agents.qnets.mlp import MLPNetwork
from hive.agents.qnets.utils import calculate_output_dim
from hive.utils.registry import Registrable, registry


[docs]class SequenceFn(Registrable, nn.Module): """A wrapper for callables that produce sequence functions."""
[docs] @classmethod def type_name(cls): return "SequenceFn"
[docs] @abc.abstractmethod def init_hidden(self, batch_size): raise NotImplementedError
[docs] def get_hidden_spec(self): return None
[docs]class LSTMModel(SequenceFn): """ A multi-layer long short-term memory (LSTM) RNN. """ def __init__( self, rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True, ): """ Args: rnn_input_size (int): The number of expected features in the input x. rnn_hidden_size (int): The number of features in the hidden state h. num_rnn_layers (int): Number of recurrent layers. batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). """ super().__init__() self._rnn_hidden_size = rnn_hidden_size self._num_rnn_layers = num_rnn_layers self.core = nn.LSTM( input_size=rnn_input_size, hidden_size=self._rnn_hidden_size, num_layers=self._num_rnn_layers, batch_first=batch_first, ) self._device = next(self.core.parameters()).device
[docs] def forward(self, x, hidden_state): x, hidden_state = self.core( x, (hidden_state["hidden_state"], hidden_state["cell_state"]) ) return x, {"hidden_state": hidden_state[0], "cell_state": hidden_state[1]}
def _apply(self, *args, **kwargs): ret = super()._apply(*args, **kwargs) self._device = next(self.core.parameters()).device return ret
[docs] def init_hidden(self, batch_size): return { "hidden_state": torch.zeros( (self._num_rnn_layers, batch_size, self._rnn_hidden_size), dtype=torch.float32, device=self._device, ), "cell_state": torch.zeros( (self._num_rnn_layers, batch_size, self._rnn_hidden_size), dtype=torch.float32, device=self._device, ), }
[docs] def get_hidden_spec(self): return { "hidden_state": ( ( np.float32, (self._num_rnn_layers, 1, self._rnn_hidden_size), ), 1, ), "cell_state": ( ( np.float32, (self._num_rnn_layers, 1, self._rnn_hidden_size), ), 1, ), }
[docs]class GRUModel(SequenceFn): """ A multi-layer gated recurrent unit (GRU) RNN. """ def __init__( self, rnn_input_size, rnn_hidden_size=128, num_rnn_layers=1, batch_first=True, ): """ Args: rnn_input_size (int): The number of expected features in the input x. rnn_hidden_size (int): The number of features in the hidden state h. num_rnn_layers (int): Number of recurrent layers. batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). """ super().__init__() self._rnn_hidden_size = rnn_hidden_size self._num_rnn_layers = num_rnn_layers self.core = nn.GRU( input_size=rnn_input_size, hidden_size=self._rnn_hidden_size, num_layers=self._num_rnn_layers, batch_first=batch_first, ) self._device = next(self.core.parameters()).device
[docs] def forward(self, x, hidden_state): x, hidden_state = self.core(x, hidden_state) return x, {"hidden_state": hidden_state}
def _apply(self, *args, **kwargs): ret = super()._apply(*args, **kwargs) self._device = next(self.core.parameters()).device return ret
[docs] def init_hidden(self, batch_size): return { "hidden_state": torch.zeros( (self._num_rnn_layers, batch_size, self._rnn_hidden_size), dtype=torch.float32, device=self._device, ) }
[docs] def get_hidden_spec(self): return ( { "hidden_state": ( ( np.float32, (self._num_rnn_layers, 1, self._rnn_hidden_size), ), 1, ) }, )
[docs]class SequenceModel(Registrable, nn.Module): """ Basic convolutional recurrent neural network architecture. Applies a number of convolutional layers (each followed by a ReLU activation), recurrent layers, and then feeds the output into an :py:class:`hive.agents.qnets.mlp.MLPNetwork`. Note, if :obj:`channels` is :const:`None`, the network created for the convolution portion of the architecture is simply an :py:class:`torch.nn.Identity` module. If :obj:`mlp_layers` is :const:`None`, the mlp portion of the architecture is an :py:class:`torch.nn.Identity` module. """
[docs] @classmethod def type_name(cls): return "SequenceModel"
def __init__( self, in_dim, representation_network: torch.nn.Module, sequence_fn: SequenceFn, mlp_layers=None, normalization_factor=255, noisy=False, std_init=0.5, ): """ Args: in_dim (tuple): The tuple of observations dimension (channels, width, height). sequence_fn (SequenceFn): A sequence neural network that learns recurrent representation. Usually placed between the convolutional layers and mlp layers. normalization_factor (float | int): What the input is divided by before the forward pass of the network. noisy (bool): Whether the MLP part of the network will use :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers or :py:class:`torch.nn.Linear` layers. std_init (float): The range for the initialization of the standard deviation of the weights in :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear`. """ super().__init__() self._normalization_factor = normalization_factor self.representation_network = representation_network # RNN Layers conv_output_size = calculate_output_dim(self.representation_network, in_dim) self.sequence_fn = sequence_fn( rnn_input_size=np.prod(conv_output_size), ) if mlp_layers is not None: # MLP Layers sequence_output_size, _ = calculate_output_dim( self.sequence_fn, conv_output_size ) self.mlp = MLPNetwork( sequence_output_size, mlp_layers, noisy=noisy, std_init=std_init, ) else: self.mlp = nn.Identity()
[docs] def forward(self, x, hidden_state=None): B, L = x.shape[0], x.shape[1] x = x.reshape(B * L, *x.shape[2:]) x = self.representation_network(x) x = x.view(B, L, -1) # Sequence models with hidden state if hidden_state is None: hidden_state = self.sequence_fn.init_hidden(batch_size=B) x, hidden_state = self.sequence_fn(x, hidden_state) out = self.mlp(x.reshape((B * L, -1))) return out, hidden_state
[docs] def get_hidden_spec(self): return self.sequence_fn.get_hidden_spec()
[docs]class DRQNNetwork(nn.Module): """Implements the standard DRQN value computation. This module returns two outputs, which correspond to the two outputs from :obj:`base_network`. In particular, it transforms the first output from :obj:`base_network` with output dimension :obj:`hidden_dim` to dimension :obj:`out_dim`, which should be equal to the number of actions. The second output of this module is the second output from :obj:`base_network`, which is the hidden state that will be used as the initial hidden state when computing the next action in the trajectory. """ def __init__( self, base_network: SequenceModel, hidden_dim: int, out_dim: int, linear_fn: nn.Module = None, ): """ Args: base_network (torch.nn.Module): Backbone network that returns two outputs, one is the representation used to compute action values, and the other one is the hidden state used as input hidden state later. hidden_dim (int): Dimension of the output of the :obj:`network`. out_dim (int): Output dimension of the DRQN. Should be equal to the number of actions that you are computing values for. linear_fn (torch.nn.Module): Function that will create the :py:class:`torch.nn.Module` that will take the output of :obj:`network` and produce the final action values. If :obj:`None`, a :py:class:`torch.nn.Linear` layer will be used. """ super().__init__() self.base_network = base_network self._linear_fn = linear_fn if linear_fn is not None else nn.Linear self.output_layer = self._linear_fn(hidden_dim, out_dim)
[docs] def forward(self, x, hidden_state=None): x, hidden_state = self.base_network(x, hidden_state) x = x.flatten(start_dim=1) return self.output_layer(x), hidden_state
[docs] def get_hidden_spec(self): return self.base_network.get_hidden_spec()
registry.register_all( SequenceFn, { "LSTM": LSTMModel, "GRU": GRUModel, }, ) registry.register("SequenceModel", SequenceModel, SequenceModel) get_sequence_fn = getattr(registry, f"get_{SequenceFn.type_name()}") get_sequence_model = getattr(registry, f"get_{SequenceModel.type_name()}")