Source code for hive.agents.qnets.td3_heads

from typing import Tuple, Union
import numpy as np
import torch

from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.utils import calculate_output_dim


[docs]class TD3ActorNetwork(torch.nn.Module): """A module that implements the TD3 actor computation. It puts together the :obj:`representation_network` and :obj:`actor_net`, and adds a final :py:class:`~torch.nn.Linear` layer to compute the action.""" def __init__( self, representation_network: torch.nn.Module, actor_net: FunctionApproximator, network_output_shape: Union[int, Tuple[int]], action_shape: Tuple[int], use_tanh=True, ) -> None: """ Args: representation_network (torch.nn.Module): Network that encodes the observations. actor_net (FunctionApproximator): Function that takes in the shape of the encoded observations and creates a network. This network takes the encoded observations from representation_net and outputs the representations used to compute the actions (ie everything except the last layer). network_output_shape: Expected output shape of the representation network. action_shape: Requiured shape of the output action. """ super().__init__() self._action_shape = action_shape if actor_net is None: actor_network = torch.nn.Identity() else: actor_network = actor_net(network_output_shape) feature_dim = np.prod(calculate_output_dim(actor_network, network_output_shape)) actor_modules = [ representation_network, actor_network, torch.nn.Flatten(), torch.nn.Linear(feature_dim, np.prod(action_shape)), ] if use_tanh: actor_modules.append(torch.nn.Tanh()) self.actor = torch.nn.Sequential(*actor_modules)
[docs] def forward(self, x): x = self.actor(x) return torch.reshape(x, (x.size(0), *self._action_shape))
[docs]class TD3CriticNetwork(torch.nn.Module): def __init__( self, representation_network: torch.nn.Module, critic_net: FunctionApproximator, network_output_shape: Union[int, Tuple[int]], n_critics: int, action_shape: Tuple[int], ) -> None: """ Args: representation_network (torch.nn.Module): Network that encodes the observations. critic_net (FunctionApproximator): Function that takes in the shape of the encoded observations and creates a network. This network takes two inputs: the encoded observations from representation_net and actions. It outputs the representations used to compute the values of the actions (ie everything except the last layer). network_output_shape: Expected output shape of the representation network. n_critics: How many copies of the critic to create. They will all use the shared representation from the representation_network. action_shape: Expected shape of actions. """ super().__init__() self.network = representation_network if critic_net is None: critic_net = lambda x: torch.nn.Identity() self._n_critics = n_critics input_shape = (np.prod(network_output_shape) + np.prod(action_shape),) critics = [critic_net(input_shape) for _ in range(n_critics)] feature_dim = np.prod(calculate_output_dim(critics[0], input_shape=input_shape)) self._critics = torch.nn.ModuleList( [ torch.nn.Sequential( critic, torch.nn.Flatten(), torch.nn.Linear(feature_dim, 1), ) for critic in critics ] )
[docs] def forward(self, obs, actions): obs = self.network(obs) obs = torch.flatten(obs, start_dim=1) actions = torch.flatten(actions, start_dim=1) x = torch.cat([obs, actions], dim=1) return [critic(x) for critic in self._critics]
[docs] def q1(self, obs, actions): """Returns the value according to only the first critic.""" obs = self.network(obs) x = torch.cat([obs, actions], dim=1) return self._critics[0](x)