Welcome to RLHive’s documentation!

Quickstart

Installation

RLHive is available through pip! For the basic RLHive package, simply run pip install rlhive.

You can also install dependencies necessary for the environments that RLHive comes with by running pip install rlhive[<env_names>] where <env_names> is a comma separated list made up of the following:

  • atari

  • gym_minigrid

  • pettingzoo

In addition to these environments, Minatar and Marlgrid are also supported, but need to be installed separately.

To install Minatar, run pip install MinAtar@git+https://github.com/kenjyoung/MinAtar.git@8b39a18a60248ede15ce70142b557f3897c4e1eb

To install Marlgrid, run pip install marlgrid@https://github.com/kandouss/marlgrid/archive/refs/heads/master.zip

Running an experiment

There are several ways to run an experiment with RLHive. If you want to just run a preset config, you can directly run your experiment from the command line, with a config file path relative to the hive/configs folder. These examples run a DQN on the Atari game Asterix according to the Dopamine configuration and a simplified Rainbow agent for Hanabi trained using self-play according to the DeepMind’s configuration

hive_single_agent_loop -p atari/dqn.yml
hive_multi_agent_loop -p hanabi/rainbow.yml

If you want to run an experiment with components that are all available in RLHive, but not presets, you can create your own config file, and run that instead! Make sure you look at the examples here and the tutorial here to properly format it:

hive_single_agent_loop -c <config-file>
hive_multi_agent_loop -c <config-file>

Finally, if instead you want to use your own custom custom components you can simply register it with RLHive and run your config normally:

import hive
from hive.runners.utils import load_config
from hive.runners.single_agent_loop import set_up_experiment

class CustomAgent(hive.agents.Agent):
    # Definition of Agent
    pass

hive.registry.register('CustomAgent', CustomAgent, CustomAgent)

# Either load your custom full config file with that includes CustomAgent
config = load_config(config='custom_full_config.yml')
runner = set_up_experiment(config)
runner.run_training()

# Or load a preset config and just replace the agent config
config = load_config(preset_config='atari/dqn.yml', agent_config='custom_agent_config.yml')
runner = set_up_experiment(config)
runner.run_training()

Tutorials

Agent

Agent API

Interacting with the agent happens primarily through two functions: agent.act() and agent.update(). agent.act() takes in an observation and returns an action, while agent.update() takes in a dictionary consisting of the information relevant to the most recent transition and updates the agent.

Creating an Agent

Let’s create a new tabular Q-learning agent with discount factor gamma and learning rate alpha, for some environment with one hot observations. We want this agent to have an epsilon greedy policy, with the exploration rate decaying over explore_steps from 1.0 to some value final_epsilon. First, we define the constructor:

import numpy as np
import os

import hive
from hive.agents.agent import Agent
from hive.utils.schedule import LinearSchedule

class TabularQLearningAgent(Agent):
    def __init__(self, obs_dim, act_dim, gamma, alpha, explore_steps, final_epsilon, id=0):
        super().__init__(obs_dim, act_dim, id=id)
        self._q_values = np.zeros(obs_dim, act_dim)
        self._gamma = gamma
        self._alpha = alpha
        self._epsilon_schedule = LinearSchedule(1.0, final_epsilon, explore_steps)

In this constructor, we created a numpy array to keep track of the Q-values for every state-action pair, and a linear decay schedule for the epsilon exploration rate. Next, let’s create the act function:

def act(self, observation):
    # Return a random action if exploring
    if np.random.rand() < self._epsilon_schedule.update():
        return np.random.randint(self._act_dim)
    else:
        state = np.argmax(observation) # Convert from one-hot

        # Break ties randomly between all actions with max values
        max_value = np.amax(self._q_values[state])
        best_actions = np.where(self._q_values[state] == max_value)[0]
        return np.random.choice(best_actions)

Now, we write our update function, which updates the state of our agent:

def update(self, update_info):
    state = np.argmax(update_info["observation"])
    next_state = np.argmax(update_info["next_observation"])

    self._q_values[state, update_info["action"]] += self._alpha * (
        update_info["reward"]
        + self._gamma * np.amax(self._q_values[next_state])
        - self._q_values[state, action]
    )

Now, we can directly use this environment with the single agent or multi-agent runners. Note act and update are framework agnostic, so you could implement it with any (deep) learning framework, although most of our implemented agents are written in PyTorch.

If we write a save and load function for this agent, we can also take advantage of checkpointing and resuming in the runner:

def save(self, dname):
    np.save(os.path.join(dname, "qvalues.npy"), self._q_values)
    pickle.dump({"schedule": self._epsilon_schedule}, open("state.p", "wb"))

def load(self, dname):
    self._q_values = np.load(os.path.join(dname, "qvalues.npy"))
    self._epsilon_schedule = pickle.load(open("state.p", "rb"))["schedule"]

Finally, we register our agent class, so that it can be found when setting up experiments through the yaml config files and command line.

hive.registry.register('TabularQLearningAgent', TabularQLearningAgent, Agent)

Configuration

RLHive was written to allow for fast configuration and iteration on configuration. The base configuration for all parameters for the experiment is done through YAML files. The majority of these parameters can be overriden through the command line.

Any object registered with RLHive can be configured directly through YAML files or through the command line, without having to add any extra argument parsers.

YAML files

Let’s do a basic example of configuring an agent through YAML files. Specifically, let’s create a DQNAgent.

agent:
  name: DQNAgent
  kwargs:
    representation_net:
      name: MLPNetwork
      kwargs:
        hidden_units: [256, 256]
    discount_rate: .9
    replay_buffer:
      name: CircularReplayBuffer
    reward_clip: 1.0

In this example, DQNAgent , MLPNetwork , and CircularReplayBuffer are all classes registered with RLHive. Thus, we can do this configuration directly. When the registry getter function for agents, get_agent() is then called with this config dictionary (with the missing required arguments such as obs_dim and act_dim, filled in), it will build all the inner RLHive objects automatically. This works by using the type annotations on the constructors of the objects, so to recursively create the internal objects, those arguments need to be annotated correctly.

Overriding from command lines

When using the registry getter functions, RLHive automatically checks any command line arguments passed to see if they match/override any default or yaml configured arguments. With getter functionyou provide a config and a prefix. That prefix is added prepended to any argument names when searching the command line. For example, with the above config, if it were loaded and the get_agent() method was called as follows:

agent = get_agent(config['agent'], 'ag')

Then, to override the discount_rate, you could pass the following argument to your python script: --ag.discount_rate .95. This can go arbitrarily deep into registered RLHive class. For example, if you wanted to change the capacity of the replay buffer, you could pass --ag.replay_buffer.capacity 100000.

If the type annotation the argument arg is List[C] where C is a registered RLHive class, then you can override the argument of an individual object, foo, configured through YAML by passing --arg.0.foo <value>.

Note as of this version, you must have configured the object in the YAML file in order to override its parameters through the command line.

Using the DQN/Rainbow Agents

The DQNAgent and RainbowDQNAgent are written to allow for easy extensions and adaptation to your applications. We outline a few different use cases here.

Using a different network architecture

Using different types of network architectures with DQNAgent and RainbowDQNAgent is done using the representation_net parameter in the constructor. This network should not include the final layer which computes the final Q-values. It computes the representations that are fed into the layer which will compute the final Q-values. This is because often the only difference between different variations of the DQN algorithms is how the final Q-values are computed, with the rest of the architecture not changing.

You can modify the architecture of the representation network from the config, or create a completely new architecture better suited to your needs. From the config, two different types of network architectures are supported:

  • ConvNetwork: Networks with convolutional layers, followed by an MLP

  • MLPNetwork: An MLP with only linear layers

See this page for details on how to configure the network.

To use an architecture not supported by the above classes, simply write the Pytorch module implementing the architecture, and register the class wrapped with FunctionApproximator wrapper. The only requirement is that this class should take in the input dimension as the first positional argument:

import torch

import hive
from hive.agents.qnets import FunctionApproximator

class CustomArchitecture(torch.nn.Module):
    def __init__(self, in_dim, hidden_units):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_units),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_units, hidden_units)
        )

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        return self.network(x)

hive.registry.register(
    'CustomArchitecture',
    FunctionApproximator(CustomArchitecture),
    FunctionApproximator
)

Adding in different Rainbow components

The Rainbow architecture is composed of several different components, namely:

  • Double Q-learning

  • Prioritized Replay

  • Dueling Networks

  • Multi-step Learning

  • Distributional RL

  • Noisy Networks

Each of these components can be independently used with our RainbowDQNAgent class. To use Prioritized Replay, you must pass a PrioritizedReplayBuffer to the replay_buffer parameter of RainbowDQNAgent. The details for how to use the other components of rainbow are found in the API documentation of RainbowDQNAgent.

Custom Input Observations

The current implementations of DQNAgent and RainbowDQNAgent handle the standard case of observations being a single numpy array, and no extra inputs being necessary during the update phase other than action, reward, and done. In the situation where this is not the case, and you need to handle more complex inputs, you can do so by overriding the methods of DQNAgent. Let’s walk through the example of LegalMovesRainbowAgent. This agent takes in a list of legal moves on each turn and only selects from those.

class LegalMovesHead(torch.nn.Module):
    def __init__(self, base_network):
        super().__init__()
        self.base_network = base_network

    def forward(self, x, legal_moves):
        x = self.base_network(x)
        return x + legal_moves

    def dist(self, x, legal_moves):
        return self.base_network.dist(x)

class LegalMovesRainbowAgent(RainbowDQNAgent):
    """A Rainbow agent which supports games with legal actions."""

    def create_q_networks(self, representation_net):
        """Creates the qnet and target qnet."""
        super().create_q_networks(representation_net)
        self._qnet = LegalMovesHead(self._qnet)
        self._target_qnet = LegalMovesHead(self._target_qnet)

This defines a wrapper around the Q-networks used by agent that takes an encoding of the legal moves where illegal moves have value \(-\infty\) and legal moves have value \(0\). The wrapper then adds this encoding to the values generated by the base Q-networks. Overriding create_q_networks() allows you to modify the base Q-networks by adding this wrapper.

def preprocess_update_batch(self, batch):
    for key in batch:
        batch[key] = torch.tensor(batch[key], device=self._device)
    return (
        (batch["observation"], batch["action_mask"]),
        (batch["next_observation"], batch["next_action_mask"]),
        batch,
    )

Now, since the Q-networks expect an extra parameter (the legal moves action mask), we override the preprocess_update_batch() method, which takes a batch sampled from the replay buffer and defines the inputs that will be used to compute the values of the current state and the next state during the update step.

