import copy
import gymnasium as gym
import numpy as np
import torch
from hive.agents.dqn import DQNAgent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.sequence_models import DRQNNetwork, SequenceFn, SequenceModel
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
create_init_weights_fn,
apply_to_tensor,
)
from hive.replays.recurrent_replay import RecurrentReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Schedule,
SwitchSchedule,
)
from hive.utils.utils import LossFn, OptimizerFn, seeder
[docs]class DRQNAgent(DQNAgent):
"""An agent implementing the DRQN algorithm. Uses an epsilon greedy
exploration policy
"""
def __init__(
self,
observation_space: gym.spaces.Box,
action_space: gym.spaces.Discrete,
representation_net: FunctionApproximator,
sequence_fn: SequenceFn,
id=0,
optimizer_fn: OptimizerFn = None,
loss_fn: LossFn = None,
init_fn: InitializationFn = None,
replay_buffer: RecurrentReplayBuffer = None,
max_seq_len: int = 1,
discount_rate: float = 0.99,
n_step: int = 1,
grad_clip: float = None,
reward_clip: float = None,
update_period_schedule: Schedule = None,
target_net_soft_update: bool = False,
target_net_update_fraction: float = 0.05,
target_net_update_schedule: Schedule = None,
epsilon_schedule: Schedule = None,
test_epsilon: float = 0.001,
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
store_hidden: bool = True,
burn_frames: int = 0,
**kwargs,
):
"""
Args:
observation_space (gym.spaces.Box): Observation space for the agent.
action_space (gym.spaces.Discrete): Action space for the agent.
representation_net (SequenceFunctionApproximator): A network that outputs the
representations that will be used to compute Q-values (e.g.
everything except the final layer of the DRQN), as well as the
hidden states of the recurrent component. The structure should be
similar to ConvRNNNetwork, i.e., it should have a current module
component placed between the convolutional layers and MLP layers.
It should also define a method that initializes the hidden state
of the recurrent module if the computation requires hidden states
as input/output.
id: Agent identifier.
optimizer_fn (OptimizerFn): A function that takes in a list of parameters
to optimize and returns the optimizer. If None, defaults to
:py:class:`~torch.optim.Adam`.
loss_fn (LossFn): Loss function used by the agent. If None, defaults to
:py:class:`~torch.nn.SmoothL1Loss`.
init_fn (InitializationFn): Initializes the weights of qnet using
create_init_weights_fn.
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.recurrent_replay.RecurrentReplayBuffer`.
max_seq_len (int): The number of consecutive transitions in a sequence.
discount_rate (float): A number between 0 and 1 specifying how much
future rewards are discounted by the agent.
n_step (int): The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float): Gradients will be clipped to between
[-grad_clip, grad_clip].
reward_clip (float): Rewards will be clipped to between
[-reward_clip, reward_clip].
update_period_schedule (Schedule): Schedule determining how frequently
the agent's Q-network is updated.
target_net_soft_update (bool): Whether the target net parameters are
replaced by the qnet parameters completely or using a weighted
average of the target net parameters and the qnet parameters.
target_net_update_fraction (float): The weight given to the target
net parameters in a soft update.
target_net_update_schedule (Schedule): Schedule determining how frequently
the target net is updated.
epsilon_schedule (Schedule): Schedule determining the value of epsilon
through the course of training.
test_epsilon (float): epsilon (probability of choosing a random action)
to be used during testing phase.
min_replay_history (int): How many observations to fill the replay buffer
with before starting to learn.
batch_size (int): The size of the batch sampled from the replay buffer
during learning.
device: Device on which all computations should be run.
logger (ScheduledLogger): Logger used to log agent's metrics.
log_frequency (int): How often to log the agent's metrics.
"""
self._max_seq_len = max_seq_len
super(DQNAgent, self).__init__(
observation_space=observation_space, action_space=action_space, id=id
)
self._state_size = (
self._observation_space.shape[0],
*self._observation_space.shape[1:],
)
self._init_fn = create_init_weights_fn(init_fn)
self._device = torch.device("cpu" if not torch.cuda.is_available() else device)
self.create_q_networks(representation_net, sequence_fn)
if optimizer_fn is None:
optimizer_fn = torch.optim.Adam
self._optimizer = optimizer_fn(self._qnet.parameters())
self._rng = np.random.default_rng(seed=seeder.get_new_seed("agent"))
hidden_spec = self._qnet.get_hidden_spec()
if not store_hidden or hidden_spec is None:
store_hidden = False
self._hidden_replay_spec = None
self._hidden_batch_spec = None
else:
self._hidden_replay_spec = {key: hidden_spec[key][0] for key in hidden_spec}
self._hidden_batch_spec = {key: hidden_spec[key][1] for key in hidden_spec}
if replay_buffer is None:
replay_buffer = RecurrentReplayBuffer
self._replay_buffer = replay_buffer(
observation_shape=self._observation_space.shape,
observation_dtype=self._observation_space.dtype,
action_shape=self._action_space.shape,
action_dtype=self._action_space.dtype,
max_seq_len=max_seq_len,
hidden_spec=self._hidden_replay_spec,
)
self._discount_rate = discount_rate**n_step
self._grad_clip = grad_clip
self._reward_clip = reward_clip
self._target_net_soft_update = target_net_soft_update
self._target_net_update_fraction = target_net_update_fraction
if loss_fn is None:
loss_fn = torch.nn.SmoothL1Loss
self._loss_fn = loss_fn(reduction="none")
self._batch_size = batch_size
self._logger = logger
if self._logger is None:
self._logger = NullLogger([])
self._timescale = self.id
self._logger.register_timescale(
self._timescale, PeriodicSchedule(False, True, log_frequency)
)
if update_period_schedule is None:
self._update_period_schedule = PeriodicSchedule(False, True, 1)
else:
self._update_period_schedule = update_period_schedule()
if target_net_update_schedule is None:
self._target_net_update_schedule = PeriodicSchedule(False, True, 10000)
else:
self._target_net_update_schedule = target_net_update_schedule()
if epsilon_schedule is None:
self._epsilon_schedule = LinearSchedule(1, 0.1, 100000)
else:
self._epsilon_schedule = epsilon_schedule()
self._test_epsilon = test_epsilon
self._learn_schedule = SwitchSchedule(False, True, min_replay_history)
self._training = False
self._store_hidden = store_hidden
self._burn_frames = burn_frames
[docs] def create_q_networks(self, representation_net, sequence_fn):
"""Creates the Q-network and target Q-network.
Args:
representation_net: A network that outputs the representations that will
be used to compute Q-values (e.g. everything except the final layer
of the DRQN).
"""
network = SequenceModel(
self._state_size, representation_net(self._state_size), sequence_fn
)
network_output_dim = np.prod(
calculate_output_dim(network, (1,) + self._state_size)[0]
)
self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to(
self._device
)
self._qnet.apply(self._init_fn)
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)
[docs] def preprocess_update_info(self, update_info, hidden_state):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.
Args:
update_info: Contains the information from the current timestep that the
agent should use to update itself.
"""
preprocessed_update_info = super().preprocess_update_info(update_info)
if self._store_hidden:
preprocessed_update_info.update(
apply_to_tensor(hidden_state, lambda x: x.detach().cpu().numpy())
)
return preprocessed_update_info
[docs] def preprocess_update_batch(self, batch):
"""Preprocess the batch sampled from the replay buffer.
Args:
batch: Batch sampled from the replay buffer for the current update.
Returns:
(tuple):
- (tuple) Inputs used to calculate current state values.
- (tuple) Inputs used to calculate next state values
- Preprocessed batch.
"""
for key in batch:
batch[key] = torch.tensor(batch[key], device=self._device)
if self._store_hidden:
for key in self._hidden_replay_spec:
if self._hidden_batch_spec[key] >= 0:
# Replay batches on the first dimension, network expects
# batch on different dimension
batch[key] = torch.cat(
list(batch[key]), dim=self._hidden_batch_spec[key]
)
batch[f"next_{key}"] = torch.cat(
list(batch[f"next_{key}"]), dim=self._hidden_batch_spec[key]
)
return (
(
batch["observation"],
{key: batch[key] for key in self._hidden_replay_spec},
),
(
batch["next_observation"],
{key: batch[f"next_{key}"] for key in self._hidden_replay_spec},
),
batch,
)
else:
return (batch["observation"]), (batch["next_observation"]), batch
[docs] @torch.no_grad()
def act(self, observation, agent_traj_state=None):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.
Args:
observation: The current observation.
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.
Returns:
- action
- agent trajectory state
"""
# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
else:
epsilon = self._test_epsilon
# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
# Insert batch_size and sequence_len dimensions to observation
observation = torch.tensor(
np.expand_dims(observation, axis=(0, 1)), device=self._device
).float()
hidden_state = (
None if agent_traj_state is None else agent_traj_state["hidden_state"]
)
qvals, hidden_state = self._qnet(observation, hidden_state)
if self._rng.random() < epsilon:
action = self._rng.integers(self._action_space.n)
else:
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()
if agent_traj_state is None:
if self._training and self._logger.should_log(self._timescale):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
return action, {"hidden_state": hidden_state}
[docs] def update(self, update_info, agent_traj_state=None):
"""
Updates the DRQN agent.
Args:
update_info: dictionary containing all the necessary information
from the environment to update the agent. Should contain a full
transition, with keys for "observation", "action", "reward",
"next_observation", "terminated", and "truncated".
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.
Returns:
- action
- agent trajectory state
"""
if not self._training:
return
# Add the most recent transition to the replay buffer.
self._replay_buffer.add(
**self.preprocess_update_info(
update_info, hidden_state=agent_traj_state["hidden_state"]
)
)
# Update the q network based on a sample batch from the replay buffer.
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
current_state_inputs,
next_state_inputs,
batch,
) = self.preprocess_update_batch(batch)
# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals, _ = self._qnet(*current_state_inputs)
pred_qvals = pred_qvals.view(self._batch_size, self._max_seq_len, -1)
actions = batch["action"].long()
pred_qvals = torch.gather(pred_qvals, -1, actions.unsqueeze(-1)).squeeze(-1)
# Compute 1-step Q targets
next_qvals, _ = self._target_qnet(*next_state_inputs)
next_qvals = next_qvals.view(self._batch_size, self._max_seq_len, -1)
next_qvals, _ = torch.max(next_qvals, dim=-1)
q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["terminated"]
)
if self._burn_frames > 0:
interm_loss = self._loss_fn(pred_qvals, q_targets)
mask = torch.zeros(
self._replay_buffer._max_seq_len,
device=self._device,
dtype=torch.float,
)
mask[self._burn_frames :] = 1.0
mask = mask.unsqueeze(0).repeat(len(batch["reward"]), 1)
mask = mask & batch["mask"]
interm_loss *= mask
loss = interm_loss.sum() / mask.sum()
else:
interm_loss = self._loss_fn(pred_qvals, q_targets)
interm_loss *= batch["mask"]
loss = interm_loss.sum() / batch["mask"].sum()
if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)
loss.backward()
if self._grad_clip is not None:
torch.nn.utils.clip_grad_value_(
self._qnet.parameters(), self._grad_clip
)
self._optimizer.step()
# Update target network
if self._target_net_update_schedule.update():
self._update_target()
return agent_traj_state