Source code for hive.agents.qnets.mlp

from functools import partial
from typing import List, Tuple, Union

import numpy as np
import torch
from torch import nn

from hive.agents.qnets.noisy_linear import NoisyLinear


[docs]class MLPNetwork(nn.Module): """Basic MLP neural network architecture. Contains a series of :py:class:`torch.nn.Linear` or :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers, each of which is followed by a ReLU. """ def __init__( self, in_dim: Tuple[int], hidden_units: Union[int, List[int]] = 256, noisy: bool = False, std_init: float = 0.5, ): """ Args: in_dim (tuple[int]): The shape of input observations. hidden_units (int | list[int]): The number of neurons for each mlp layer. noisy (bool): Whether the MLP should use :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers or normal :py:class:`torch.nn.Linear` layers. std_init (float): The range for the initialization of the standard deviation of the weights in :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear`. """ super().__init__() if isinstance(hidden_units, int): hidden_units = [hidden_units] linear_fn = partial(NoisyLinear, std_init=std_init) if noisy else nn.Linear modules = [linear_fn(np.prod(in_dim), hidden_units[0]), torch.nn.ReLU()] for i in range(len(hidden_units) - 1): modules.append(linear_fn(hidden_units[i], hidden_units[i + 1])) modules.append(torch.nn.ReLU()) self.network = torch.nn.Sequential(*modules)
[docs] def forward(self, x): x = x.float() x = torch.flatten(x, start_dim=1) return self.network(x)