def preprocess_update_info(self, update_info):
    preprocessed_update_info = {
        "observation": update_info["observation"]["observation"],
        "action": update_info["action"],
        "reward": update_info["reward"],
        "done": update_info["done"],
        "action_mask": action_encoding(update_info["observation"]["action_mask"]),
    }
    if "agent_id" in update_info:
        preprocessed_update_info["agent_id"] = int(update_info["agent_id"])
    return preprocessed_update_info

We must also make sure that the action encoding for each transition is added to the replay buffer in the first place. To do that, we override the preprocess_update_info() method, which should return a dictionary with keys and values corresponding to the items you wish to store into the replay buffer. Note, these keys need to be specified when you create the replay buffer, see Replays for more information.

@torch.no_grad()
def act(self, observation):
    if self._training:
        if not self._learn_schedule.get_value():
            epsilon = 1.0
        elif not self._use_eps_greedy:
            epsilon = 0.0
        else:
            epsilon = self._epsilon_schedule.update()
        if self._logger.update_step(self._timescale):
            self._logger.log_scalar("epsilon", epsilon, self._timescale)
    else:
        epsilon = self._test_epsilon

    vectorized_observation = torch.tensor(
        np.expand_dims(observation["observation"], axis=0), device=self._device
    ).float()
    legal_moves_as_int = [
        i for i, x in enumerate(observation["action_mask"]) if x == 1
    ]
    encoded_legal_moves = torch.tensor(
        action_encoding(observation["action_mask"]), device=self._device
    ).float()
    qvals = self._qnet(vectorized_observation, encoded_legal_moves).cpu()

    if self._rng.random() < epsilon:
        action = np.random.choice(legal_moves_as_int).item()
    else:
        action = torch.argmax(qvals).item()

    return action

Finally, you also need to override the act() method to extract and use the extra information.

Environments

Installing Environments

We support several environments in RLHive, namely:

  • Atari

  • Gym classic control

  • Minatar (simplified Atari)

  • Minigrid (single-agent grid world)

  • Marlgrid (multi-agent)

  • Pettingzoo (multi-agent)

While gym comes installed with the base package, you need to install the other environments. See Installation for more details.

Creating an Environment

RLHive Environments

Every environment used in RLHive should be a subclass of ~hive.envs.base.BaseEnv. It should provide a reset function that resets the environment to a new episode and returns a tuple of (observation, turn) and a step function that takes in an action, performs the step in the environment, and returns a tuple of (observation, reward, done, turn, info). All these values correspond to their canonical meanings, and turn corresponds to the index of the agent whose turn it is (in multi-agent environments).

The reward return value can be a single number, an array, or a dictionary. If it’s a number, then that same reward will be given to every single agent. If it’s an array, the agents get the reward corresponding to their index in the runner. If it’s a dictionary, the keys should be the agent ids, and the value the reward for that agent.

Each environment should also provide an EnvSpec environment that will provide information about the environment such as the expected observation shape and action dimension for each agent. These should be lists with one element for each agent. See GymEnv for an example.

Gym environments

If your environment is a gym environment, and you do not need to preprocess the observations generated by the environment, then you can directly use the GymEnv. Just make sure you register your environment with gym, and pass the name of the environment to the GymEnv constructor.

If you need to add extra preprocessing or change the default way that environment/EnvSpec creation is done, you can simply subclass this class and override either create_env() and/or create_env_spec(), as in AtariEnv.

Parallel Step Environments

Multi-agent environments usually come in two flavors: sequential step environments, where each agent takes it’s action one at a time, and parallel step environments, where each agent steps at the same time. The MultiAgentRunner class expects only sequential step environments. Fortunately, we can convert between parallel step environments and single step environments by simply generating the action for each agent one at a time and passing all the action to the parallel step environment all at once. To facilitate this, we provide a utility class ParallelEnv. Simply write the logic for your parallel step environment as normal, and then create a single step version of the environment by subclassing ParallelEnv and the parallel step environment, making sure to put ParallelEnv first in the superclass list.

from hive.envs.base import BaseEnv, ParallelEnv

class ParallelStepEnvironment(BaseEnv):
    # Write the logic needed for the parallel step environment. Assume the step
    # function gets an array actions as it's input, and should return an array
    # containing the observations for each agent, as well as the other return
    # values expected by the environment.

class SequentialStepEnvironment(ParallelEnv, ParallelStepEnvironment):
    # Any other logic needed to create the environmnet.

Loggers

Using a Logger

RLHive currently provides 3 types of loggers:

All of these loggers are ScheduledLoggers. When creating these loggers, you provide the different timescales that you want to track with the logger. Timescales could correspond to any loop variable, such as "train_step", "agent_step", "test_iteration", etc.

With ScheduledLoggers, each timescale is associated with a Schedule object. By updating these schedules in some loop, you can control the logging frequency for that timescale. The loggers also keep track of how many times that timescale was updated, and those values are logged alongside any metric you log (i.e. if "timescale_1" was updated 7 times, then the logger will log an additional key value pair of "timescale_1": 7 with each metric. This allows you to see the trends of any metric across any timescale.

Timescales can be registered when creating the logger, or later on.

For an example:

from hive.utils.loggers import ChompLogger, CompositeLogger, WandbLogger
from hive.utils.schedule import ConstantSchedule, PeriodicSchedule

logger = ChompLogger(
    ['train_step', 'test_step'], # Timescales to track
    [ # How often to log train_step and test_step
        PeriodicSchedule(False, True, 10), # train_step is logged once every 10 times
        ConstantSchedule(True) # test_step is always logged
    ]
)

# You can register timescales after logger creation as well. This particular timescale
# is never scheduled to log.
logger.register_timescale('dummy_timescale', ConstantSchedule(False))

for _ in range(total_training_time):
    metrics_to_log = run_training_step()
    if logger.update_step('train_step'): # Evaluates to True once every 10 times it's hit.
        logger.log_metrics(metrics_to_log, 'training_metrics')


    other_metric = do_something_else()
    if logger.should_log('train_step'): # Checks schedule for this timescale without updating
        logger.log_scalar('other_metric', other_metric, 'training_metrics')


    if should_run_testing():
        test_metrics = run_testing()
        if logger.update_step('tets_step'): # Evaluates to True every time
            logger.log_metrics(test_metrics, 'testing_metrics')

Note, the schedules associated with each timescale are not hard constraints on when you can log. They are merely a convenience to help you keep track of when to log.

Composite Logger

If you want to log to multiple sources, without the hassle of keeping track of multiple loggers, you can create a CompositeLogger object initialized with each of the individual loggers that you want to use. For example:

from hive.utils.loggers import ChompLogger, CompositeLogger, WandbLogger

logger = CompositeLogger([
    ChompLogger('train'),
    WandbLogger('train')
])

You can now use this logger as above, and it will log to both ChompLogger and WandbLogger.

Registration

Registering new Classes

To register a new class with the Registry, you need to make sure that it or one of its ancestors subclassed hive.utils.registry.Registrable and provided a definition for hive.utils.registry.Registrable.type_name(). The return value of this function is used to create the getter function for that type (registry.get_{type_name}).

You can either register several different classes at once:

registry.register_all(
    Agent,
    {
        "DQNAgent": DQNAgent,
        "LegalMovesRainbowAgent": LegalMovesRainbowAgent,
        "RainbowDQNAgent": RainbowDQNAgent,
        "RandomAgent": RandomAgent,
    },
)

or one at a time:

registry.register("DQNAgent", DQNAgent, Agent)
registry.register("LegalMovesRainbowAgent", LegalMovesRainbowAgent, Agent)
registry.register("RainbowDQNAgent", RainbowDQNAgent, Agent)
registry.register("RandomAgent", RandomAgent, Agent)

After a class has been registered, you can use pass a config dictionary to the getter function for that type to create the object.

Callables

There are several cases where we want to parameterize some function or constructor partway, but not pass the fully created object in as an argument. One example is optimizers. You might want to pass a learning rate, but you cannot create the final optimizer object until you’ve created the parameters you want to optimize. To deal with such cases, we provide a CallableType class, which can be used to register and wrap any callable. For example, with optimizers, we have:

class OptimizerFn(CallableType):
    """A wrapper for callables that produce optimizer functions.

    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "optimizer_fn"
        """
        return "optimizer_fn"

registry.register_all(
    OptimizerFn,
    {
        "Adadelta": OptimizerFn(optim.Adadelta),
        "Adagrad": OptimizerFn(optim.Adagrad),
        "Adam": OptimizerFn(optim.Adam),
        "Adamax": OptimizerFn(optim.Adamax),
        "AdamW": OptimizerFn(optim.AdamW),
        "ASGD": OptimizerFn(optim.ASGD),
        "LBFGS": OptimizerFn(optim.LBFGS),
        "RMSprop": OptimizerFn(optim.RMSprop),
        "RMSpropTF": OptimizerFn(RMSpropTF),
        "Rprop": OptimizerFn(optim.Rprop),
        "SGD": OptimizerFn(optim.SGD),
        "SparseAdam": OptimizerFn(optim.SparseAdam),
    },
)

With this, we can now make use of the configurability of RLHive objects while still passing callables as arguments.

Replays

RLHive currently provides 4 types of Replays:

The main replay buffer classes that you will likely use/extend are CircularReplayBuffer and PrioritizedReplayBuffer. By default, these classes expect the arguments "observation", "action", "reward", and "done" when adding to the buffer. You can also provide alternative shapes/dtypes for these keys, and the buffer will try to automatically cast the objects you add to the buffer.

Along with these default keys, you can also store extra keys in the buffer. When creating the buffer, provide a dictionary with key-value pairs key: (type, shape). When adding, you can directly provide this key as an argument to the add() method, and it will automatically be added to the batch dictionary that you sample.

Runners

We provide two different Runner classes: SingleAgentRunner and MultiAgentRunner. The setup for both Runner classes can be viewed in their respective files with the set_up_experiment() functions. The get_parsed_args() function can be used to get any arguments from the command line are not part of the signatures of already registered RLHive class constructors.

Metrics and TransitionInfo

The Metrics class can be used to keep track of metrics for single/multiple agents across an episode.

# Create the Metrics object. The first set of metrics is individual to
# each agent, the second is common for all agents. The metrics can
# be initialized either with a value or with callable with no arguments
metrics = Metrics(
    [agent1, agent2],
    [("reward", 0), ("episode_traj", lambda: [])],
    [("full_episode_length", 0)],
)

# Add metrics
metrics[agent1.id]["reward"] += 1
metrics[agent2.id]["episode_traj"].append(0)
metrics["full_episode_length"] += 1

# Convert to flat dictionary for easy logging. Adds agent id's as prefixes
# for agent_specific metrics
flat_metrics = metrics.get_flat_dict()

# Reinitialize/reset all metrics
metrics.reset_metrics()

The TransitionInfo class can be used to keep track of the information needed by the agent to construct it’s next state for acting or next transition for updating. It also handles state stacking and padding.

transition_info = TransitionInfo([agent1, agent2], stack_size)

# Set the start flag for agent1.
transition_info.start(agent1)

# Get stacked observation for agent1. If not enough observations have been
# recorded, it will pad with 0s
stacked_observation = transition_info.get_stacked_state(
    agent1, observation
)

# Record update information about the agent
transition_info.record_info(agent, info)

# Get the update information for the agent, with done set to the value passed
info = transition_info.get_info(agent, done=done)

RLHive API

hive package

Subpackages

hive.agents package
Subpackages
hive.agents.qnets package
Subpackages
hive.agents.qnets.atari package
Submodules
hive.agents.qnets.atari.nature_atari_dqn module
class hive.agents.qnets.atari.nature_atari_dqn.NatureAtariDQNModel(in_dim)[source]

Bases: ConvNetwork

The convolutional network used to train the DQN agent in the original Nature paper: https://www.nature.com/articles/nature14236

Parameters

in_dim (tuple) – The tuple of observations dimension (channels, width, height).

training: bool
Module contents
Submodules
hive.agents.qnets.base module
class hive.agents.qnets.base.FunctionApproximator(fn)[source]

Bases: CallableType

A wrapper for callables that produce function approximators.

For example, FunctionApproximator(create_neural_network) or FunctionApproximator(MyNeuralNetwork) where create_neural_network is a function that creates a neural network module and MyNeuralNetwork is a class that defines your function approximator.

These wrapped callables can be partially initialized through configuration files or command line arguments.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]
Returns

