Source code for hive.replays.legal_moves_replay

from typing import Dict, Tuple

import numpy as np

from hive.replays.prioritized_replay import PrioritizedReplayBuffer


[docs]class LegalMovesBuffer(PrioritizedReplayBuffer): """A Prioritized Replay buffer for the games like Hanabi with legal moves which need to add next_action_mask to the batch. """ def __init__( self, capacity: int, beta: float = 0.5, stack_size: int = 1, n_step: int = 1, gamma: float = 0.9, observation_shape: Tuple = (), observation_dtype: type = np.uint8, action_shape: Tuple = (), action_dtype: type = np.int8, reward_shape: Tuple = (), reward_dtype: type = np.float32, extra_storage_types: Dict = None, action_dim: int = None, num_players_sharing_buffer: int = None, ): if extra_storage_types is None: extra_storage_types = {} extra_storage_types["action_mask"] = (np.float, [action_dim]) super().__init__( capacity=capacity, stack_size=stack_size, beta=beta, n_step=n_step, gamma=gamma, observation_shape=observation_shape, observation_dtype=observation_dtype, action_shape=action_shape, action_dtype=action_dtype, reward_shape=reward_shape, reward_dtype=reward_dtype, extra_storage_types=extra_storage_types, num_players_sharing_buffer=num_players_sharing_buffer, )
[docs] def sample(self, batch_size): """Sample transitions from the buffer. Adding next_action_mask to the batch for environments with legal moves. """ batch = super().sample(batch_size) batch["next_action_mask"] = self._get_from_storage( "action_mask", batch["indices"] + batch["trajectory_lengths"] - self._stack_size + 1, num_to_access=1, ) return batch