hive.replays.circular_replay module

class hive.replays.circular_replay.CircularReplayBuffer(capacity=10000, stack_size=1, n_step=1, gamma=0.99, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]

Bases: BaseReplayBuffer

An efficient version of a circular replay buffer that only stores each observation once.

Constructor for CircularReplayBuffer.

Parameters
  • capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.

  • stack_size (int) – The number of frames to stack to create an observation.

  • n_step (int) – Horizon used to compute n-step return reward

  • gamma (float) – Discounting factor used to compute n-step return reward

  • observation_shape – Shape of observations that will be stored in the buffer.

  • observation_dtype – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • reward_shape – Shape of rewards that will be stored in the buffer.

  • reward_dtype – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.

  • extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.

  • num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.

size()[source]

Returns the number of transitions stored in the buffer.

add(observation, action, reward, done, **kwargs)[source]

Adds a transition to the buffer. The required components of a transition are given as positional arguments. The user can pass additional components to store in the buffer as kwargs as long as they were defined in the specification in the constructor.

sample(batch_size)[source]

Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.

Parameters

batch_size (int) – Number of transitions to sample.

save(dname)[source]

Save the replay buffer.

Parameters

dname (str) – directory where to save buffer. Should already have been created.

load(dname)[source]

Load the replay buffer.

Parameters

dname (str) – directory where to load buffer from.

class hive.replays.circular_replay.SimpleReplayBuffer(capacity=100000.0, compress=False, seed=42, **kwargs)[source]

Bases: BaseReplayBuffer

A simple circular replay buffers.

Parameters
  • capacity (int) – repaly buffer capacity

  • compress (bool) – if False, convert data to float32 otherwise keep it as int8.

  • seed (int) – Seed for a pseudo-random number generator.

add(observation, action, reward, done, **kwargs)[source]

Adds transition to the buffer

Parameters
  • observation – The current observation

  • action – The action taken on the current observation

  • reward – The reward from taking action at current observation

  • done – If current observation was the last observation in the episode

sample(batch_size=32)[source]

sample a minibatch

Parameters

batch_size (int) – The number of examples to sample.

size()[source]

returns the number of transitions stored in the replay buffer

save(dname)[source]

Saves buffer checkpointing information to file for future loading.

Parameters

dname (str) – directory name where agent should save all relevant info.

load(dname)[source]

Loads buffer from file.

Parameters

dname (str) – directory name where buffer checkpoint info is stored.

Returns

True if successfully loaded the buffer. False otherwise.

hive.replays.circular_replay.str_to_dtype(dtype)[source]