“function”

hive.agents.qnets.conv module
class hive.agents.qnets.conv.ConvNetwork(in_dim, channels=None, mlp_layers=None, kernel_sizes=1, strides=1, paddings=0, normalization_factor=255, noisy=False, std_init=0.5)[source]

Bases: Module

Basic convolutional neural network architecture. Applies a number of convolutional layers (each followed by a ReLU activation), and then feeds the output into an hive.agents.qnets.mlp.MLPNetwork.

Note, if channels is None, the network created for the convolution portion of the architecture is simply an torch.nn.Identity module. If mlp_layers is None, the mlp portion of the architecture is an torch.nn.Identity module.

Parameters
  • in_dim (tuple) – The tuple of observations dimension (channels, width, height).

  • channels (list) – The size of output channel for each convolutional layer.

  • mlp_layers (list) – The number of neurons for each mlp layer after the convolutional layers.

  • kernel_sizes (list | int) – The kernel size for each convolutional layer

  • strides (list | int) – The stride used for each convolutional layer.

  • paddings (list | int) – The size of the padding used for each convolutional layer.

  • normalization_factor (float | int) – What the input is divided by before the forward pass of the network.

  • noisy (bool) – Whether the MLP part of the network will use NoisyLinear layers or torch.nn.Linear layers.

  • std_init (float) – The range for the initialization of the standard deviation of the weights in NoisyLinear.

training: bool
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

hive.agents.qnets.mlp module
class hive.agents.qnets.mlp.MLPNetwork(in_dim, hidden_units=256, noisy=False, std_init=0.5)[source]

Bases: Module

Basic MLP neural network architecture.

Contains a series of torch.nn.Linear or NoisyLinear layers, each of which is followed by a ReLU.

Parameters
  • in_dim (tuple[int]) – The shape of input observations.

  • hidden_units (int | list[int]) – The number of neurons for each mlp layer.

  • noisy (bool) – Whether the MLP should use NoisyLinear layers or normal torch.nn.Linear layers.

  • std_init (float) – The range for the initialization of the standard deviation of the weights in NoisyLinear.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
hive.agents.qnets.noisy_linear module
class hive.agents.qnets.noisy_linear.NoisyLinear(in_dim, out_dim, std_init=0.5)[source]

Bases: Module

NoisyLinear Layer. Implements the layer described in https://arxiv.org/abs/1706.10295.

Parameters
  • in_dim (int) – The dimension of the input.

  • out_dim (int) – The desired dimension of the output.

  • std_init (float) – The range for the initialization of the standard deviation of the weights.

forward(inp)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
hive.agents.qnets.qnet_heads module
class hive.agents.qnets.qnet_heads.DQNNetwork(base_network, hidden_dim, out_dim, linear_fn=None)[source]

Bases: Module

Implements the standard DQN value computation. Transforms output from base_network with output dimension hidden_dim to dimension out_dim, which should be equal to the number of actions.

Parameters
  • base_network (torch.nn.Module) – Backbone network that computes the representations that are used to compute action values.

  • hidden_dim (int) – Dimension of the output of the network.

  • out_dim (int) – Output dimension of the DQN. Should be equal to the number of actions that you are computing values for.

  • linear_fn (torch.nn.Module) – Function that will create the torch.nn.Module that will take the output of network and produce the final action values. If None, a torch.nn.Linear layer will be used.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class hive.agents.qnets.qnet_heads.DuelingNetwork(base_network, hidden_dim, out_dim, linear_fn=None, atoms=1)[source]

Bases: Module

Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581). In dueling, we have two heads—one for estimating advantage function and one for estimating value function.

Parameters
  • base_network (torch.nn.Module) – Backbone network that computes the representations that are shared by the two estimators.

  • hidden_dim (int) – Dimension of the output of the base_network.

  • out_dim (int) – Output dimension of the Dueling DQN. Should be equal to the number of actions that you are computing values for.

  • linear_fn (torch.nn.Module) – Function that will create the torch.nn.Module that will take the output of network and produce the final action values. If None, a torch.nn.Linear layer will be used.

  • atoms (int) – Multiplier for the dimension of the output. For standard dueling networks, this should be 1. Used by DistributionalNetwork.

init_networks()[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class hive.agents.qnets.qnet_heads.DistributionalNetwork(base_network, out_dim, vmin=0, vmax=200, atoms=51)[source]

Bases: Module

Computes a categorical distribution over values for each action (https://arxiv.org/abs/1707.06887).

Parameters
  • base_network (torch.nn.Module) – Backbone network that computes the representations that are used to compute the value distribution.

  • out_dim (int) – Output dimension of the Distributional DQN. Should be equal to the number of actions that you are computing values for.

  • vmin (float) – The minimum of the support of the categorical value distribution.

  • vmax (float) – The maximum of the support of the categorical value distribution.

  • atoms (int) – Number of atoms discretizing the support range of the categorical value distribution.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

dist(x)[source]

Computes a categorical distribution over values for each action.

training: bool
hive.agents.qnets.utils module
hive.agents.qnets.utils.calculate_output_dim(net, input_shape)[source]

Calculates the resulting output shape for a given input shape and network.

Parameters
  • net (torch.nn.Module) – The network which you want to calculate the output dimension for.

  • input_shape (int | tuple[int]) – The shape of the input being fed into the net. Batch dimension should not be included.

Returns

The shape of the output of a network given an input shape. Batch dimension is not included.

hive.agents.qnets.utils.create_init_weights_fn(initialization_fn)[source]

Returns a function that wraps initialization_function() and applies it to modules that have the weight attribute.

Parameters

initialization_fn (callable) – A function that takes in a tensor and initializes it.

Returns

Function that takes in PyTorch modules and initializes their weights. Can be used as follows:

init_fn = create_init_weights_fn(variance_scaling_)
network.apply(init_fn)

hive.agents.qnets.utils.calculate_correct_fan(tensor, mode)[source]

Calculate fan of tensor.

Parameters
  • tensor (torch.Tensor) – Tensor to calculate fan of.

  • mode (str) – Which type of fan to compute. Must be one of “fan_in”, “fan_out”, and “fan_avg”.

Returns

Fan of the tensor based on the mode.

hive.agents.qnets.utils.variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='uniform')[source]

Implements the tf.keras.initializers.VarianceScaling initializer in PyTorch.

Parameters
  • tensor (torch.Tensor) – Tensor to initialize.

  • scale (float) – Scaling factor (must be positive).

  • mode (str) – Must be one of “fan_in”, “fan_out”, and “fan_avg”.

  • distribution – Random distribution to use, must be one of “truncated_normal”, “untruncated_normal” and “uniform”.

Returns

Initialized tensor.

class hive.agents.qnets.utils.InitializationFn(fn)[source]

Bases: CallableType

A wrapper for callables that produce initialization functions.

These wrapped callables can be partially initialized through configuration files or command line arguments.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]
Returns

“init_fn”

Module contents
Submodules
hive.agents.agent module
class hive.agents.agent.Agent(obs_dim, act_dim, id=0)[source]

Bases: ABC, Registrable

Base class for agents. Every implemented agent should be a subclass of this class.

Parameters
  • obs_dim – Dimension of observations that agent will see.

  • act_dim – Number of actions that the agent needs to chose from.

  • id – Identifier for the agent.

property id
abstract act(observation)[source]

Returns an action for the agent to perform based on the observation.

Parameters

observation – Current observation that agent should act on.

Returns

Action for the current timestep.

abstract update(update_info)[source]

Updates the agent.

Parameters

update_info (dict) – Contains information agent needs to update itself.

train()[source]

Changes the agent to training mode.

eval()[source]

Changes the agent to evaluation mode

abstract save(dname)[source]

Saves agent checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

abstract load(dname)[source]

Loads agent information from file.

Parameters

dname (str) – directory where agent checkpoint info is stored.

classmethod type_name()[source]
Returns

“agent”

hive.agents.dqn module
class hive.agents.dqn.DQNAgent(representation_net, obs_dim, act_dim, id=0, optimizer_fn=None, loss_fn=None, init_fn=None, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100)[source]

Bases: Agent

An agent implementing the DQN algorithm. Uses an epsilon greedy exploration policy

Parameters
  • representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

  • obs_dim – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to CircularReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

create_q_networks(representation_net)[source]

Creates the Q-network and target Q-network.

Parameters

representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

train()[source]

Changes the agent to training mode.

eval()[source]

Changes the agent to evaluation mode.

preprocess_update_info(update_info)[source]

Preprocesses the update_info before it goes into the replay buffer. Clips the reward in update_info.

Parameters

update_info – Contains the information from the current timestep that the agent should use to update itself.

preprocess_update_batch(batch)[source]

Preprocess the batch sampled from the replay buffer.

Parameters

batch – Batch sampled from the replay buffer for the current update.

Returns

  • (tuple) Inputs used to calculate current state values.

  • (tuple) Inputs used to calculate next state values

  • Preprocessed batch.

Return type

(tuple)

act(observation)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters

observation – The current observation.

