Source code for hive.agents.legal_moves_rainbow

import numpy as np
import torch

from hive.agents.rainbow import RainbowDQNAgent


[docs]class LegalMovesRainbowAgent(RainbowDQNAgent): """A Rainbow agent which supports games with legal actions."""
[docs] def create_q_networks(self, representation_net): """Creates the qnet and target qnet.""" super().create_q_networks(representation_net) self._qnet = LegalMovesHead(self._qnet) self._target_qnet = LegalMovesHead(self._target_qnet)
[docs] def preprocess_update_info(self, update_info): preprocessed_update_info = { "observation": update_info["observation"]["observation"], "action": update_info["action"], "reward": update_info["reward"], "done": update_info["done"], "action_mask": action_encoding(update_info["observation"]["action_mask"]), } if "agent_id" in update_info: preprocessed_update_info["agent_id"] = int(update_info["agent_id"]) return preprocessed_update_info
[docs] def preprocess_update_batch(self, batch): for key in batch: batch[key] = torch.tensor(batch[key], device=self._device) return ( (batch["observation"], batch["action_mask"]), (batch["next_observation"], batch["next_action_mask"]), batch, )
[docs] @torch.no_grad() def act(self, observation): if self._training: if not self._learn_schedule.get_value(): epsilon = 1.0 elif not self._use_eps_greedy: epsilon = 0.0 else: epsilon = self._epsilon_schedule.update() if self._logger.update_step(self._timescale): self._logger.log_scalar("epsilon", epsilon, self._timescale) else: epsilon = self._test_epsilon vectorized_observation = torch.tensor( np.expand_dims(observation["observation"], axis=0), device=self._device ).float() legal_moves_as_int = [ i for i, x in enumerate(observation["action_mask"]) if x == 1 ] encoded_legal_moves = torch.tensor( action_encoding(observation["action_mask"]), device=self._device ).float() qvals = self._qnet(vectorized_observation, encoded_legal_moves).cpu() if self._rng.random() < epsilon: action = np.random.choice(legal_moves_as_int).item() else: action = torch.argmax(qvals).item() if ( self._training and self._logger.should_log(self._timescale) and self._state["episode_start"] ): self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale) self._state["episode_start"] = False return action
[docs]class LegalMovesHead(torch.nn.Module): def __init__(self, base_network): super().__init__() self.base_network = base_network
[docs] def forward(self, x, legal_moves): x = self.base_network(x) return x + legal_moves
[docs] def dist(self, x, legal_moves): return self.base_network.dist(x)
[docs]def action_encoding(action_mask): encoded_action_mask = np.zeros(action_mask.shape) encoded_action_mask[action_mask == 0] = -np.inf return encoded_action_mask