hive.replays.prioritized_replay module

class hive.replays.prioritized_replay.PrioritizedReplayBuffer(capacity, beta=0.5, stack_size=1, n_step=1, gamma=0.9, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]

Bases: CircularReplayBuffer

Implements a replay with prioritized sampling. See https://arxiv.org/abs/1511.05952

Parameters
  • capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.

  • beta (float) – Parameter controlling level of prioritization.

  • stack_size (int) – The number of frames to stack to create an observation.

  • n_step (int) – Horizon used to compute n-step return reward

  • gamma (float) – Discounting factor used to compute n-step return reward

  • observation_shape (Tuple) – Shape of observations that will be stored in the buffer.

  • observation_dtype (type) – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.

  • action_shape (Tuple) – Shape of actions that will be stored in the buffer.

  • action_dtype (type) – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • reward_shape (Tuple) – Shape of rewards that will be stored in the buffer.

  • reward_dtype (type) – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.

  • extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.

  • num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.

set_beta(beta)[source]
sample(batch_size)[source]

Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.

Parameters

batch_size (int) – Number of transitions to sample.

update_priorities(indices, priorities)[source]

Update the priorities of the transitions at the specified indices.

Parameters
  • indices – Which transitions to update priorities for. Can be numpy array or torch tensor.

  • priorities – What the priorities should be updated to. Can be numpy array or torch tensor.

save(dname)[source]

Save the replay buffer.

Parameters

dname (str) – directory where to save buffer. Should already have been created.

load(dname)[source]

Load the replay buffer.

Parameters

dname (str) – directory where to load buffer from.

class hive.replays.prioritized_replay.SumTree(capacity)[source]

Bases: object

Data structure used to implement prioritized sampling. It is implemented as a tree where the value of each node is the sum of the values of the subtree of the node.

set_priority(indices, priorities)[source]

Sets the priorities for the given indices.

Parameters
  • indices (np.ndarray) – Which transitions to update priorities for.

  • priorities (np.ndarray) – What the priorities should be updated to.

sample(batch_size)[source]

Sample elements from the sum tree with probability proportional to their priority.

Parameters

batch_size (int) – The number of elements to sample.

stratified_sample(batch_size)[source]

Performs stratified sampling using the sum tree.

Parameters

batch_size (int) – The number of elements to sample.

extract(queries)[source]

Get the elements in the sum tree that correspond to the query. For each query, the element that is selected is the one with the greatest sum of “previous” elements in the tree, but also such that the sum is not a greater proportion of the total sum of priorities than the query.

Parameters

queries (np.ndarray) – Queries to extract. Each element should be between 0 and 1.

get_priorities(indices)[source]

Get the priorities of the elements at indicies.

Parameters

indices (np.ndarray) – The indices to query.

save(dname)[source]
load(dname)[source]