update(update_info)[source]

Updates the DQN agent.

Parameters

update_info – dictionary containing all the necessary information to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, and “done”.

save(dname)[source]

Saves agent checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

load(dname)[source]

Loads agent information from file.

Parameters

dname (str) – directory where agent checkpoint info is stored.

hive.agents.legal_moves_rainbow module
class hive.agents.legal_moves_rainbow.LegalMovesRainbowAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]

Bases: RainbowDQNAgent

A Rainbow agent which supports games with legal actions.

Parameters
  • representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

  • obs_dim (Tuple) – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to PrioritizedReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

  • noisy (bool) – Whether to use noisy linear layers for exploration.

  • std_init (float) – The range for the initialization of the standard deviation of the weights.

  • use_eps_greedy (bool) – Whether to use epsilon greedy exploration.

  • double (bool) – Whether to use double DQN.

  • dueling (bool) – Whether to use a dueling network architecture.

  • distributional (bool) – Whether to use the distributional RL.

  • vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.

  • vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.

  • atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.

create_q_networks(representation_net)[source]

Creates the qnet and target qnet.

preprocess_update_info(update_info)[source]

Preprocesses the update_info before it goes into the replay buffer. Clips the reward in update_info.

Parameters

update_info – Contains the information from the current timestep that the agent should use to update itself.

preprocess_update_batch(batch)[source]

Preprocess the batch sampled from the replay buffer.

Parameters

batch – Batch sampled from the replay buffer for the current update.

Returns

  • (tuple) Inputs used to calculate current state values.

  • (tuple) Inputs used to calculate next state values

  • Preprocessed batch.

Return type

(tuple)

act(observation)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters

observation – The current observation.

class hive.agents.legal_moves_rainbow.LegalMovesHead(base_network)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, legal_moves)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

dist(x, legal_moves)[source]
training: bool
hive.agents.legal_moves_rainbow.action_encoding(action_mask)[source]
hive.agents.rainbow module
class hive.agents.rainbow.RainbowDQNAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]

Bases: DQNAgent

An agent implementing the Rainbow algorithm.

Parameters
  • representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

  • obs_dim (Tuple) – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to PrioritizedReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

  • noisy (bool) – Whether to use noisy linear layers for exploration.

  • std_init (float) – The range for the initialization of the standard deviation of the weights.

  • use_eps_greedy (bool) – Whether to use epsilon greedy exploration.

  • double (bool) – Whether to use double DQN.

  • dueling (bool) – Whether to use a dueling network architecture.

  • distributional (bool) – Whether to use the distributional RL.

  • vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.

  • vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.

  • atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.

create_q_networks(representation_net)[source]

Creates the Q-network and target Q-network. Adds the appropriate heads for DQN, Dueling DQN, Noisy Networks, and Distributional DQN.

Parameters

representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

act(observation)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters

observation – The current observation.

update(update_info)[source]

Updates the DQN agent. :param update_info: dictionary containing all the necessary information to :param update the agent. Should contain a full transition: :param with keys for: :param “observation”: :param “action”: :param “reward”: :param “next_observation”: :param and “done”.:

target_projection(target_net_inputs, next_action, reward, done)[source]

Project distribution of target Q-values.

Parameters
  • target_net_inputs – Inputs to feed into the target net to compute the projection of the target Q-values. Should be set from preprocess_update_batch().

  • next_action (Tensor) – Tensor containing next actions used to compute target distribution.

  • reward (Tensor) – Tensor containing rewards for the current batch.

  • done (Tensor) – Tensor containing whether the states in the current batch are terminal.

hive.agents.random module
class hive.agents.random.RandomAgent(obs_dim, act_dim, id=0, logger=None)[source]

Bases: Agent

An agent that takes random steps at each timestep.

Parameters
  • obs_dim – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

act(observation)[source]

Returns a random action for the agent.

update(update_info)[source]

Updates the agent.

Parameters

update_info (dict) – Contains information agent needs to update itself.

save(dname)[source]

Saves agent checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

load(dname)[source]

Loads agent information from file.

Parameters

dname (str) – directory where agent checkpoint info is stored.

Module contents
hive.envs package
Subpackages
hive.envs.atari package
Submodules
hive.envs.atari.atari module
Module contents
hive.envs.marlgrid package
Subpackages
hive.envs.marlgrid.ma_envs package
Submodules
hive.envs.marlgrid.ma_envs.base module
hive.envs.marlgrid.ma_envs.checkers module
hive.envs.marlgrid.ma_envs.pursuit module
hive.envs.marlgrid.ma_envs.switch module
Module contents
Submodules
hive.envs.marlgrid.marlgrid module
Module contents
hive.envs.minatar package
Submodules
hive.envs.minatar.minatar module
class hive.envs.minatar.minatar.MinAtarEnv(env_name, sticky_action_prob=0.1, difficulty_ramping=True)[source]

Bases: BaseEnv

Class for loading MinAtar environments. See https://github.com/kenjyoung/MinAtar.

Parameters
  • env_name (str) – Name of the environment

  • sticky_actions (bool) – Whether to use sticky_actions as per Machado et al.

  • difficulty_ramping (bool) – Whether to periodically increase difficulty.

create_env_spec(env_name)[source]
reset()[source]

Resets the state of the environment.

Returns

The initial observation of the new episode. turn (int): The index of the agent which should take turn.

Return type

observation

seed(seed=None)[source]

Reseeds the environment.

Parameters

seed (int) – Seed to use for environment.

step(action=None)[source]
Remarks:
  • Execute self.frame_skips steps taking the action in the the environment.

  • This may execute fewer than self.frame_skip steps in the environment, if the done state is reached.

  • Furthermore, in this case the returned observation should be ignored.

Module contents
hive.envs.minigrid package
Submodules
hive.envs.minigrid.minigrid module
Module contents
hive.envs.pettingzoo package
Submodules
hive.envs.pettingzoo.pettingzoo module
class hive.envs.pettingzoo.pettingzoo.PettingZooEnv(env_name, env_family, num_players, **kwargs)[source]

Bases: BaseEnv

PettingZoo environment from https://github.com/PettingZoo-Team/PettingZoo

For now, we only support environments from PettingZoo with discrete actions.

Parameters
  • env_name (str) – Name of the environment

  • env_family (str) – Family of the environment such as “Atari”,

  • "Classic"

  • "SISL"

  • "Butterfly"

  • "MAgent"

  • "MPE". (and) –

  • num_players (int) – Number of learning agents

create_env(env_name, num_players, **kwargs)[source]
create_env_spec(env_name, **kwargs)[source]

Each family of environments have their own type of observations and actions. You can add support for more families here by modifying obs_dim and act_dim.

reset()[source]

Resets the state of the environment.

Returns

The initial observation of the new episode. turn (int): The index of the agent which should take turn.

Return type

observation

step(action)[source]

Run one time-step of the environment using the input action.

Parameters

action – An element of environment’s action space.

Returns

Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.

Return type

observation

render(mode='rgb_array')[source]

Displays a rendered frame from the environment.

seed(seed=None)[source]

Reseeds the environment.

Parameters

seed (int) – Seed to use for environment.

close()[source]

Additional clean up operations

Module contents
hive.envs.wrappers package
Submodules
hive.envs.wrappers.gym_wrappers module
class hive.envs.wrappers.gym_wrappers.FlattenWrapper(env)[source]

Bases: ObservationWrapper

Flatten the observation to one dimensional vector.

Wraps an environment to allow a modular transformation of the step() and reset() methods.

Parameters
  • env – The environment to wrap

  • new_step_api – Whether the wrapper’s step method will output in new or old step API

observation(obs)[source]

Returns a modified observation.

class hive.envs.wrappers.gym_wrappers.PermuteImageWrapper(env)[source]

Bases: ObservationWrapper

Changes the image format from HWC to CHW

Wraps an environment to allow a modular transformation of the step() and reset() methods.

Parameters
  • env – The environment to wrap

  • new_step_api – Whether the wrapper’s step method will output in new or old step API

observation(obs)[source]

Returns a modified observation.

Module contents
Submodules
hive.envs.base module
class hive.envs.base.BaseEnv(env_spec, num_players)[source]

Bases: ABC, Registrable

Base class for environments.

Parameters
  • env_spec (EnvSpec) – An object containing information about the environment.

  • num_players (int) – The number of players in the environment.

abstract reset()[source]

Resets the state of the environment.

Returns

The initial observation of the new episode. turn (int): The index of the agent which should take turn.

Return type

observation

abstract step(action)[source]

Run one time-step of the environment using the input action.

Parameters

action – An element of environment’s action space.

Returns

Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.

Return type

observation

render(mode='rgb_array')[source]

Displays a rendered frame from the environment.

abstract seed(seed=None)[source]

Reseeds the environment.

Parameters

seed (int) – Seed to use for environment.

save(save_dir)[source]

Saves the environment.

Parameters

save_dir (str) – Location to save environment state.

load(load_dir)[source]

Loads the environment.

Parameters

load_dir (str) – Location to load environment state from.

close()[source]

Additional clean up operations

property env_spec
classmethod type_name()[source]

Returns: “env”

class hive.envs.base.ParallelEnv(env_name, num_players)[source]

Bases: BaseEnv

Base class for environments that take make all agents step in parallel.

ParallelEnv takes an environment that expects an array of actions at each step to execute in parallel, and allows you to instead pass it a single action at each step.

This class makes use of Python’s multiple inheritance pattern. Specifically, when writing your parallel environment, it should extend both this class and the class that implements the step method that takes in actions for all agents.

If environment class A has the logic for the step function that takes in the array of actions, and environment class B is your parallel step version of that environment, class B should be defined as:

class B(ParallelEnv, A):
    ...

The order in which you list the classes is important. ParallelEnv must come before A in the order.

Parameters
  • env_spec (EnvSpec) – An object containing information about the environment.

  • num_players (int) – The number of players in the environment.

reset()[source]

Resets the state of the environment.

Returns

The initial observation of the new episode. turn (int): The index of the agent which should take turn.

Return type

observation

step(action)[source]

Run one time-step of the environment using the input action.

Parameters

action – An element of environment’s action space.

Returns

Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.

Return type

observation

hive.envs.env_spec module
class hive.envs.env_spec.EnvSpec(env_name, obs_dim, act_dim, env_info=None)[source]

Bases: object

Object used to store information about environment configuration. Every environment should create an EnvSpec object.

Parameters
  • env_name – Name of the environment

  • obs_dim – Dimensionality of observations from environment. This can be a simple integer, or a complex object depending on the types of observations expected.

  • act_dim – Dimensionality of action space.

  • env_info – Any other info relevant to this environment. This can include items such as random seeds or parameters used to create the environment

