import os
import pickle
from typing import Dict, Tuple
import numpy as np
from hive.replays.circular_replay import CircularReplayBuffer
from hive.utils.torch_utils import numpify
from hive.utils.utils import seeder
[docs]class PrioritizedReplayBuffer(CircularReplayBuffer):
"""Implements a replay with prioritized sampling. See
https://arxiv.org/abs/1511.05952
"""
def __init__(
self,
capacity: int,
alpha: float = 0.5,
beta: float = 0.5,
stack_size: int = 1,
n_step: int = 1,
gamma: float = 0.9,
observation_shape: Tuple = (),
observation_dtype: type = np.uint8,
action_shape: Tuple = (),
action_dtype: type = np.int8,
reward_shape: Tuple = (),
reward_dtype: type = np.float32,
extra_storage_types: Dict = None,
num_players_sharing_buffer=None,
):
"""
Args:
capacity (int): Total number of observations that can be stored in the
buffer. Note, this is not the same as the number of transitions that
can be stored in the buffer.
alpha (float): Parameter controlling level of prioritization.
beta (float): Parameter controlling level of correction for prioritization.
stack_size (int): The number of frames to stack to create an observation.
n_step (int): Horizon used to compute n-step return reward
gamma (float): Discounting factor used to compute n-step return reward
observation_shape: Shape of observations that will be stored in the buffer.
observation_dtype: Type of observations that will be stored in the buffer.
This can either be the type itself or string representation of the
type. The type can be either a native python type or a numpy type. If
a numpy type, a string of the form np.uint8 or numpy.uint8 is
acceptable.
action_shape: Shape of actions that will be stored in the buffer.
action_dtype: Type of actions that will be stored in the buffer. Format is
described in the description of observation_dtype.
action_shape: Shape of actions that will be stored in the buffer.
action_dtype: Type of actions that will be stored in the buffer. Format is
described in the description of observation_dtype.
reward_shape: Shape of rewards that will be stored in the buffer.
reward_dtype: Type of rewards that will be stored in the buffer. Format is
described in the description of observation_dtype.
extra_storage_types (dict): A dictionary describing extra items to store
in the buffer. The mapping should be from the name of the item to a
(type, shape) tuple.
num_players_sharing_buffer (int): Number of agents that share their
buffers. It is used for self-play.
"""
super().__init__(
capacity=capacity,
stack_size=stack_size,
n_step=n_step,
gamma=gamma,
observation_shape=observation_shape,
observation_dtype=observation_dtype,
action_shape=action_shape,
action_dtype=action_dtype,
reward_shape=reward_shape,
reward_dtype=reward_dtype,
extra_storage_types=extra_storage_types,
num_players_sharing_buffer=num_players_sharing_buffer,
)
self._sum_tree = SumTree(self._capacity)
self._alpha = alpha
self._beta = beta
[docs] def set_beta(self, beta):
self._beta = beta
def _add_transition(self, priority=None, **transition):
if priority is None:
priority = self._sum_tree.max_recorded_priority
else:
priority = priority**self._alpha
self._sum_tree.set_priority(self._cursor, priority)
super()._add_transition(**transition)
def _pad_buffer(self, pad_length):
for _ in range(pad_length):
transition = {
key: np.zeros_like(self._storage[key][0]) for key in self._storage
}
transition["priority"] = 0
self._add_transition(**transition)
def _sample_indices(self, batch_size):
indices = self._sum_tree.stratified_sample(batch_size)
indices = self._filter_transitions(indices)
while len(indices) < batch_size:
new_indices = self._sum_tree.sample(batch_size - len(indices))
new_indices = self._filter_transitions(new_indices)
indices = np.concatenate([indices, new_indices])
return indices
def _filter_transitions(self, indices):
indices = super()._filter_transitions(indices - (self._stack_size - 1)) + (
self._stack_size - 1
)
if self._num_added < self._capacity:
indices = indices[indices < self._cursor - self._n_step]
indices = indices[indices >= self._stack_size - 1]
else:
low = (self._cursor - self._n_step) % self._capacity
high = (self._cursor + self._stack_size - 2) % self._capacity
if low <= high:
indices = indices[np.logical_or(indices < low, indices > high)]
else:
indices = indices[~np.logical_or(indices >= low, indices <= high)]
return indices
[docs] def sample(self, batch_size):
batch = super().sample(batch_size)
indices = batch["indices"]
priorities = self._sum_tree.get_priorities(indices)
weights = (1.0 / (priorities + 1e-10)) ** self._beta
weights /= np.max(weights)
batch["weights"] = weights
return batch
[docs] def update_priorities(self, indices, priorities):
"""Update the priorities of the transitions at the specified indices.
Args:
indices: Which transitions to update priorities for. Can be numpy array
or torch tensor.
priorities: What the priorities should be updated to. Can be numpy array
or torch tensor.
"""
indices = numpify(indices)
priorities = numpify(priorities) ** self._alpha
indices, unique_idxs = np.unique(indices, return_index=True)
priorities = priorities[unique_idxs]
self._sum_tree.set_priority(indices, priorities)
[docs] def save(self, dname):
super().save(dname)
self._sum_tree.save(dname)
[docs] def load(self, dname):
super().load(dname)
self._sum_tree.load(dname)
[docs]class SumTree:
"""Data structure used to implement prioritized sampling. It is implemented
as a tree where the value of each node is the sum of the values of the subtree
of the node.
"""
def __init__(self, capacity: int):
self._capacity = capacity
self._depth = int(np.ceil(np.log2(capacity))) + 1
self._tree = np.zeros(2**self._depth - 1)
self._last_level_start = 2 ** (self._depth - 1) - 1
self._priorities = self._tree[
self._last_level_start : self._last_level_start + self._capacity
]
self.max_recorded_priority = 1.0
self._rng = np.random.default_rng(seed=seeder.get_new_seed("replay"))
[docs] def set_priority(self, indices, priorities):
"""Sets the priorities for the given indices.
Args:
indices (np.ndarray): Which transitions to update priorities for.
priorities (np.ndarray): What the priorities should be updated to.
"""
self.max_recorded_priority = max(self.max_recorded_priority, np.max(priorities))
indices = self._last_level_start + indices
diffs = priorities - self._tree[indices]
for _ in range(self._depth):
np.add.at(self._tree, indices, diffs)
indices = (indices - 1) // 2
[docs] def sample(self, batch_size):
"""Sample elements from the sum tree with probability proportional to their
priority.
Args:
batch_size (int): The number of elements to sample.
"""
indices = self.extract(self._rng.uniform(size=batch_size))
return indices
[docs] def stratified_sample(self, batch_size):
"""Performs stratified sampling using the sum tree.
Args:
batch_size (int): The number of elements to sample.
"""
query_values = (
np.arange(batch_size) + self._rng.uniform(size=batch_size)
) / batch_size
indices = self.extract(query_values)
return indices
[docs] def get_priorities(self, indices):
"""Get the priorities of the elements at indicies.
Args:
indices (np.ndarray): The indices to query.
"""
return self._priorities[indices]
[docs] def save(self, dname):
np.save(os.path.join(dname, "sumtree.npy"), self._tree)
state = {
"rng": self._rng,
}
with open(os.path.join(dname, "sumtree.pkl"), "wb") as f:
pickle.dump(state, f)
[docs] def load(self, dname):
self._tree = np.load(os.path.join(dname, "sumtree.npy"))
self._priorities = self._tree[
self._last_level_start : self._last_level_start + self._capacity
]
with open(os.path.join(dname, "sumtree.pkl"), "rb") as f:
state = pickle.load(f)
self._rng = state["rng"]