Source code for hive.agents.qnets.qnet_heads

import torch
import torch.nn.functional as F
from torch import nn


[docs]class DQNNetwork(nn.Module): """Implements the standard DQN value computation. Transforms output from :obj:`base_network` with output dimension :obj:`hidden_dim` to dimension :obj:`out_dim`, which should be equal to the number of actions. """ def __init__( self, base_network: nn.Module, hidden_dim: int, out_dim: int, linear_fn: nn.Module = None, ): """ Args: base_network (torch.nn.Module): Backbone network that computes the representations that are used to compute action values. hidden_dim (int): Dimension of the output of the :obj:`network`. out_dim (int): Output dimension of the DQN. Should be equal to the number of actions that you are computing values for. linear_fn (torch.nn.Module): Function that will create the :py:class:`torch.nn.Module` that will take the output of :obj:`network` and produce the final action values. If :obj:`None`, a :py:class:`torch.nn.Linear` layer will be used. """ super().__init__() self.base_network = base_network self._linear_fn = linear_fn if linear_fn is not None else nn.Linear self.output_layer = self._linear_fn(hidden_dim, out_dim)
[docs] def forward(self, x): x = self.base_network(x) x = x.flatten(start_dim=1) return self.output_layer(x)
[docs]class DuelingNetwork(nn.Module): """Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581). In dueling, we have two heads---one for estimating advantage function and one for estimating value function. """ def __init__( self, base_network: nn.Module, hidden_dim: int, out_dim: int, linear_fn: nn.Module = None, atoms: int = 1, ): """ Args: base_network (torch.nn.Module): Backbone network that computes the representations that are shared by the two estimators. hidden_dim (int): Dimension of the output of the :obj:`base_network`. out_dim (int): Output dimension of the Dueling DQN. Should be equal to the number of actions that you are computing values for. linear_fn (torch.nn.Module): Function that will create the :py:class:`torch.nn.Module` that will take the output of :obj:`network` and produce the final action values. If :obj:`None`, a :py:class:`torch.nn.Linear` layer will be used. atoms (int): Multiplier for the dimension of the output. For standard dueling networks, this should be 1. Used by :py:class:`~hive.agents.qnets.qnet_heads.DistributionalNetwork`. """ super().__init__() self.base_network = base_network self._hidden_dim = hidden_dim self._out_dim = out_dim self._atoms = atoms self._linear_fn = linear_fn if linear_fn is not None else nn.Linear self.init_networks()
[docs] def init_networks(self): self.output_layer_adv = self._linear_fn( self._hidden_dim, self._out_dim * self._atoms ) self.output_layer_val = self._linear_fn(self._hidden_dim, 1 * self._atoms)
[docs] def forward(self, x): x = self.base_network(x) x = x.flatten(start_dim=1) adv = self.output_layer_adv(x) val = self.output_layer_val(x) if adv.dim() == 1: x = val + adv - adv.mean(0) else: adv = adv.reshape(adv.size(0), self._out_dim, self._atoms) val = val.reshape(val.size(0), 1, self._atoms) x = val + adv - adv.mean(dim=1, keepdim=True) if self._atoms == 1: x = x.squeeze(dim=2) return x
[docs]class DistributionalNetwork(nn.Module): """Computes a categorical distribution over values for each action (https://arxiv.org/abs/1707.06887).""" def __init__( self, base_network: nn.Module, out_dim: int, vmin: float = 0, vmax: float = 200, atoms: int = 51, ): """ Args: base_network (torch.nn.Module): Backbone network that computes the representations that are used to compute the value distribution. out_dim (int): Output dimension of the Distributional DQN. Should be equal to the number of actions that you are computing values for. vmin (float): The minimum of the support of the categorical value distribution. vmax (float): The maximum of the support of the categorical value distribution. atoms (int): Number of atoms discretizing the support range of the categorical value distribution. """ super().__init__() self.base_network = base_network self._supports = torch.nn.Parameter(torch.linspace(vmin, vmax, atoms)) self._out_dim = out_dim self._atoms = atoms
[docs] def forward(self, x): x = self.dist(x) x = torch.sum(x * self._supports, dim=2) return x
[docs] def dist(self, x): """Computes a categorical distribution over values for each action.""" x = self.base_network(x) x = x.view(-1, self._out_dim, self._atoms) x = F.softmax(x, dim=-1) return x