property env_name
property obs_dim
property act_dim
property env_info
hive.envs.gym_env module
class hive.envs.gym_env.GymEnv(env_name, num_players=1, **kwargs)[source]

Bases: BaseEnv

Class for loading gym environments.

Parameters
  • env_name (str) – Name of the environment (NOTE: make sure it is available in gym.envs.registry.all())

  • num_players (int) – Number of players for the environment.

  • kwargs – Any arguments you want to pass to create_env() or create_env_spec() can be passed as keyword arguments to this constructor.

create_env(env_name, **kwargs)[source]

Function used to create the environment. Subclasses can override this method if they are using a gym style environment that needs special logic.

Parameters

env_name (str) – Name of the environment

create_env_spec(env_name, **kwargs)[source]

Function used to create the specification. Subclasses can override this method if they are using a gym style environment that needs special logic.

Parameters

env_name (str) – Name of the environment

reset()[source]

Resets the state of the environment.

Returns

The initial observation of the new episode. turn (int): The index of the agent which should take turn.

Return type

observation

step(action)[source]

Run one time-step of the environment using the input action.

Parameters

action – An element of environment’s action space.

Returns

Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.

Return type

observation

render(mode='rgb_array')[source]

Displays a rendered frame from the environment.

seed(seed=None)[source]

Reseeds the environment.

Parameters

seed (int) – Seed to use for environment.

close()[source]

Additional clean up operations

Module contents
hive.replays package
Submodules
hive.replays.circular_replay module
class hive.replays.circular_replay.CircularReplayBuffer(capacity=10000, stack_size=1, n_step=1, gamma=0.99, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]

Bases: BaseReplayBuffer

An efficient version of a circular replay buffer that only stores each observation once.

Constructor for CircularReplayBuffer.

Parameters
  • capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.

  • stack_size (int) – The number of frames to stack to create an observation.

  • n_step (int) – Horizon used to compute n-step return reward

  • gamma (float) – Discounting factor used to compute n-step return reward

  • observation_shape – Shape of observations that will be stored in the buffer.

  • observation_dtype – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • reward_shape – Shape of rewards that will be stored in the buffer.

  • reward_dtype – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.

  • extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.

  • num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.

size()[source]

Returns the number of transitions stored in the buffer.

add(observation, action, reward, done, **kwargs)[source]

Adds a transition to the buffer. The required components of a transition are given as positional arguments. The user can pass additional components to store in the buffer as kwargs as long as they were defined in the specification in the constructor.

sample(batch_size)[source]

Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.

Parameters

batch_size (int) – Number of transitions to sample.

save(dname)[source]

Save the replay buffer.

Parameters

dname (str) – directory where to save buffer. Should already have been created.

load(dname)[source]

Load the replay buffer.

Parameters

dname (str) – directory where to load buffer from.

class hive.replays.circular_replay.SimpleReplayBuffer(capacity=100000.0, compress=False, seed=42, **kwargs)[source]

Bases: BaseReplayBuffer

A simple circular replay buffers.

Parameters
  • capacity (int) – repaly buffer capacity

  • compress (bool) – if False, convert data to float32 otherwise keep it as int8.

  • seed (int) – Seed for a pseudo-random number generator.

add(observation, action, reward, done, **kwargs)[source]

Adds transition to the buffer

Parameters
  • observation – The current observation

  • action – The action taken on the current observation

  • reward – The reward from taking action at current observation

  • done – If current observation was the last observation in the episode

sample(batch_size=32)[source]

sample a minibatch

Parameters

batch_size (int) – The number of examples to sample.

size()[source]

returns the number of transitions stored in the replay buffer

save(dname)[source]

Saves buffer checkpointing information to file for future loading.

Parameters

dname (str) – directory name where agent should save all relevant info.

load(dname)[source]

Loads buffer from file.

Parameters

dname (str) – directory name where buffer checkpoint info is stored.

Returns

True if successfully loaded the buffer. False otherwise.

hive.replays.circular_replay.str_to_dtype(dtype)[source]
hive.replays.legal_moves_replay module
class hive.replays.legal_moves_replay.LegalMovesBuffer(capacity, beta=0.5, stack_size=1, n_step=1, gamma=0.9, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, action_dim=None, num_players_sharing_buffer=None)[source]

Bases: PrioritizedReplayBuffer

A Prioritized Replay buffer for the games like Hanabi with legal moves which need to add next_action_mask to the batch.

Parameters
  • capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.

  • beta (float) – Parameter controlling level of prioritization.

  • stack_size (int) – The number of frames to stack to create an observation.

  • n_step (int) – Horizon used to compute n-step return reward

  • gamma (float) – Discounting factor used to compute n-step return reward

  • observation_shape (Tuple) – Shape of observations that will be stored in the buffer.

  • observation_dtype (type) – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.

  • action_shape (Tuple) – Shape of actions that will be stored in the buffer.

  • action_dtype (type) – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • reward_shape (Tuple) – Shape of rewards that will be stored in the buffer.

  • reward_dtype (type) – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.

  • extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.

  • num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.

sample(batch_size)[source]

Sample transitions from the buffer. Adding next_action_mask to the batch for environments with legal moves.

hive.replays.prioritized_replay module
class hive.replays.prioritized_replay.PrioritizedReplayBuffer(capacity, beta=0.5, stack_size=1, n_step=1, gamma=0.9, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]

Bases: CircularReplayBuffer

Implements a replay with prioritized sampling. See https://arxiv.org/abs/1511.05952

Parameters
  • capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.

  • beta (float) – Parameter controlling level of prioritization.

  • stack_size (int) – The number of frames to stack to create an observation.

  • n_step (int) – Horizon used to compute n-step return reward

  • gamma (float) – Discounting factor used to compute n-step return reward

  • observation_shape (Tuple) – Shape of observations that will be stored in the buffer.

  • observation_dtype (type) – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.

  • action_shape (Tuple) – Shape of actions that will be stored in the buffer.

  • action_dtype (type) – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • action_shape – Shape of actions that will be stored in the buffer.

  • action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.

  • reward_shape (Tuple) – Shape of rewards that will be stored in the buffer.

  • reward_dtype (type) – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.

  • extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.

  • num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.

set_beta(beta)[source]
sample(batch_size)[source]

Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.

Parameters

batch_size (int) – Number of transitions to sample.

update_priorities(indices, priorities)[source]

Update the priorities of the transitions at the specified indices.

Parameters
  • indices – Which transitions to update priorities for. Can be numpy array or torch tensor.

  • priorities – What the priorities should be updated to. Can be numpy array or torch tensor.

save(dname)[source]

Save the replay buffer.

Parameters

dname (str) – directory where to save buffer. Should already have been created.

load(dname)[source]

Load the replay buffer.

Parameters

dname (str) – directory where to load buffer from.

class hive.replays.prioritized_replay.SumTree(capacity)[source]

Bases: object

Data structure used to implement prioritized sampling. It is implemented as a tree where the value of each node is the sum of the values of the subtree of the node.

set_priority(indices, priorities)[source]

Sets the priorities for the given indices.

Parameters
  • indices (np.ndarray) – Which transitions to update priorities for.

  • priorities (np.ndarray) – What the priorities should be updated to.

sample(batch_size)[source]

Sample elements from the sum tree with probability proportional to their priority.

Parameters

batch_size (int) – The number of elements to sample.

stratified_sample(batch_size)[source]

Performs stratified sampling using the sum tree.

Parameters

batch_size (int) – The number of elements to sample.

extract(queries)[source]

Get the elements in the sum tree that correspond to the query. For each query, the element that is selected is the one with the greatest sum of “previous” elements in the tree, but also such that the sum is not a greater proportion of the total sum of priorities than the query.

Parameters

queries (np.ndarray) – Queries to extract. Each element should be between 0 and 1.

get_priorities(indices)[source]

Get the priorities of the elements at indicies.

Parameters

indices (np.ndarray) – The indices to query.

save(dname)[source]
load(dname)[source]
hive.replays.replay_buffer module
class hive.replays.replay_buffer.BaseReplayBuffer[source]

Bases: ABC, Registrable

Base class for replay buffers. Every implemented buffer should be a subclass of this class.

abstract add(**data)[source]

Adds data to the buffer

Parameters

data – data to add to the replay buffer. Subclasses can define this class signature based on use case.

abstract sample(batch_size)[source]

sample a minibatch

Parameters

batch_size (int) – the number of transitions to sample.

abstract size()[source]

Returns the number of transitions stored in the buffer.

abstract save(dname)[source]

Saves buffer checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

abstract load(dname)[source]

Loads buffer from file.

Parameters

dname (str) – directory name where buffer checkpoint info is stored.

Returns

True if successfully loaded the buffer. False otherwise.

classmethod type_name()[source]

Returns: “replay”

Module contents
hive.runners package
Submodules
hive.runners.base module
class hive.runners.base.Runner(environment, agents, logger, experiment_manager, train_steps=1000000, test_frequency=10000, test_episodes=1, max_steps_per_episode=27000)[source]

Bases: ABC

Base Runner class used to implement a training loop.

Different types of training loops can be created by overriding the relevant functions.

Parameters
  • environment (BaseEnv) – Environment used in the training loop.

  • agents (list[Agent]) – List of agents that interact with the environment.

  • logger (ScheduledLogger) – Logger object used to log metrics.

  • experiment_manager (Experiment) – Experiment object that saves the state of the training.

  • train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.

  • test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.

  • test_episodes (int) – How many episodes to run testing for.

train_mode(training)[source]

If training is true, sets all agents to training mode. If training is false, sets all agents to eval mode.

Parameters

training (bool) – Whether to be in training mode.

create_episode_metrics()[source]

Create the metrics used during the loop.

run_one_step(observation, turn, episode_metrics)[source]

Run one step of the training loop.

Parameters
  • observation – Current observation that the agent should create an action for.

  • turn (int) – Agent whose turn it is.

  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

run_end_step(episode_metrics, done)[source]

Run the final step of an episode.

Parameters
  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

  • done (bool) – Whether this step was terminal.

run_episode()[source]

Run a single episode of the environment.

run_training()[source]

Run the training loop.

run_testing()[source]

Run a testing phase.

resume()[source]

Resume a saved experiment.

hive.runners.multi_agent_loop module
class hive.runners.multi_agent_loop.MultiAgentRunner(environment, agents, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, self_play, max_steps_per_episode=27000)[source]

Bases: Runner

Runner class used to implement a multiagent training loop.

Initializes the Runner object.

