hive.replays.replay_buffer module

class hive.replays.replay_buffer.BaseReplayBuffer[source]

Bases: ABC, Registrable

Base class for replay buffers. Every implemented buffer should be a subclass of this class.

abstract add(**data)[source]

Adds data to the buffer

Parameters

data – data to add to the replay buffer. Subclasses can define this class signature based on use case.

abstract sample(batch_size)[source]

sample a minibatch

Parameters

batch_size (int) – the number of transitions to sample.

abstract size()[source]

Returns the number of transitions stored in the buffer.

abstract save(dname)[source]

Saves buffer checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

abstract load(dname)[source]

Loads buffer from file.

Parameters

dname (str) – directory name where buffer checkpoint info is stored.

Returns

True if successfully loaded the buffer. False otherwise.

classmethod type_name()[source]

Returns: “replay”