Parameters
  • environment (BaseEnv) – Environment used in the training loop.

  • agents (list[Agent]) – List of agents that interact with the environment

  • logger (ScheduledLogger) – Logger object used to log metrics.

  • experiment_manager (Experiment) – Experiment object that saves the state of the training.

  • train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.

  • test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.

  • test_episodes (int) – How many episodes to run testing for.

  • stack_size (int) – The number of frames in an observation sent to an agent.

  • max_steps_per_episode (int) – The maximum number of steps to run an episode for.

run_one_step(observation, turn, episode_metrics)[source]

Run one step of the training loop.

If it is the agent’s first turn during the episode, do not run an update step. Otherwise, run an update step based on the previous action and accumulated reward since then.

Parameters
  • observation – Current observation that the agent should create an action for.

  • turn (int) – Agent whose turn it is.

  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

run_end_step(episode_metrics, done=True)[source]

Run the final step of an episode.

After an episode ends, iterate through agents and update then with the final step in the episode.

Parameters
  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

  • done (bool) – Whether this step was terminal.

run_episode()[source]

Run a single episode of the environment.

hive.runners.multi_agent_loop.set_up_experiment(config)[source]

Returns a MultiAgentRunner object based on the config and any command line arguments.

Parameters

config – Configuration for experiment.

hive.runners.multi_agent_loop.main()[source]
hive.runners.single_agent_loop module
class hive.runners.single_agent_loop.SingleAgentRunner(environment, agent, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, max_steps_per_episode=27000)[source]

Bases: Runner

Runner class used to implement a sinle-agent training loop.

Initializes the Runner object.

Parameters
  • environment (BaseEnv) – Environment used in the training loop.

  • agent (Agent) – Agent that will interact with the environment

  • logger (ScheduledLogger) – Logger object used to log metrics.

  • experiment_manager (Experiment) – Experiment object that saves the state of the training.

  • train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.

  • test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.

  • test_episodes (int) – How many episodes to run testing for duing each test phase.

  • stack_size (int) – The number of frames in an observation sent to an agent.

  • max_steps_per_episode (int) – The maximum number of steps to run an episode for.

run_one_step(observation, episode_metrics)[source]

Run one step of the training loop.

Parameters
  • observation – Current observation that the agent should create an action for.

  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

run_episode()[source]

Run a single episode of the environment.

hive.runners.single_agent_loop.set_up_experiment(config)[source]

Returns a SingleAgentRunner object based on the config and any command line arguments.

Parameters

config – Configuration for experiment.

hive.runners.single_agent_loop.main()[source]
hive.runners.utils module
hive.runners.utils.load_config(config=None, preset_config=None, agent_config=None, env_config=None, logger_config=None)[source]

Used to load config for experiments. Agents, environment, and loggers components in main config file can be overrided based on other log files.

Parameters
  • config (str) – Path to configuration file. Either this or preset_config must be passed.

  • preset_config (str) – Path to a preset hive config. This path should be relative to hive/configs. For example, the Atari DQN config would be atari/dqn.yml.

  • agent_config (str) – Path to agent configuration file. Overrides settings in base config.

  • env_config (str) – Path to environment configuration file. Overrides settings in base config.

  • logger_config (str) – Path to logger configuration file. Overrides settings in base config.

class hive.runners.utils.Metrics(agents, agent_metrics, episode_metrics)[source]

Bases: object

Class used to keep track of separate metrics for each agent as well general episode metrics.

Parameters
  • agents (list[Agent]) – List of agents for which object will track metrics.

  • agent_metrics (list[(str, (callable | obj))]) – List of metrics to track for each agent. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable that takes no arguments and creates the initial metric.

  • episode_metrics (list[(str, (callable | obj))]) – List of non agent specific metrics to keep track of. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable with no arguments that creates the initial metric.

reset_metrics()[source]

Resets all metrics to their initial values.

get_flat_dict()[source]

Get a flat dictionary version of the metrics. Each agent metric will be prefixed by the agent id.

class hive.runners.utils.TransitionInfo(agents, stack_size)[source]

Bases: object

Used to keep track of the most recent transition for each agent.

Any info that the agent needs to remember for updating can be stored here. Should be completely reset between episodes. After any info is extracted, it is automatically removed from the object. Also keeps track of which agents have started their episodes.

This object also handles padding and stacking observations for agents.

Parameters
  • agents (list[Agent]) – list of agents that will be kept track of.

  • stack_size (int) – How many observations will be stacked.

reset()[source]

Reset the object by clearing all info.

is_started(agent)[source]

Check if agent has started its episode.

Parameters

agent (Agent) – Agent to check.

start_agent(agent)[source]

Set the agent’s start flag to true.

Parameters

agent (Agent) – Agent to start.

record_info(agent, info)[source]

Update some information for the agent.

Parameters
  • agent (Agent) – Agent to update.

  • info (dict) – Info to add to the agent’s state.

update_reward(agent, reward)[source]

Add a reward to the agent.

Parameters
  • agent (Agent) – Agent to update.

  • reward (float) – Reward to add to agent.

update_all_rewards(rewards)[source]

Update the rewards for all agents. If rewards is list, it updates the rewards according to the order of agents provided in the initializer. If rewards is a dict, the keys should be the agent ids for the agents and the values should be the rewards for those agents. If rewards is a float or int, every agent is updated with that reward.

Parameters

rewards (float | list | np.ndarray | dict) – Rewards to update agents with.

get_info(agent, done=False)[source]

Get all the info for the agent, and reset the info for that agent. Also adds a done value to the info dictionary that is based on the done parameter to the function.

Parameters
  • agent (Agent) – Agent to get transition update info for.

  • done (bool) – Whether this transition is terminal.

get_stacked_state(agent, observation)[source]

Create a stacked state for the agent. The previous observations recorded by this agent are stacked with the current observation. If not enough observations have been recorded, zero arrays are appended.

Parameters
  • agent (Agent) – Agent to get stacked state for.

  • observation – Current observation.

hive.runners.utils.zeros_like(x)[source]

Create a zero state like some state. This handles slightly more complex objects such as lists and dictionaries of numpy arrays and torch Tensors.

Parameters

x (np.ndarray | torch.Tensor | dict | list) – State used to define structure/state of zero state.

hive.runners.utils.concatenate(xs)[source]

Concatenates numpy arrays or dictionaries of numpy arrays.

Parameters

xs (list) – List of objects to concatenate.

Module contents
hive.utils package
Submodules
hive.utils.experiment module

Implementation of a simple experiment class.

class hive.utils.experiment.Experiment(name, dir_name, schedule)[source]

Bases: object

Implementation of a simple experiment class.

Initializes an experiment object.

The experiment state is an exposed property of objects of this class. It can be used to keep track of objects that need to be saved to keep track of the experiment, but don’t fit in one of the standard categories. One example of this is the various schedules used in the Runner class.

Parameters
  • name (str) – Name of the experiment.

  • dir_name (str) – Absolute path to the directory to save/load the experiment.

register_experiment(config=None, logger=None, agents=None, environment=None)[source]

Registers all the components of an experiment.

Parameters
  • config (Chomp) – a config dictionary.

  • logger (Logger) – a logger object.

  • agents (Agent | list[Agent]) – either an agent object or a list of agents.

  • environment (BaseEnv) – an environment object.

update_step()[source]

Updates the step of the saving schedule for the experiment.

should_save()[source]

Returns whether you should save the experiment at the current step.

save(tag='current')[source]

Saves the experiment. :param tag: Tag to prefix the folder. :type tag: str

is_resumable(tag='current')[source]

Returns true if the experiment is resumable.

Parameters

tag (str) – Tag for the saved experiment.

resume(tag='current')[source]

Resumes the experiment from a checkpoint.

Parameters

tag (str) – Tag for the saved experiment.

hive.utils.loggers module
class hive.utils.loggers.Logger(timescales=None)[source]

Bases: ABC, Registrable

Abstract class for logging in hive.

Constructor for base Logger class. Every Logger must call this constructor in its own constructor

Parameters

timescales (str | list(str)) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

register_timescale(timescale)[source]

Register a new timescale with the logger.

Parameters

timescale (str) – Timescale to register.

abstract log_config(config)[source]

Log the config.

Parameters

config (dict) – Config parameters.

abstract log_scalar(name, value, prefix)[source]

Log a scalar variable.

Parameters
  • name (str) – Name of the metric to be logged.

  • value (float) – Value to be logged.

  • prefix (str) – Prefix to append to metric name.

abstract log_metrics(metrics, prefix)[source]

Log a dictionary of values.

Parameters
  • metrics (dict) – Dictionary of metrics to be logged.

  • prefix (str) – Prefix to append to metric name.

abstract save(dir_name)[source]

Saves the current state of the log files.

Parameters

dir_name (str) – Name of the directory to save the log files.

abstract load(dir_name)[source]

Loads the log files from given directory.

Parameters

dir_name (str) – Name of the directory to load the log file from.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

class hive.utils.loggers.ScheduledLogger(timescales=None, logger_schedules=None)[source]

Bases: Logger

Abstract class that manages a schedule for logging.

The update_step method should be called for each step in the loop to update the logger’s schedule. The should_log method can be used to check whether the logger should log anything.

This schedule is not strictly enforced! It is still possible to log something even if should_log returns false. These functions are just for the purpose of convenience.

Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).

Parameters
  • timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

  • logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.

register_timescale(timescale, schedule=None)[source]

Register a new timescale.

Parameters
  • timescale (str) – Timescale to register.

  • schedule (Schedule) – Schedule to use for this timescale.

update_step(timescale)[source]

Update the step and schedule for a given timescale.

Parameters

timescale (str) – A registered timescale.

should_log(timescale)[source]

Check if you should log for a given timescale.

Parameters

timescale (str) – A registered timescale.

save(dir_name)[source]

Saves the current state of the log files.

Parameters

dir_name (str) – Name of the directory to save the log files.

load(dir_name)[source]

Loads the log files from given directory.

Parameters

dir_name (str) – Name of the directory to load the log file from.

class hive.utils.loggers.NullLogger(timescales=None, logger_schedules=None)[source]

Bases: ScheduledLogger

A null logger that does not log anything.

Used if you don’t want to log anything, but still want to use parts of the framework that ask for a logger.

Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).

Parameters
  • timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

  • logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.

log_config(config)[source]

Log the config.

Parameters

config (dict) – Config parameters.

log_scalar(name, value, timescale)[source]

Log a scalar variable.

Parameters
  • name (str) – Name of the metric to be logged.

  • value (float) – Value to be logged.

  • prefix (str) – Prefix to append to metric name.

log_metrics(metrics, timescale)[source]

Log a dictionary of values.

Parameters
  • metrics (dict) – Dictionary of metrics to be logged.

  • prefix (str) – Prefix to append to metric name.

save(dir_name)[source]

Saves the current state of the log files.

Parameters

dir_name (str) – Name of the directory to save the log files.

load(dir_name)[source]

Loads the log files from given directory.

Parameters

dir_name (str) – Name of the directory to load the log file from.

class hive.utils.loggers.WandbLogger(timescales=None, logger_schedules=None, project=None, name=None, dir=None, mode=None, id=None, resume=None, start_method=None, **kwargs)[source]

Bases: ScheduledLogger

A Wandb logger.

This logger can be used to log to wandb. It assumes that wandb is configured locally on your system. Multiple timescales/loggers can be implemented by instantiating multiple loggers with different logger_names. These should still have the same project and run names.

Check the wandb documentation for more details on the parameters.

Parameters
  • timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

  • logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.

  • project (str) – Name of the project. Wandb’s dash groups all runs with the same project name together.

  • name (str) – Name of the run. Used to identify the run on the wandb dash.

  • dir (str) – Local directory where wandb saves logs.

  • mode (str) – The mode of logging. Can be “online”, “offline” or “disabled”. In offline mode, writes all data to disk for later syncing to a server, while in disabled mode, it makes all calls to wandb api’s noop’s, while maintaining core functionality.

  • id (str, optional) – A unique ID for this run, used for resuming. It must be unique in the project, and if you delete a run you can’t reuse the ID.

  • resume (bool, str, optional) – Sets the resuming behavior. Options are the same as mentioned in Wandb’s doc.

  • start_method (str) – The start method to use for wandb’s process. See https://docs.wandb.ai/guides/track/launch#init-start-error.

  • **kwargs – You can pass any other arguments to wandb’s init method as keyword arguments. Note, these arguments can’t be overriden from the command line.

log_config(config)[source]

Log the config.

Parameters

config (dict) – Config parameters.

log_scalar(name, value, prefix)[source]

Log a scalar variable.

Parameters
  • name (str) – Name of the metric to be logged.

  • value (float) – Value to be logged.

  • prefix (str) – Prefix to append to metric name.

log_metrics(metrics, prefix)[source]

Log a dictionary of values.

Parameters
  • metrics (dict) – Dictionary of metrics to be logged.

  • prefix (str) – Prefix to append to metric name.

class hive.utils.loggers.ChompLogger(timescales=None, logger_schedules=None)[source]

Bases: ScheduledLogger

This logger uses the Chomp data structure to store all logged values which are then directly saved to disk.

Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).

Parameters
  • timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

  • logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.

log_config(config)[source]

Log the config.

Parameters

config (dict) – Config parameters.

log_scalar(name, value, prefix)[source]

Log a scalar variable.

Parameters
  • name (str) – Name of the metric to be logged.

  • value (float) – Value to be logged.

  • prefix (str) – Prefix to append to metric name.

log_metrics(metrics, prefix)[source]

Log a dictionary of values.

Parameters
  • metrics (dict) – Dictionary of metrics to be logged.

  • prefix (str) – Prefix to append to metric name.

save(dir_name)[source]

Saves the current state of the log files.

Parameters

dir_name (str) – Name of the directory to save the log files.

load(dir_name)[source]

Loads the log files from given directory.

Parameters

dir_name (str) – Name of the directory to load the log file from.

class hive.utils.loggers.CompositeLogger(logger_list)[source]

Bases: Logger

This Logger aggregates multiple loggers together.

This logger is for convenience and allows for logging using multiple loggers without having to keep track of several loggers. When timescales are updated, this logger updates the timescale for each one of its component loggers. When logging, logs to each of its component loggers as long as the logger is not a ScheduledLogger that should not be logging for the timescale.

Constructor for base Logger class. Every Logger must call this constructor in its own constructor

Parameters

timescales (str | list(str)) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.

register_timescale(timescale, schedule=None)[source]

Register a new timescale with the logger.

Parameters

timescale (str) – Timescale to register.

log_config(config)[source]

Log the config.

Parameters

config (dict) – Config parameters.

log_scalar(name, value, prefix)[source]

Log a scalar variable.

Parameters
  • name (str) – Name of the metric to be logged.

  • value (float) – Value to be logged.

  • prefix (str) – Prefix to append to metric name.

log_metrics(metrics, prefix)[source]

Log a dictionary of values.

Parameters
  • metrics (dict) – Dictionary of metrics to be logged.

  • prefix (str) – Prefix to append to metric name.

update_step(timescale)[source]

Update the step and schedule for a given timescale for every ScheduledLogger.

Parameters

timescale (str) – A registered timescale.

should_log(timescale)[source]

Check if you should log for a given timescale. If any logger in the list is scheduled to log, returns True.

Parameters

timescale (str) – A registered timescale.

save(dir_name)[source]

Saves the current state of the log files.

Parameters

dir_name (str) – Name of the directory to save the log files.

load(dir_name)[source]

Loads the log files from given directory.

Parameters

dir_name (str) – Name of the directory to load the log file from.

hive.utils.registry module
class hive.utils.registry.Registrable[source]

Bases: object

Class used to denote which types of objects can be registered in the RLHive Registry. These objects can also be configured directly from the command line, and recursively built from the config, assuming type annotations are present.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

class hive.utils.registry.CallableType(fn)[source]

Bases: Registrable

A wrapper that allows any callable to be registered in the RLHive Registry. Specifically, it maps the arguments and annotations of the wrapped function to the resulting callable, allowing any argument names and type annotations of the underlying function to be present for outer wrapper. When called with some arguments, this object returns a partial function with those arguments assigned.

By default, the type_name is “callable”, but if you want to create specific types of callables, you can simply create a subclass and override the type_name method. See hive.utils.utils.OptimizerFn.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

class hive.utils.registry.Registry[source]

Bases: object

This is the Registry class for RLHive. It allows you to register different types of Registrable classes and objects and generates constructors for those classes in the form of get_{type_name}.

These constructors allow you to construct objects from dictionary configs. These configs should have two fields: name, which corresponds to the name used when registering a class in the registry, and kwargs, which corresponds to the keyword arguments that will be passed to the constructor of the object. These constructors can also build objects recursively, i.e. if a config contains the config for another Registrable object, this will be automatically created before being passed to the constructor of the original object. These constructors also allow you to directly specify/override arguments for object constructors directly from the command line. These parameters are specified in dot notation. They also are able to handle lists and dictionaries of Registrable objects.

For example, let’s consider the following scenario: Your agent class has an argument arg1 which is annotated to be List[Class1], Class1 is Registrable, and the Class1 constructor takes an argument arg2. In the passed yml config, there are two different Class1 object configs listed. the constructor will check to see if both –agent.arg1.0.arg2 and –agent.arg1.1.arg2 have been passed.

The parameters passed in the command line will be parsed according to the type annotation of the corresponding low level constructor. If it is not one of int, float, str, or bool, it simply loads the string into python using a yaml loader.

Each constructor returns the object, as well a dictionary config with all the parameters used to create the object and any Registrable objects created in the process of creating this object.

register(name, constructor, type)[source]

Register a Registrable class/object with RLHive.

Parameters
  • name (str) – Name of the class/object being registered.

  • constructor (callable) – Callable that will be passed all kwargs from configs and be analyzed to get type annotations.

  • type (type) – Type of class/object being registered. Should be subclass of Registrable.

register_all(base_class, class_dict)[source]

Bulk register function.

Parameters
  • base_class (type) – Corresponds to the type of the register function

  • class_dict (dict[str, callable]) – A dictionary mapping from name to constructor.

get_agent(object_or_config, prefix=None)
get_env(object_or_config, prefix=None)
get_function(object_or_config, prefix=None)
get_init_fn(object_or_config, prefix=None)
get_logger(object_or_config, prefix=None)
get_loss_fn(object_or_config, prefix=None)
get_optimizer_fn(object_or_config, prefix=None)
get_replay(object_or_config, prefix=None)
get_schedule(object_or_config, prefix=None)
hive.utils.registry.construct_objects(object_constructor, config, prefix=None)[source]

Helper function that constructs any objects specified in the config that are registrable.

Returns the object, as well a dictionary config with all the parameters used to create the object and any Registrable objects created in the process of creating this object.

Parameters
  • object_constructor (callable) – constructor of object that corresponds to config. The signature of this function will be analyzed to see if there are any Registrable objects that might be specified in the config.

  • config (dict) – The kwargs for the object being created. May contain configs for other Registrable objects that need to be recursively created.

  • prefix (str) – Prefix that is attached to the argument names when looking for command line arguments.

hive.utils.registry.get_callable_parsed_args(callable, prefix=None)[source]

Helper function that extracts the command line arguments for a given function.

Parameters
  • callable (callable) – function whose arguments will be inspected to extract arguments from the command line.

  • prefix (str) – Prefix that is attached to the argument names when looking for command line arguments.

hive.utils.registry.get_parsed_args(arguments, prefix=None)[source]

Helper function that takes a dictionary mapping argument names to types, and extracts command line arguments for those arguments. If the dictionary contains a key-value pair “bar”: int, and the prefix passed is “foo”, this function will look for a command line argument “--foo.bar”. If present, it will cast it to an int.

If the type for a given argument is not one of int, float, str, or bool, it simply loads the string into python using a yaml loader.

Parameters
  • arguments (dict[str, type]) – dictionary mapping argument names to types

  • prefix (str) – prefix that is attached to each argument name before searching for command line arguments.

hive.utils.schedule module
class hive.utils.schedule.Schedule[source]

Bases: ABC, Registrable

abstract get_value()[source]

Returns the current value of the variable we are tracking

abstract update()[source]

Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.

classmethod type_name()[source]

This should represent a string that denotes the which type of class you are creating. For example, “logger”, “agent”, or “env”.

class hive.utils.schedule.LinearSchedule(init_value, end_value, steps)[source]

Bases: Schedule

Defines a linear schedule between two values over some number of steps.

If updated more than the defined number of steps, the schedule stays at the end value.

Parameters
  • init_value (int | float) – Starting value for schedule.

  • end_value (int | float) – End value for schedule.

  • steps (int) – Number of steps for schedule. Should be positive.

get_value()[source]

Returns the current value of the variable we are tracking

update()[source]

Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.

class hive.utils.schedule.ConstantSchedule(value)[source]

Bases: Schedule

Returns a constant value over the course of the schedule

Parameters

value – The value to be returned.

get_value()[source]

Returns the current value of the variable we are tracking

update()[source]

Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.

class hive.utils.schedule.SwitchSchedule(off_value, on_value, steps)[source]

Bases: Schedule

Returns one value for the first part of the schedule. After the defined number of steps is reached, switches to returning a second value.

Parameters
  • off_value – The value to be returned in the first part of the schedule.

  • on_value – The value to be returned in the second part of the schedule.

  • steps (int) – The number of steps after which to switch from the off value to the on value.

get_value()[source]

Returns the current value of the variable we are tracking

update()[source]

Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.

class hive.utils.schedule.DoublePeriodicSchedule(off_value, on_value, off_period, on_period)[source]

Bases: Schedule

Returns off value for off period, then switches to returning on value for on period. Alternates between the two.

Parameters
  • on_value – The value to be returned for the on period.

  • off_value – The value to be returned for the off period.

  • on_period (int) – the number of steps in the on period.

  • off_period (int) – the number of steps in the off period.

get_value()[source]

Returns the current value of the variable we are tracking

update()[source]

Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.

class hive.utils.schedule.PeriodicSchedule(off_value, on_value, period)[source]

Bases: DoublePeriodicSchedule

Returns one value on the first step of each period of a predefined number of steps. Returns another value otherwise.

Parameters
  • on_value – The value to be returned on the first step of each period.

  • off_value – The value to be returned for every other step in the period.

  • period (int) – The number of steps in the period.

hive.utils.torch_utils module
hive.utils.torch_utils.numpify(t)[source]

Convert object to a numpy array.

Parameters

t (np.ndarray | torch.Tensor | obj) – Converts object to np.ndarray.

class hive.utils.torch_utils.RMSpropTF(params, lr=0.01, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0.0, centered=False, decoupled_decay=False, lr_in_momentum=True)[source]

Bases: Optimizer

Direct cut-paste from rwhightman/pytorch-image-models. https://github.com/rwightman/pytorch-image-models/blob/f7d210d759beb00a3d0834a3ce2d93f6e17f3d38/timm/optim/rmsprop_tf.py Licensed under Apache 2.0, https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE

Implements RMSprop algorithm (TensorFlow style epsilon)

NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt and a few other modifications to closer match Tensorflow for matching hyper-params. Noteworthy changes include:

  1. Epsilon applied inside square-root

  2. square_avg initialized to ones

  3. LR scaling of update accumulated in momentum buffer

Proposed by G. Hinton in his course. The centered version first appears in Generating Sequences With Recurrent Neural Networks.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate (default: 1e-2)

  • momentum (float, optional) – momentum factor (default: 0)

  • alpha (float, optional) – smoothing (decay) constant (default: 0.9)

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-10)

  • centered (bool, optional) – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • decoupled_decay (bool, optional) – decoupled weight decay as per https://arxiv.org/abs/1711.05101

  • lr_in_momentum (bool, optional) – learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

hive.utils.utils module
hive.utils.utils.create_folder(folder)[source]

Creates a folder.

Parameters

folder (str) – Folder to create.

class hive.utils.utils.Seeder[source]

Bases: object

Class used to manage seeding in RLHive. It sets the seed for all the frameworks that RLHive currently uses. It also deterministically provides new seeds based on the global seed, in case any other objects in RLHive (such as the agents) need their own seed.

set_global_seed(seed)[source]

This reduces some sources of randomness in experiments. To get reproducible results, you must run on the same machine and set the environment variable CUBLAS_WORKSPACE_CONFIG to “:4096:8” or “:16:8” before starting the experiment.

Parameters

seed (int) – Global seed.

get_new_seed()[source]

Each time it is called, it increments the current_seed and returns it.

class hive.utils.utils.Chomp[source]

Bases: dict

An extension of the dictionary class that allows for accessing through dot notation and easy saving/loading.

save(filename)[source]

Saves the object using pickle.

Parameters

filename (str) – Filename to save object.

load(filename)[source]

Loads the object.

Parameters

filename (str) – Where to load object from.

class hive.utils.utils.OptimizerFn(fn)[source]

Bases: CallableType

A wrapper for callables that produce optimizer functions.

These wrapped callables can be partially initialized through configuration files or command line arguments.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]
Returns

“optimizer_fn”

class hive.utils.utils.LossFn(fn)[source]

Bases: CallableType

A wrapper for callables that produce loss functions.

These wrapped callables can be partially initialized through configuration files or command line arguments.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]
Returns

“loss_fn”

hive.utils.visualization module
hive.utils.visualization.find_single_run_data(run_folder)[source]

Looks for a chomp logger data file in run_folder and it’s subdirectories. Once it finds one, it loads the file and returns the data.

Parameters

run_folder (str) – Which folder to search for the chomp logger data.

Returns

The Chomp object containing the logger data.

hive.utils.visualization.find_all_runs_data(runs_folder)[source]

Iterates through each directory in runs_folder, finds one chomp logger data file in each directory, and concatenates the data together under each key present in the data.

Parameters

runs_folder (str) – Folder which contains subfolders, each of from which one chomp logger data file is loaded.

Returns

Dictionary with each key corresponding to a key in the chomp logger data files. The values are the folder

hive.utils.visualization.find_all_experiments_data(experiments_folder, runs_folders)[source]

Finds and loads all log data in a folder.

Assuming the directory structure is as follows:

experiment
├── config_1_runs
|   ├──seed0/
|   ├──seed1/
|       .
|       .
|   └──seedn/
├── config_2_runs
|   ├──seed0/
|   ├──seed1/
|       .
|       .
|   └──seedn/
├── config_3_runs
|   ├──seed0/
|   ├──seed1/
|       .
|       .
|   └──seedn/

Where there is some chomp logger data file under each seed directory, then passing “experiment” and the list of config folders (“config_1_runs”, “config_2_runs”, “config_3_runs”), will load all the data.

Parameters
  • experiments_folder (str) – Root folder with all experiments data.

  • runs_folders (list[str]) – List of folders under root folder to load data from.

hive.utils.visualization.standardize_data(experiment_data, x_key, y_key, num_sampled_points=1000, drop_last=True)[source]

Extracts given keys from data, and standardizes the data across runs by sampling equally spaced points along x data, and interpolating y data of each run.

Parameters
  • experiment_data – Data object in the format of find_all_experiments_data().

  • x_key (str) – Key for x axis.

  • y_key (str) – Key for y axis.

  • num_sampled_points (int) – How many points to sample along x axis.

  • drop_last (bool) – Whether to drop the last point in the data.

hive.utils.visualization.find_and_standardize_data(experiments_folder, runs_folders, x_key, y_key, num_sampled_points, drop_last)[source]

Finds and standardizes the data in experiments_folder.

Parameters
  • experiments_folder (str) – Root folder with all experiments data.

  • runs_folders (list[str]) – List of folders under root folder to load data from.

  • x_key (str) – Key for x axis.

  • y_key (str) – Key for y axis.

  • num_sampled_points (int) – How many points to sample along x axis.

  • drop_last (bool) – Whether to drop the last point in the data.

hive.utils.visualization.generate_lineplot(x_datas, y_datas, smoothing_fn=None, line_labels=None, xlabel=None, ylabel=None, cmap_name=None, output_file='output.png')[source]

Aggregates data and generates lineplot.

hive.utils.visualization.plot_results(experiments_folder, x_key, y_key, runs_folders=None, drop_last=True, run_names=None, x_label=None, y_label=None, cmap_name=None, smoothing_fn=None, num_sampled_points=100, output_file='output.png')[source]

Plots results.

hive.utils.visualization.create_exponential_smoothing_fn(smoothing=0.1)[source]
hive.utils.visualization.create_moving_average_smoothing_fn(running_average=10)[source]
hive.utils.visualization.get_smoothing_fn(smoothing_fn, smoothing_fn_kwargs)[source]
Module contents

Module contents

Notes

Reproducibility

Achieving reproducibility in deep RL is difficult. Even when the random seed is fixed, libraries such as PyTorch use algorithms and implementations that are nondeterministic. PyTorch has several options that allow the user to turn off some aspects of this nondeterminism, but behavior is still usually only replicable if the runs are executed on the same hardware.

We provide a global seeding class Seeder that allows the user to set a global seed for all packages currently used by the framework (NumPy, PyTorch, and Python’s random package). It also sets the PyTorch options to turn off nondeterminism. When using this seeding functionality, before starting a run, you must set the environment variable CUBLAS_WORKSPACE_CONFIG to either ":16:8" (limits performance) or ":4096:8" (uses slightly more memory). See this page for more details.

The Seeder class also provides a function get_new_seed() that provides a new random seed each time it is called, which is useful when in multi-agent setups where you want each agent to be seeded differently.

Roadmap

We have quite a lot planned for RLHive, so stay tuned! Here is what we are currently planning to add:

  • Policy gradient methods (PPO)

  • RNN support

  • rliable

  • Lifelong RL support

  • JAX support

  • Continuous Actions

  • Dreamerv2/MBRL

  • Better config system

  • RL Debugging tools

  • Log videos of trajectories

  • Allow changing objects from command line

  • Hyperparameter search integration

  • Add logger based on Python logging

If you have a feature request that isn’t listed here, check out our Contributing page!

Contributing

Core RLHive

Did you spot a bug, or is there something that you think should be added to the main RLHive package? Check our Issues on Github or our roadmap to see if it’s already on our radar. If not, you can either open an issue to let us know, or you can fork our repo and create a pull request with your own feature/bug fix.

Creating Issues

We’d love to hear from you on how we can improve RLHive! When creating issues, please follow these guidelines:

  • Tag your issue with bug, feature request, or question to help us effectively sort through the issues.

  • Include the version of RLHive you are running (run pip list | grep rlhive)

Creating PRs

When contributing to RLHive, please follow these guidelines:

  • Create a separate PR for each feature you are adding.

  • When you are done writing the feature, create a pull request to the dev branch of the main repo.

  • Each pull request must pass all unit tests and linting checks before being merged. Please run these checks locally before creating PRs/pushing changes to the PR to minimize unnecessary runs on our CI system. You can run the following commands from the root directory of the project:

    • Unit Tests: python -m pytest tests/

    • Linter: black .

  • Information (such as installation instructions and editor integrations) for the black formatter is available here.

  • Make sure your code is documented using Google style docstrings. For examples, see the rest of our repository.

Contrib RLHive

We want to encourage users to contribute their own custom components to RLHive. This could be things like agents, environments, runners, replays, optimizers, and anything else. This will allow everyone to easily use and build on your work.

To do this, we have a contrib directory that will be part of the package. After adding new components and any relevant citation information to this folder, new versions of RLHive will be updated with these components, allowing your work to get a potentially larger audience. Note, we will be actively maintaining only the code in the main package, not the contrib package. We can only commit to giving minimal feedback during the review stage. If a contribution becomes widely adopted by the community, we may move it to the main repository to actively maintain.

When submitting the PR for your contributions, you must provide some results with your new components that were generated with the RLHive package, to provide us evidence of the correctness of your implementation.

Indices and tables