Welcome to RLHive’s documentation!
Quickstart
Installation
RLHive is available through pip! For the basic RLHive package, simply run
pip install rlhive
.
You can also install dependencies necessary for the environments that
RLHive comes with by running pip install rlhive[<env_names>]
where
<env_names>
is a comma separated list made up of the following:
atari
gym_minigrid
pettingzoo
In addition to these environments, Minatar and Marlgrid are also supported, but need to be installed separately.
To install Minatar, run
pip install MinAtar@git+https://github.com/kenjyoung/MinAtar.git@8b39a18a60248ede15ce70142b557f3897c4e1eb
To install Marlgrid, run
pip install marlgrid@https://github.com/kandouss/marlgrid/archive/refs/heads/master.zip
Running an experiment
There are several ways to run an experiment with RLHive. If you want to just run a preset config, you can directly run your experiment from the command line, with a config file path relative to the hive/configs folder. These examples run a DQN on the Atari game Asterix according to the Dopamine configuration and a simplified Rainbow agent for Hanabi trained using self-play according to the DeepMind’s configuration
hive_single_agent_loop -p atari/dqn.yml
hive_multi_agent_loop -p hanabi/rainbow.yml
If you want to run an experiment with components that are all available in RLHive, but not presets, you can create your own config file, and run that instead! Make sure you look at the examples here and the tutorial here to properly format it:
hive_single_agent_loop -c <config-file>
hive_multi_agent_loop -c <config-file>
Finally, if instead you want to use your own custom custom components you can simply register it with RLHive and run your config normally:
import hive
from hive.runners.utils import load_config
from hive.runners.single_agent_loop import set_up_experiment
class CustomAgent(hive.agents.Agent):
# Definition of Agent
pass
hive.registry.register('CustomAgent', CustomAgent, CustomAgent)
# Either load your custom full config file with that includes CustomAgent
config = load_config(config='custom_full_config.yml')
runner = set_up_experiment(config)
runner.run_training()
# Or load a preset config and just replace the agent config
config = load_config(preset_config='atari/dqn.yml', agent_config='custom_agent_config.yml')
runner = set_up_experiment(config)
runner.run_training()
Tutorials
Agent
Agent API
Interacting with the agent happens primarily through two functions: agent.act()
and
agent.update()
.
agent.act()
takes in an observation and returns an action, while agent.update()
takes in a
dictionary consisting of the information relevant to the most recent transition and updates the agent.
Creating an Agent
Let’s create a new tabular Q-learning agent with discount factor gamma
and learning rate alpha
,
for some environment with one hot observations.
We want this agent to have an epsilon greedy policy, with the exploration rate decaying over
explore_steps
from 1.0
to some value final_epsilon
.
First, we define the constructor:
import numpy as np
import os
import hive
from hive.agents.agent import Agent
from hive.utils.schedule import LinearSchedule
class TabularQLearningAgent(Agent):
def __init__(self, obs_dim, act_dim, gamma, alpha, explore_steps, final_epsilon, id=0):
super().__init__(obs_dim, act_dim, id=id)
self._q_values = np.zeros(obs_dim, act_dim)
self._gamma = gamma
self._alpha = alpha
self._epsilon_schedule = LinearSchedule(1.0, final_epsilon, explore_steps)
In this constructor, we created a numpy array to keep track of the Q-values for every state-action pair, and a linear decay schedule for the epsilon exploration rate. Next, let’s create the act function:
def act(self, observation):
# Return a random action if exploring
if np.random.rand() < self._epsilon_schedule.update():
return np.random.randint(self._act_dim)
else:
state = np.argmax(observation) # Convert from one-hot
# Break ties randomly between all actions with max values
max_value = np.amax(self._q_values[state])
best_actions = np.where(self._q_values[state] == max_value)[0]
return np.random.choice(best_actions)
Now, we write our update function, which updates the state of our agent:
def update(self, update_info):
state = np.argmax(update_info["observation"])
next_state = np.argmax(update_info["next_observation"])
self._q_values[state, update_info["action"]] += self._alpha * (
update_info["reward"]
+ self._gamma * np.amax(self._q_values[next_state])
- self._q_values[state, action]
)
Now, we can directly use this environment with the single agent or multi-agent runners.
Note act
and update
are framework agnostic, so you could implement it with any
(deep) learning framework, although most of our implemented agents are written in PyTorch.
If we write a save and load function for this agent, we can also take advantage of checkpointing and resuming in the runner:
def save(self, dname):
np.save(os.path.join(dname, "qvalues.npy"), self._q_values)
pickle.dump({"schedule": self._epsilon_schedule}, open("state.p", "wb"))
def load(self, dname):
self._q_values = np.load(os.path.join(dname, "qvalues.npy"))
self._epsilon_schedule = pickle.load(open("state.p", "rb"))["schedule"]
Finally, we register our agent class, so that it can be found when setting up experiments through the yaml config files and command line.
hive.registry.register('TabularQLearningAgent', TabularQLearningAgent, Agent)
Configuration
RLHive was written to allow for fast configuration and iteration on configuration. The base configuration for all parameters for the experiment is done through YAML files. The majority of these parameters can be overriden through the command line.
Any object registered with RLHive can be configured directly through YAML files or through the command line, without having to add any extra argument parsers.
YAML files
Let’s do a basic example of configuring an agent through YAML files. Specifically,
let’s create a DQNAgent
.
agent:
name: DQNAgent
kwargs:
representation_net:
name: MLPNetwork
kwargs:
hidden_units: [256, 256]
discount_rate: .9
replay_buffer:
name: CircularReplayBuffer
reward_clip: 1.0
In this example, DQNAgent
,
MLPNetwork
, and
CircularReplayBuffer
are all classes
registered with RLHive. Thus, we can do this configuration directly. When the
registry
getter function for agents,
get_agent()
is then called with this config
dictionary (with the missing required arguments such as obs_dim
and act_dim
,
filled in), it will build all the inner RLHive objects automatically.
This works by using the type annotations on the constructors of the objects, so
to recursively create the internal objects, those arguments need to be annotated
correctly.
Overriding from command lines
When using the registry
getter functions, RLHive automatically checks any command
line arguments passed to see if they match/override any default or yaml configured
arguments. With getter
functionyou provide a config and a prefix. That prefix
is added prepended to any argument names when searching the command line. For example,
with the above config, if it were loaded and the
get_agent()
method was called as follows:
agent = get_agent(config['agent'], 'ag')
Then, to override the discount_rate, you could pass the following argument to your
python script: --ag.discount_rate .95
. This can go arbitrarily deep into registered
RLHive class. For example, if you wanted to change the capacity of the replay buffer,
you could pass --ag.replay_buffer.capacity 100000
.
If the type annotation the argument arg
is List[C]
where C is a registered
RLHive class, then you can override the argument of an individual object, foo
,
configured through YAML by passing --arg.0.foo <value>
.
Note as of this version, you must have configured the object in the YAML file in order to override its parameters through the command line.
Using the DQN/Rainbow Agents
The DQNAgent
and RainbowDQNAgent
are written to allow for easy extensions and adaptation to your applications. We outline a few different use cases here.
Using a different network architecture
Using different types of network architectures with
DQNAgent
and RainbowDQNAgent
is done using the representation_net
parameter in the constructor. This network
should not include the final layer which computes the final Q-values. It
computes the representations that are fed into the layer which will compute the
final Q-values. This is because often the only difference between different variations
of the DQN algorithms is how the final Q-values are computed, with the rest of the architecture
not changing.
You can modify the architecture of the representation network from the config, or create a completely new architecture better suited to your needs. From the config, two different types of network architectures are supported:
ConvNetwork
: Networks with convolutional layers, followed by an MLPMLPNetwork
: An MLP with only linear layers
See this page for details on how to configure the network.
To use an architecture not supported by the above classes, simply write the Pytorch
module implementing the architecture, and register the class wrapped with
FunctionApproximator
wrapper. The only requirement is that this class should take
in the input dimension as the first positional argument:
import torch
import hive
from hive.agents.qnets import FunctionApproximator
class CustomArchitecture(torch.nn.Module):
def __init__(self, in_dim, hidden_units):
super().__init__()
self.network = torch.nn.Sequential(
torch.nn.Linear(in_dim, hidden_units),
torch.nn.ReLU(),
torch.nn.Linear(hidden_units, hidden_units)
)
def forward(self, x):
x = torch.flatten(x, start_dim=1)
return self.network(x)
hive.registry.register(
'CustomArchitecture',
FunctionApproximator(CustomArchitecture),
FunctionApproximator
)
Adding in different Rainbow components
The Rainbow architecture is composed of several different components, namely:
Double Q-learning
Prioritized Replay
Dueling Networks
Multi-step Learning
Distributional RL
Noisy Networks
Each of these components can be independently used with our
RainbowDQNAgent
class. To use Prioritized Replay,
you must pass a PrioritizedReplayBuffer
to the replay_buffer
parameter of
RainbowDQNAgent
. The details for how to use the other
components of rainbow are found in the API documentation of
RainbowDQNAgent
.
Custom Input Observations
The current implementations of DQNAgent
and RainbowDQNAgent
handle the standard case of
observations being a single numpy array, and no extra inputs being necessary during
the update phase other than action
, reward
, and done
. In the situation
where this is not the case, and you need to handle more complex inputs, you can do so
by overriding the methods of DQNAgent
. Let’s walk through
the example of LegalMovesRainbowAgent
.
This agent takes in a list of legal moves on each turn and only selects from those.
class LegalMovesHead(torch.nn.Module):
def __init__(self, base_network):
super().__init__()
self.base_network = base_network
def forward(self, x, legal_moves):
x = self.base_network(x)
return x + legal_moves
def dist(self, x, legal_moves):
return self.base_network.dist(x)
class LegalMovesRainbowAgent(RainbowDQNAgent):
"""A Rainbow agent which supports games with legal actions."""
def create_q_networks(self, representation_net):
"""Creates the qnet and target qnet."""
super().create_q_networks(representation_net)
self._qnet = LegalMovesHead(self._qnet)
self._target_qnet = LegalMovesHead(self._target_qnet)
This defines a wrapper around the Q-networks used by agent that takes an
encoding of the legal moves where illegal moves have value \(-\infty\)
and legal moves have value \(0\). The wrapper then adds this encoding
to the values generated by the base Q-networks. Overriding
create_q_networks()
allows you to modify the
base Q-networks by adding this wrapper.
def preprocess_update_batch(self, batch):
for key in batch:
batch[key] = torch.tensor(batch[key], device=self._device)
return (
(batch["observation"], batch["action_mask"]),
(batch["next_observation"], batch["next_action_mask"]),
batch,
)
Now, since the Q-networks expect an extra parameter (the legal moves action mask),
we override the preprocess_update_batch()
method,
which takes a batch sampled from the replay buffer and defines the inputs that will
be used to compute the values of the current state and the next state during the update
step.
def preprocess_update_info(self, update_info):
preprocessed_update_info = {
"observation": update_info["observation"]["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["done"],
"action_mask": action_encoding(update_info["observation"]["action_mask"]),
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])
return preprocessed_update_info
We must also make sure that the action encoding for each transition is added to the
replay buffer in the first place. To do that, we override the
preprocess_update_info()
method, which should return
a dictionary with keys and values corresponding to the items you wish to store into the
replay buffer. Note, these keys need to be specified when you create the replay buffer,
see Replays for more information.
@torch.no_grad()
def act(self, observation):
if self._training:
if not self._learn_schedule.get_value():
epsilon = 1.0
elif not self._use_eps_greedy:
epsilon = 0.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
else:
epsilon = self._test_epsilon
vectorized_observation = torch.tensor(
np.expand_dims(observation["observation"], axis=0), device=self._device
).float()
legal_moves_as_int = [
i for i, x in enumerate(observation["action_mask"]) if x == 1
]
encoded_legal_moves = torch.tensor(
action_encoding(observation["action_mask"]), device=self._device
).float()
qvals = self._qnet(vectorized_observation, encoded_legal_moves).cpu()
if self._rng.random() < epsilon:
action = np.random.choice(legal_moves_as_int).item()
else:
action = torch.argmax(qvals).item()
return action
Finally, you also need to override the act()
method
to extract and use the extra information.
Environments
Installing Environments
We support several environments in RLHive, namely:
Atari
Gym classic control
Minatar (simplified Atari)
Minigrid (single-agent grid world)
Marlgrid (multi-agent)
Pettingzoo (multi-agent)
While gym comes installed with the base package, you need to install the other environments. See Installation for more details.
Creating an Environment
RLHive Environments
Every environment used in RLHive should be a subclass of ~hive.envs.base.BaseEnv.
It should provide a reset
function that resets the environment to a new episode
and returns a tuple of (observation, turn)
and a step
function that takes in
an action, performs the step in the environment, and returns a tuple of
(observation, reward, done, turn, info)
. All these values correspond to their
canonical meanings, and turn
corresponds to the index
of the agent whose turn it is (in multi-agent environments).
The reward
return value can be a single number, an array, or a dictionary. If it’s
a number, then that same reward will be given to every single agent. If it’s an array,
the agents get the reward corresponding to their index in the runner. If it’s a
dictionary, the keys should be the agent ids, and the value the reward for that agent.
Each environment should also provide an EnvSpec
environment that will provide information about the environment such as the expected
observation shape and action dimension for each agent. These should be lists with one
element for each agent. See GymEnv
for an example.
Gym environments
If your environment is a gym environment, and you do not need to preprocess the
observations generated by the environment, then you can directly use the
GymEnv
. Just make sure you register your environment
with gym
, and pass the name of the environment to the
GymEnv
constructor.
If you need to add extra preprocessing or change the default way that
environment/EnvSpec creation is done, you can simply subclass this class and override
either create_env()
and/or
create_env_spec()
, as in
AtariEnv
.
Parallel Step Environments
Multi-agent environments usually come in two flavors: sequential step environments,
where each agent takes it’s action one at a time, and parallel step environments,
where each agent steps at the same time. The
MultiAgentRunner
class expects only
sequential step environments. Fortunately, we can convert between parallel step
environments and single step environments by simply generating the action for each
agent one at a time and passing all the action to the parallel step environment all
at once. To facilitate this, we provide a utility class
ParallelEnv
. Simply write the logic for your parallel step
environment as normal, and then create a single step version of the environment by
subclassing ParallelEnv
and the parallel step environment,
making sure to put ParallelEnv
first in the
superclass list.
from hive.envs.base import BaseEnv, ParallelEnv
class ParallelStepEnvironment(BaseEnv):
# Write the logic needed for the parallel step environment. Assume the step
# function gets an array actions as it's input, and should return an array
# containing the observations for each agent, as well as the other return
# values expected by the environment.
class SequentialStepEnvironment(ParallelEnv, ParallelStepEnvironment):
# Any other logic needed to create the environmnet.
Loggers
Using a Logger
RLHive currently provides 3 types of loggers:
ChompLogger
: Logs metrics to a dictionary like object that can be saved.WandbLogger
: Logs metrics to WandB.NullLogger
: Does not log metrics.
All of these loggers are ScheduledLoggers
.
When creating these loggers, you provide the different timescales that you want to
track with the logger. Timescales could correspond to any loop variable, such as
"train_step"
, "agent_step"
, "test_iteration"
, etc.
With ScheduledLoggers
, each
timescale is associated with a Schedule
object.
By updating these schedules in some loop, you can control the logging frequency
for that timescale. The loggers also keep track of how many times that timescale
was updated, and those values are logged alongside any metric you log (i.e. if
"timescale_1"
was updated 7 times, then the logger will log an additional
key value pair of "timescale_1": 7
with each metric. This allows
you to see the trends of any metric across any timescale.
Timescales can be registered when creating the logger, or later on.
For an example:
from hive.utils.loggers import ChompLogger, CompositeLogger, WandbLogger
from hive.utils.schedule import ConstantSchedule, PeriodicSchedule
logger = ChompLogger(
['train_step', 'test_step'], # Timescales to track
[ # How often to log train_step and test_step
PeriodicSchedule(False, True, 10), # train_step is logged once every 10 times
ConstantSchedule(True) # test_step is always logged
]
)
# You can register timescales after logger creation as well. This particular timescale
# is never scheduled to log.
logger.register_timescale('dummy_timescale', ConstantSchedule(False))
for _ in range(total_training_time):
metrics_to_log = run_training_step()
if logger.update_step('train_step'): # Evaluates to True once every 10 times it's hit.
logger.log_metrics(metrics_to_log, 'training_metrics')
other_metric = do_something_else()
if logger.should_log('train_step'): # Checks schedule for this timescale without updating
logger.log_scalar('other_metric', other_metric, 'training_metrics')
if should_run_testing():
test_metrics = run_testing()
if logger.update_step('tets_step'): # Evaluates to True every time
logger.log_metrics(test_metrics, 'testing_metrics')
Note, the schedules associated with each timescale are not hard constraints on when you can log. They are merely a convenience to help you keep track of when to log.
Composite Logger
If you want to log to multiple sources, without the hassle of keeping track of multiple
loggers, you can create a CompositeLogger
object
initialized with each of the individual loggers that you want to use. For example:
from hive.utils.loggers import ChompLogger, CompositeLogger, WandbLogger
logger = CompositeLogger([
ChompLogger('train'),
WandbLogger('train')
])
You can now use this logger as above, and it will log to both
ChompLogger
and
WandbLogger
.
Registration
Registering new Classes
To register a new class with the Registry, you need to make sure that it or one of its
ancestors subclassed hive.utils.registry.Registrable
and provided a definition
for hive.utils.registry.Registrable.type_name()
. The return value of this function
is used to create the getter function for that type (registry.get_{type_name}
).
You can either register several different classes at once:
registry.register_all(
Agent,
{
"DQNAgent": DQNAgent,
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
},
)
or one at a time:
registry.register("DQNAgent", DQNAgent, Agent)
registry.register("LegalMovesRainbowAgent", LegalMovesRainbowAgent, Agent)
registry.register("RainbowDQNAgent", RainbowDQNAgent, Agent)
registry.register("RandomAgent", RandomAgent, Agent)
After a class has been registered, you can use pass a config dictionary to the getter function for that type to create the object.
Callables
There are several cases where we want to parameterize some function or constructor partway, but
not pass the fully created object in as an argument. One example is optimizers. You might want
to pass a learning rate, but you cannot create the final optimizer object until you’ve created
the parameters you want to optimize. To deal with such cases, we provide a
CallableType
class, which can be used to register and wrap any
callable. For example, with optimizers, we have:
class OptimizerFn(CallableType):
"""A wrapper for callables that produce optimizer functions.
These wrapped callables can be partially initialized through configuration
files or command line arguments.
"""
@classmethod
def type_name(cls):
"""
Returns:
"optimizer_fn"
"""
return "optimizer_fn"
registry.register_all(
OptimizerFn,
{
"Adadelta": OptimizerFn(optim.Adadelta),
"Adagrad": OptimizerFn(optim.Adagrad),
"Adam": OptimizerFn(optim.Adam),
"Adamax": OptimizerFn(optim.Adamax),
"AdamW": OptimizerFn(optim.AdamW),
"ASGD": OptimizerFn(optim.ASGD),
"LBFGS": OptimizerFn(optim.LBFGS),
"RMSprop": OptimizerFn(optim.RMSprop),
"RMSpropTF": OptimizerFn(RMSpropTF),
"Rprop": OptimizerFn(optim.Rprop),
"SGD": OptimizerFn(optim.SGD),
"SparseAdam": OptimizerFn(optim.SparseAdam),
},
)
With this, we can now make use of the configurability of RLHive objects while still passing callables as arguments.
Replays
RLHive currently provides 4 types of Replays:
CircularReplayBuffer
: An implementation of a FIFO circular replay buffer. Stores individual observations and constructs transitions on the fly when sampling to save space.SimpleReplayBuffer
: A simplified version of a FIFO circular replay buffer that stores individual transitions directly.PrioritizedReplayBuffer
: A subclass ofCircularReplayBuffer
that adds prioritized sampling.LegalMovesReplayBuffer
: A subclass ofPrioritizedReplayBuffer
that stores/handles legal moves.
The main replay buffer classes that you will likely use/extend are
CircularReplayBuffer
and
PrioritizedReplayBuffer
.
By default, these classes expect the arguments "observation"
, "action"
,
"reward"
, and "done"
when adding to the buffer. You can also provide alternative
shapes/dtypes for these keys, and the buffer will try to automatically cast the objects
you add to the buffer.
Along with these default keys, you can also store extra keys in the buffer. When
creating the buffer, provide a dictionary with key-value pairs key: (type, shape)
.
When adding, you can directly provide this key as an argument to the
add()
method, and it will
automatically be added to the batch dictionary that you sample.
Runners
We provide two different Runner
classes:
SingleAgentRunner
and
MultiAgentRunner
. The setup
for both Runner classes can be viewed in their respective files with the
set_up_experiment()
functions.
The get_parsed_args()
function can be used
to get any arguments from the command line are not part of the signatures
of already registered RLHive class constructors.
Metrics and TransitionInfo
The Metrics
class can be used to keep track
of metrics for single/multiple agents across an episode.
# Create the Metrics object. The first set of metrics is individual to
# each agent, the second is common for all agents. The metrics can
# be initialized either with a value or with callable with no arguments
metrics = Metrics(
[agent1, agent2],
[("reward", 0), ("episode_traj", lambda: [])],
[("full_episode_length", 0)],
)
# Add metrics
metrics[agent1.id]["reward"] += 1
metrics[agent2.id]["episode_traj"].append(0)
metrics["full_episode_length"] += 1
# Convert to flat dictionary for easy logging. Adds agent id's as prefixes
# for agent_specific metrics
flat_metrics = metrics.get_flat_dict()
# Reinitialize/reset all metrics
metrics.reset_metrics()
The TransitionInfo
class can be used to keep track
of the information needed by the agent to construct it’s next state for acting or
next transition for updating. It also handles state stacking and padding.
transition_info = TransitionInfo([agent1, agent2], stack_size)
# Set the start flag for agent1.
transition_info.start(agent1)
# Get stacked observation for agent1. If not enough observations have been
# recorded, it will pad with 0s
stacked_observation = transition_info.get_stacked_state(
agent1, observation
)
# Record update information about the agent
transition_info.record_info(agent, info)
# Get the update information for the agent, with done set to the value passed
info = transition_info.get_info(agent, done=done)
RLHive API
hive package
Subpackages
hive.agents package
Subpackages
- class hive.agents.qnets.atari.nature_atari_dqn.NatureAtariDQNModel(in_dim)[source]
Bases:
ConvNetwork
The convolutional network used to train the DQN agent in the original Nature paper: https://www.nature.com/articles/nature14236
- Parameters
in_dim (tuple) – The tuple of observations dimension (channels, width, height).
- class hive.agents.qnets.base.FunctionApproximator(fn)[source]
Bases:
CallableType
A wrapper for callables that produce function approximators.
For example,
FunctionApproximator(create_neural_network)
orFunctionApproximator(MyNeuralNetwork)
wherecreate_neural_network
is a function that creates a neural network module andMyNeuralNetwork
is a class that defines your function approximator.These wrapped callables can be partially initialized through configuration files or command line arguments.
- Parameters
fn – callable to be wrapped.
- class hive.agents.qnets.conv.ConvNetwork(in_dim, channels=None, mlp_layers=None, kernel_sizes=1, strides=1, paddings=0, normalization_factor=255, noisy=False, std_init=0.5)[source]
Bases:
Module
Basic convolutional neural network architecture. Applies a number of convolutional layers (each followed by a ReLU activation), and then feeds the output into an
hive.agents.qnets.mlp.MLPNetwork
.Note, if
channels
isNone
, the network created for the convolution portion of the architecture is simply antorch.nn.Identity
module. Ifmlp_layers
isNone
, the mlp portion of the architecture is antorch.nn.Identity
module.- Parameters
in_dim (tuple) – The tuple of observations dimension (channels, width, height).
channels (list) – The size of output channel for each convolutional layer.
mlp_layers (list) – The number of neurons for each mlp layer after the convolutional layers.
kernel_sizes (list | int) – The kernel size for each convolutional layer
strides (list | int) – The stride used for each convolutional layer.
paddings (list | int) – The size of the padding used for each convolutional layer.
normalization_factor (float | int) – What the input is divided by before the forward pass of the network.
noisy (bool) – Whether the MLP part of the network will use
NoisyLinear
layers ortorch.nn.Linear
layers.std_init (float) – The range for the initialization of the standard deviation of the weights in
NoisyLinear
.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.mlp.MLPNetwork(in_dim, hidden_units=256, noisy=False, std_init=0.5)[source]
Bases:
Module
Basic MLP neural network architecture.
Contains a series of
torch.nn.Linear
orNoisyLinear
layers, each of which is followed by a ReLU.- Parameters
hidden_units (int | list[int]) – The number of neurons for each mlp layer.
noisy (bool) – Whether the MLP should use
NoisyLinear
layers or normaltorch.nn.Linear
layers.std_init (float) – The range for the initialization of the standard deviation of the weights in
NoisyLinear
.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.noisy_linear.NoisyLinear(in_dim, out_dim, std_init=0.5)[source]
Bases:
Module
NoisyLinear Layer. Implements the layer described in https://arxiv.org/abs/1706.10295.
- Parameters
- forward(inp)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.qnet_heads.DQNNetwork(base_network, hidden_dim, out_dim, linear_fn=None)[source]
Bases:
Module
Implements the standard DQN value computation. Transforms output from
base_network
with output dimensionhidden_dim
to dimensionout_dim
, which should be equal to the number of actions.- Parameters
base_network (torch.nn.Module) – Backbone network that computes the representations that are used to compute action values.
hidden_dim (int) – Dimension of the output of the
network
.out_dim (int) – Output dimension of the DQN. Should be equal to the number of actions that you are computing values for.
linear_fn (torch.nn.Module) – Function that will create the
torch.nn.Module
that will take the output ofnetwork
and produce the final action values. IfNone
, atorch.nn.Linear
layer will be used.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.qnet_heads.DuelingNetwork(base_network, hidden_dim, out_dim, linear_fn=None, atoms=1)[source]
Bases:
Module
Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581). In dueling, we have two heads—one for estimating advantage function and one for estimating value function.
- Parameters
base_network (torch.nn.Module) – Backbone network that computes the representations that are shared by the two estimators.
hidden_dim (int) – Dimension of the output of the
base_network
.out_dim (int) – Output dimension of the Dueling DQN. Should be equal to the number of actions that you are computing values for.
linear_fn (torch.nn.Module) – Function that will create the
torch.nn.Module
that will take the output ofnetwork
and produce the final action values. IfNone
, atorch.nn.Linear
layer will be used.atoms (int) – Multiplier for the dimension of the output. For standard dueling networks, this should be 1. Used by
DistributionalNetwork
.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.qnets.qnet_heads.DistributionalNetwork(base_network, out_dim, vmin=0, vmax=200, atoms=51)[source]
Bases:
Module
Computes a categorical distribution over values for each action (https://arxiv.org/abs/1707.06887).
- Parameters
base_network (torch.nn.Module) – Backbone network that computes the representations that are used to compute the value distribution.
out_dim (int) – Output dimension of the Distributional DQN. Should be equal to the number of actions that you are computing values for.
vmin (float) – The minimum of the support of the categorical value distribution.
vmax (float) – The maximum of the support of the categorical value distribution.
atoms (int) – Number of atoms discretizing the support range of the categorical value distribution.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- hive.agents.qnets.utils.calculate_output_dim(net, input_shape)[source]
Calculates the resulting output shape for a given input shape and network.
- Parameters
net (torch.nn.Module) – The network which you want to calculate the output dimension for.
input_shape (int | tuple[int]) – The shape of the input being fed into the
net
. Batch dimension should not be included.
- Returns
The shape of the output of a network given an input shape. Batch dimension is not included.
- hive.agents.qnets.utils.create_init_weights_fn(initialization_fn)[source]
Returns a function that wraps
initialization_function()
and applies it to modules that have theweight
attribute.- Parameters
initialization_fn (callable) – A function that takes in a tensor and initializes it.
- Returns
Function that takes in PyTorch modules and initializes their weights. Can be used as follows:
init_fn = create_init_weights_fn(variance_scaling_) network.apply(init_fn)
- hive.agents.qnets.utils.calculate_correct_fan(tensor, mode)[source]
Calculate fan of tensor.
- Parameters
tensor (torch.Tensor) – Tensor to calculate fan of.
mode (str) – Which type of fan to compute. Must be one of “fan_in”, “fan_out”, and “fan_avg”.
- Returns
Fan of the tensor based on the mode.
- hive.agents.qnets.utils.variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='uniform')[source]
Implements the
tf.keras.initializers.VarianceScaling
initializer in PyTorch.- Parameters
tensor (torch.Tensor) – Tensor to initialize.
scale (float) – Scaling factor (must be positive).
mode (str) – Must be one of “fan_in”, “fan_out”, and “fan_avg”.
distribution – Random distribution to use, must be one of “truncated_normal”, “untruncated_normal” and “uniform”.
- Returns
Initialized tensor.
- class hive.agents.qnets.utils.InitializationFn(fn)[source]
Bases:
CallableType
A wrapper for callables that produce initialization functions.
These wrapped callables can be partially initialized through configuration files or command line arguments.
- Parameters
fn – callable to be wrapped.
Submodules
- class hive.agents.agent.Agent(obs_dim, act_dim, id=0)[source]
Bases:
ABC
,Registrable
Base class for agents. Every implemented agent should be a subclass of this class.
- Parameters
obs_dim – Dimension of observations that agent will see.
act_dim – Number of actions that the agent needs to chose from.
id – Identifier for the agent.
- property id
- abstract act(observation)[source]
Returns an action for the agent to perform based on the observation.
- Parameters
observation – Current observation that agent should act on.
- Returns
Action for the current timestep.
- abstract update(update_info)[source]
Updates the agent.
- Parameters
update_info (dict) – Contains information agent needs to update itself.
- abstract save(dname)[source]
Saves agent checkpointing information to file for future loading.
- Parameters
dname (str) – directory where agent should save all relevant info.
- class hive.agents.dqn.DQNAgent(representation_net, obs_dim, act_dim, id=0, optimizer_fn=None, loss_fn=None, init_fn=None, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100)[source]
Bases:
Agent
An agent implementing the DQN algorithm. Uses an epsilon greedy exploration policy
- Parameters
representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
obs_dim – The shape of the observations.
act_dim (int) – The number of actions available to the agent.
id – Agent identifier.
optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to
Adam
.loss_fn (LossFn) – Loss function used by the agent. If None, defaults to
SmoothL1Loss
.init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
CircularReplayBuffer
.discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.
target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.
target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.
target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.
epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.
test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
device – Device on which all computations should be run.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
- create_q_networks(representation_net)[source]
Creates the Q-network and target Q-network.
- Parameters
representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
- preprocess_update_info(update_info)[source]
Preprocesses the
update_info
before it goes into the replay buffer. Clips the reward in update_info.- Parameters
update_info – Contains the information from the current timestep that the agent should use to update itself.
- preprocess_update_batch(batch)[source]
Preprocess the batch sampled from the replay buffer.
- Parameters
batch – Batch sampled from the replay buffer for the current update.
- Returns
(tuple) Inputs used to calculate current state values.
(tuple) Inputs used to calculate next state values
Preprocessed batch.
- Return type
(tuple)
- act(observation)[source]
Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.
- Parameters
observation – The current observation.
- update(update_info)[source]
Updates the DQN agent.
- Parameters
update_info – dictionary containing all the necessary information to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, and “done”.
- class hive.agents.legal_moves_rainbow.LegalMovesRainbowAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]
Bases:
RainbowDQNAgent
A Rainbow agent which supports games with legal actions.
- Parameters
representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
obs_dim (
Tuple
) – The shape of the observations.act_dim (int) – The number of actions available to the agent.
id – Agent identifier.
optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to
Adam
.loss_fn (LossFn) – Loss function used by the agent. If None, defaults to
SmoothL1Loss
.init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
PrioritizedReplayBuffer
.discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.
target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.
target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.
target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.
epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.
test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
device – Device on which all computations should be run.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
noisy (bool) – Whether to use noisy linear layers for exploration.
std_init (float) – The range for the initialization of the standard deviation of the weights.
use_eps_greedy (bool) – Whether to use epsilon greedy exploration.
double (bool) – Whether to use double DQN.
dueling (bool) – Whether to use a dueling network architecture.
distributional (bool) – Whether to use the distributional RL.
vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.
vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.
atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.
- preprocess_update_info(update_info)[source]
Preprocesses the
update_info
before it goes into the replay buffer. Clips the reward in update_info.- Parameters
update_info – Contains the information from the current timestep that the agent should use to update itself.
- preprocess_update_batch(batch)[source]
Preprocess the batch sampled from the replay buffer.
- Parameters
batch – Batch sampled from the replay buffer for the current update.
- Returns
(tuple) Inputs used to calculate current state values.
(tuple) Inputs used to calculate next state values
Preprocessed batch.
- Return type
(tuple)
- class hive.agents.legal_moves_rainbow.LegalMovesHead(base_network)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, legal_moves)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hive.agents.rainbow.RainbowDQNAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]
Bases:
DQNAgent
An agent implementing the Rainbow algorithm.
- Parameters
representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
obs_dim (
Tuple
) – The shape of the observations.act_dim (int) – The number of actions available to the agent.
id – Agent identifier.
optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to
Adam
.loss_fn (LossFn) – Loss function used by the agent. If None, defaults to
SmoothL1Loss
.init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
PrioritizedReplayBuffer
.discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.
target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.
target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.
target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.
epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.
test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
device – Device on which all computations should be run.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
noisy (bool) – Whether to use noisy linear layers for exploration.
std_init (float) – The range for the initialization of the standard deviation of the weights.
use_eps_greedy (bool) – Whether to use epsilon greedy exploration.
double (bool) – Whether to use double DQN.
dueling (bool) – Whether to use a dueling network architecture.
distributional (bool) – Whether to use the distributional RL.
vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.
vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.
atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.
- create_q_networks(representation_net)[source]
Creates the Q-network and target Q-network. Adds the appropriate heads for DQN, Dueling DQN, Noisy Networks, and Distributional DQN.
- Parameters
representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
- act(observation)[source]
Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.
- Parameters
observation – The current observation.
- update(update_info)[source]
Updates the DQN agent. :param update_info: dictionary containing all the necessary information to :param update the agent. Should contain a full transition: :param with keys for: :param “observation”: :param “action”: :param “reward”: :param “next_observation”: :param and “done”.:
- target_projection(target_net_inputs, next_action, reward, done)[source]
Project distribution of target Q-values.
- Parameters
target_net_inputs – Inputs to feed into the target net to compute the projection of the target Q-values. Should be set from
preprocess_update_batch()
.next_action (Tensor) – Tensor containing next actions used to compute target distribution.
reward (Tensor) – Tensor containing rewards for the current batch.
done (Tensor) – Tensor containing whether the states in the current batch are terminal.
- class hive.agents.random.RandomAgent(obs_dim, act_dim, id=0, logger=None)[source]
Bases:
Agent
An agent that takes random steps at each timestep.
- Parameters
obs_dim – The shape of the observations.
act_dim (int) – The number of actions available to the agent.
id – Agent identifier.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
- update(update_info)[source]
Updates the agent.
- Parameters
update_info (dict) – Contains information agent needs to update itself.
Module contents
hive.envs package
Subpackages
- class hive.envs.minatar.minatar.MinAtarEnv(env_name, sticky_action_prob=0.1, difficulty_ramping=True)[source]
Bases:
BaseEnv
Class for loading MinAtar environments. See https://github.com/kenjyoung/MinAtar.
- Parameters
- reset()[source]
Resets the state of the environment.
- Returns
The initial observation of the new episode. turn (int): The index of the agent which should take turn.
- Return type
observation
- class hive.envs.pettingzoo.pettingzoo.PettingZooEnv(env_name, env_family, num_players, **kwargs)[source]
Bases:
BaseEnv
PettingZoo environment from https://github.com/PettingZoo-Team/PettingZoo
For now, we only support environments from PettingZoo with discrete actions.
- Parameters
- create_env_spec(env_name, **kwargs)[source]
Each family of environments have their own type of observations and actions. You can add support for more families here by modifying obs_dim and act_dim.
- reset()[source]
Resets the state of the environment.
- Returns
The initial observation of the new episode. turn (int): The index of the agent which should take turn.
- Return type
observation
- step(action)[source]
Run one time-step of the environment using the input action.
- Parameters
action – An element of environment’s action space.
- Returns
Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.
- Return type
observation
- class hive.envs.wrappers.gym_wrappers.FlattenWrapper(env)[source]
Bases:
ObservationWrapper
Flatten the observation to one dimensional vector.
Wraps an environment to allow a modular transformation of the
step()
andreset()
methods.- Parameters
env – The environment to wrap
new_step_api – Whether the wrapper’s step method will output in new or old step API
- class hive.envs.wrappers.gym_wrappers.PermuteImageWrapper(env)[source]
Bases:
ObservationWrapper
Changes the image format from HWC to CHW
Wraps an environment to allow a modular transformation of the
step()
andreset()
methods.- Parameters
env – The environment to wrap
new_step_api – Whether the wrapper’s step method will output in new or old step API
Submodules
- class hive.envs.base.BaseEnv(env_spec, num_players)[source]
Bases:
ABC
,Registrable
Base class for environments.
- Parameters
- abstract reset()[source]
Resets the state of the environment.
- Returns
The initial observation of the new episode. turn (int): The index of the agent which should take turn.
- Return type
observation
- abstract step(action)[source]
Run one time-step of the environment using the input action.
- Parameters
action – An element of environment’s action space.
- Returns
Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.
- Return type
observation
- abstract seed(seed=None)[source]
Reseeds the environment.
- Parameters
seed (int) – Seed to use for environment.
- save(save_dir)[source]
Saves the environment.
- Parameters
save_dir (str) – Location to save environment state.
- load(load_dir)[source]
Loads the environment.
- Parameters
load_dir (str) – Location to load environment state from.
- property env_spec
- class hive.envs.base.ParallelEnv(env_name, num_players)[source]
Bases:
BaseEnv
Base class for environments that take make all agents step in parallel.
ParallelEnv takes an environment that expects an array of actions at each step to execute in parallel, and allows you to instead pass it a single action at each step.
This class makes use of Python’s multiple inheritance pattern. Specifically, when writing your parallel environment, it should extend both this class and the class that implements the step method that takes in actions for all agents.
If environment class A has the logic for the step function that takes in the array of actions, and environment class B is your parallel step version of that environment, class B should be defined as:
class B(ParallelEnv, A): ...
The order in which you list the classes is important. ParallelEnv must come before A in the order.
- Parameters
- reset()[source]
Resets the state of the environment.
- Returns
The initial observation of the new episode. turn (int): The index of the agent which should take turn.
- Return type
observation
- step(action)[source]
Run one time-step of the environment using the input action.
- Parameters
action – An element of environment’s action space.
- Returns
Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.
- Return type
observation
- class hive.envs.env_spec.EnvSpec(env_name, obs_dim, act_dim, env_info=None)[source]
Bases:
object
Object used to store information about environment configuration. Every environment should create an EnvSpec object.
- Parameters
env_name – Name of the environment
obs_dim – Dimensionality of observations from environment. This can be a simple integer, or a complex object depending on the types of observations expected.
act_dim – Dimensionality of action space.
env_info – Any other info relevant to this environment. This can include items such as random seeds or parameters used to create the environment
- property env_name
- property obs_dim
- property act_dim
- property env_info
- class hive.envs.gym_env.GymEnv(env_name, num_players=1, **kwargs)[source]
Bases:
BaseEnv
Class for loading gym environments.
- Parameters
env_name (str) – Name of the environment (NOTE: make sure it is available in gym.envs.registry.all())
num_players (int) – Number of players for the environment.
kwargs – Any arguments you want to pass to
create_env()
orcreate_env_spec()
can be passed as keyword arguments to this constructor.
- create_env(env_name, **kwargs)[source]
Function used to create the environment. Subclasses can override this method if they are using a gym style environment that needs special logic.
- Parameters
env_name (str) – Name of the environment
- create_env_spec(env_name, **kwargs)[source]
Function used to create the specification. Subclasses can override this method if they are using a gym style environment that needs special logic.
- Parameters
env_name (str) – Name of the environment
- reset()[source]
Resets the state of the environment.
- Returns
The initial observation of the new episode. turn (int): The index of the agent which should take turn.
- Return type
observation
- step(action)[source]
Run one time-step of the environment using the input action.
- Parameters
action – An element of environment’s action space.
- Returns
Indicates the next state that is an element of environment’s observation space. reward: A reward achieved from the transition. done (bool): Indicates whether the episode has ended. turn (int): Indicates which agent should take turn. info (dict): Additional custom information.
- Return type
observation
Module contents
hive.replays package
Submodules
- class hive.replays.circular_replay.CircularReplayBuffer(capacity=10000, stack_size=1, n_step=1, gamma=0.99, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]
Bases:
BaseReplayBuffer
An efficient version of a circular replay buffer that only stores each observation once.
Constructor for CircularReplayBuffer.
- Parameters
capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.
stack_size (int) – The number of frames to stack to create an observation.
n_step (int) – Horizon used to compute n-step return reward
gamma (float) – Discounting factor used to compute n-step return reward
observation_shape – Shape of observations that will be stored in the buffer.
observation_dtype – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.
action_shape – Shape of actions that will be stored in the buffer.
action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.
action_shape – Shape of actions that will be stored in the buffer.
action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.
reward_shape – Shape of rewards that will be stored in the buffer.
reward_dtype – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.
extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.
num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.
- add(observation, action, reward, done, **kwargs)[source]
Adds a transition to the buffer. The required components of a transition are given as positional arguments. The user can pass additional components to store in the buffer as kwargs as long as they were defined in the specification in the constructor.
- sample(batch_size)[source]
Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.
- Parameters
batch_size (int) – Number of transitions to sample.
- class hive.replays.circular_replay.SimpleReplayBuffer(capacity=100000.0, compress=False, seed=42, **kwargs)[source]
Bases:
BaseReplayBuffer
A simple circular replay buffers.
- Parameters
- add(observation, action, reward, done, **kwargs)[source]
Adds transition to the buffer
- Parameters
observation – The current observation
action – The action taken on the current observation
reward – The reward from taking action at current observation
done – If current observation was the last observation in the episode
- sample(batch_size=32)[source]
sample a minibatch
- Parameters
batch_size (int) – The number of examples to sample.
- class hive.replays.legal_moves_replay.LegalMovesBuffer(capacity, beta=0.5, stack_size=1, n_step=1, gamma=0.9, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, action_dim=None, num_players_sharing_buffer=None)[source]
Bases:
PrioritizedReplayBuffer
A Prioritized Replay buffer for the games like Hanabi with legal moves which need to add next_action_mask to the batch.
- Parameters
capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.
beta (float) – Parameter controlling level of prioritization.
stack_size (int) – The number of frames to stack to create an observation.
n_step (int) – Horizon used to compute n-step return reward
gamma (float) – Discounting factor used to compute n-step return reward
observation_shape (
Tuple
) – Shape of observations that will be stored in the buffer.observation_dtype (
type
) – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.action_shape (
Tuple
) – Shape of actions that will be stored in the buffer.action_dtype (
type
) – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.action_shape – Shape of actions that will be stored in the buffer.
action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.
reward_shape (
Tuple
) – Shape of rewards that will be stored in the buffer.reward_dtype (
type
) – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.
num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.
- class hive.replays.prioritized_replay.PrioritizedReplayBuffer(capacity, beta=0.5, stack_size=1, n_step=1, gamma=0.9, observation_shape=(), observation_dtype=<class 'numpy.uint8'>, action_shape=(), action_dtype=<class 'numpy.int8'>, reward_shape=(), reward_dtype=<class 'numpy.float32'>, extra_storage_types=None, num_players_sharing_buffer=None)[source]
Bases:
CircularReplayBuffer
Implements a replay with prioritized sampling. See https://arxiv.org/abs/1511.05952
- Parameters
capacity (int) – Total number of observations that can be stored in the buffer. Note, this is not the same as the number of transitions that can be stored in the buffer.
beta (float) – Parameter controlling level of prioritization.
stack_size (int) – The number of frames to stack to create an observation.
n_step (int) – Horizon used to compute n-step return reward
gamma (float) – Discounting factor used to compute n-step return reward
observation_shape (
Tuple
) – Shape of observations that will be stored in the buffer.observation_dtype (
type
) – Type of observations that will be stored in the buffer. This can either be the type itself or string representation of the type. The type can be either a native python type or a numpy type. If a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable.action_shape (
Tuple
) – Shape of actions that will be stored in the buffer.action_dtype (
type
) – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.action_shape – Shape of actions that will be stored in the buffer.
action_dtype – Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype.
reward_shape (
Tuple
) – Shape of rewards that will be stored in the buffer.reward_dtype (
type
) – Type of rewards that will be stored in the buffer. Format is described in the description of observation_dtype.extra_storage_types (dict) – A dictionary describing extra items to store in the buffer. The mapping should be from the name of the item to a (type, shape) tuple.
num_players_sharing_buffer (int) – Number of agents that share their buffers. It is used for self-play.
- sample(batch_size)[source]
Sample transitions from the buffer. For a given transition, if it’s done is True, the next_observation value should not be taken to have any meaning.
- Parameters
batch_size (int) – Number of transitions to sample.
- update_priorities(indices, priorities)[source]
Update the priorities of the transitions at the specified indices.
- Parameters
indices – Which transitions to update priorities for. Can be numpy array or torch tensor.
priorities – What the priorities should be updated to. Can be numpy array or torch tensor.
- class hive.replays.prioritized_replay.SumTree(capacity)[source]
Bases:
object
Data structure used to implement prioritized sampling. It is implemented as a tree where the value of each node is the sum of the values of the subtree of the node.
- set_priority(indices, priorities)[source]
Sets the priorities for the given indices.
- Parameters
indices (np.ndarray) – Which transitions to update priorities for.
priorities (np.ndarray) – What the priorities should be updated to.
- sample(batch_size)[source]
Sample elements from the sum tree with probability proportional to their priority.
- Parameters
batch_size (int) – The number of elements to sample.
- stratified_sample(batch_size)[source]
Performs stratified sampling using the sum tree.
- Parameters
batch_size (int) – The number of elements to sample.
- extract(queries)[source]
Get the elements in the sum tree that correspond to the query. For each query, the element that is selected is the one with the greatest sum of “previous” elements in the tree, but also such that the sum is not a greater proportion of the total sum of priorities than the query.
- Parameters
queries (np.ndarray) – Queries to extract. Each element should be between 0 and 1.
- class hive.replays.replay_buffer.BaseReplayBuffer[source]
Bases:
ABC
,Registrable
Base class for replay buffers. Every implemented buffer should be a subclass of this class.
- abstract add(**data)[source]
Adds data to the buffer
- Parameters
data – data to add to the replay buffer. Subclasses can define this class signature based on use case.
- abstract sample(batch_size)[source]
sample a minibatch
- Parameters
batch_size (int) – the number of transitions to sample.
- abstract save(dname)[source]
Saves buffer checkpointing information to file for future loading.
- Parameters
dname (str) – directory where agent should save all relevant info.
Module contents
hive.runners package
Submodules
- class hive.runners.base.Runner(environment, agents, logger, experiment_manager, train_steps=1000000, test_frequency=10000, test_episodes=1, max_steps_per_episode=27000)[source]
Bases:
ABC
Base Runner class used to implement a training loop.
Different types of training loops can be created by overriding the relevant functions.
- Parameters
environment (BaseEnv) – Environment used in the training loop.
agents (list[Agent]) – List of agents that interact with the environment.
logger (ScheduledLogger) – Logger object used to log metrics.
experiment_manager (Experiment) – Experiment object that saves the state of the training.
train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.
test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.
test_episodes (int) – How many episodes to run testing for.
- train_mode(training)[source]
If training is true, sets all agents to training mode. If training is false, sets all agents to eval mode.
- Parameters
training (bool) – Whether to be in training mode.
- class hive.runners.multi_agent_loop.MultiAgentRunner(environment, agents, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, self_play, max_steps_per_episode=27000)[source]
Bases:
Runner
Runner class used to implement a multiagent training loop.
Initializes the Runner object.
- Parameters
environment (BaseEnv) – Environment used in the training loop.
agents (list[Agent]) – List of agents that interact with the environment
logger (ScheduledLogger) – Logger object used to log metrics.
experiment_manager (Experiment) – Experiment object that saves the state of the training.
train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.
test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.
test_episodes (int) – How many episodes to run testing for.
stack_size (int) – The number of frames in an observation sent to an agent.
max_steps_per_episode (int) – The maximum number of steps to run an episode for.
- run_one_step(observation, turn, episode_metrics)[source]
Run one step of the training loop.
If it is the agent’s first turn during the episode, do not run an update step. Otherwise, run an update step based on the previous action and accumulated reward since then.
- hive.runners.multi_agent_loop.set_up_experiment(config)[source]
Returns a
MultiAgentRunner
object based on the config and any command line arguments.- Parameters
config – Configuration for experiment.
- class hive.runners.single_agent_loop.SingleAgentRunner(environment, agent, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, max_steps_per_episode=27000)[source]
Bases:
Runner
Runner class used to implement a sinle-agent training loop.
Initializes the Runner object.
- Parameters
environment (BaseEnv) – Environment used in the training loop.
agent (Agent) – Agent that will interact with the environment
logger (ScheduledLogger) – Logger object used to log metrics.
experiment_manager (Experiment) – Experiment object that saves the state of the training.
train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.
test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.
test_episodes (int) – How many episodes to run testing for duing each test phase.
stack_size (int) – The number of frames in an observation sent to an agent.
max_steps_per_episode (int) – The maximum number of steps to run an episode for.
- hive.runners.single_agent_loop.set_up_experiment(config)[source]
Returns a
SingleAgentRunner
object based on the config and any command line arguments.- Parameters
config – Configuration for experiment.
- hive.runners.utils.load_config(config=None, preset_config=None, agent_config=None, env_config=None, logger_config=None)[source]
Used to load config for experiments. Agents, environment, and loggers components in main config file can be overrided based on other log files.
- Parameters
config (str) – Path to configuration file. Either this or
preset_config
must be passed.preset_config (str) – Path to a preset hive config. This path should be relative to
hive/configs
. For example, the Atari DQN config would beatari/dqn.yml
.agent_config (str) – Path to agent configuration file. Overrides settings in base config.
env_config (str) – Path to environment configuration file. Overrides settings in base config.
logger_config (str) – Path to logger configuration file. Overrides settings in base config.
- class hive.runners.utils.Metrics(agents, agent_metrics, episode_metrics)[source]
Bases:
object
Class used to keep track of separate metrics for each agent as well general episode metrics.
- Parameters
agents (list[Agent]) – List of agents for which object will track metrics.
agent_metrics (list[(str, (callable | obj))]) – List of metrics to track for each agent. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable that takes no arguments and creates the initial metric.
episode_metrics (list[(str, (callable | obj))]) – List of non agent specific metrics to keep track of. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable with no arguments that creates the initial metric.
- class hive.runners.utils.TransitionInfo(agents, stack_size)[source]
Bases:
object
Used to keep track of the most recent transition for each agent.
Any info that the agent needs to remember for updating can be stored here. Should be completely reset between episodes. After any info is extracted, it is automatically removed from the object. Also keeps track of which agents have started their episodes.
This object also handles padding and stacking observations for agents.
- Parameters
- is_started(agent)[source]
Check if agent has started its episode.
- Parameters
agent (Agent) – Agent to check.
- start_agent(agent)[source]
Set the agent’s start flag to true.
- Parameters
agent (Agent) – Agent to start.
- update_all_rewards(rewards)[source]
Update the rewards for all agents. If rewards is list, it updates the rewards according to the order of agents provided in the initializer. If rewards is a dict, the keys should be the agent ids for the agents and the values should be the rewards for those agents. If rewards is a float or int, every agent is updated with that reward.
- get_info(agent, done=False)[source]
Get all the info for the agent, and reset the info for that agent. Also adds a done value to the info dictionary that is based on the done parameter to the function.
- get_stacked_state(agent, observation)[source]
Create a stacked state for the agent. The previous observations recorded by this agent are stacked with the current observation. If not enough observations have been recorded, zero arrays are appended.
- Parameters
agent (Agent) – Agent to get stacked state for.
observation – Current observation.
- hive.runners.utils.zeros_like(x)[source]
Create a zero state like some state. This handles slightly more complex objects such as lists and dictionaries of numpy arrays and torch Tensors.
- Parameters
x (np.ndarray | torch.Tensor | dict | list) – State used to define structure/state of zero state.
Module contents
hive.utils package
Submodules
Implementation of a simple experiment class.
- class hive.utils.experiment.Experiment(name, dir_name, schedule)[source]
Bases:
object
Implementation of a simple experiment class.
Initializes an experiment object.
The experiment state is an exposed property of objects of this class. It can be used to keep track of objects that need to be saved to keep track of the experiment, but don’t fit in one of the standard categories. One example of this is the various schedules used in the Runner class.
- Parameters
- register_experiment(config=None, logger=None, agents=None, environment=None)[source]
Registers all the components of an experiment.
- save(tag='current')[source]
Saves the experiment. :param tag: Tag to prefix the folder. :type tag: str
- class hive.utils.loggers.Logger(timescales=None)[source]
Bases:
ABC
,Registrable
Abstract class for logging in hive.
Constructor for base Logger class. Every Logger must call this constructor in its own constructor
- Parameters
timescales (str | list(str)) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
- register_timescale(timescale)[source]
Register a new timescale with the logger.
- Parameters
timescale (str) – Timescale to register.
- abstract save(dir_name)[source]
Saves the current state of the log files.
- Parameters
dir_name (str) – Name of the directory to save the log files.
- class hive.utils.loggers.ScheduledLogger(timescales=None, logger_schedules=None)[source]
Bases:
Logger
Abstract class that manages a schedule for logging.
The update_step method should be called for each step in the loop to update the logger’s schedule. The should_log method can be used to check whether the logger should log anything.
This schedule is not strictly enforced! It is still possible to log something even if should_log returns false. These functions are just for the purpose of convenience.
Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).
- Parameters
timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.
- update_step(timescale)[source]
Update the step and schedule for a given timescale.
- Parameters
timescale (str) – A registered timescale.
- should_log(timescale)[source]
Check if you should log for a given timescale.
- Parameters
timescale (str) – A registered timescale.
- class hive.utils.loggers.NullLogger(timescales=None, logger_schedules=None)[source]
Bases:
ScheduledLogger
A null logger that does not log anything.
Used if you don’t want to log anything, but still want to use parts of the framework that ask for a logger.
Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).
- Parameters
timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.
- class hive.utils.loggers.WandbLogger(timescales=None, logger_schedules=None, project=None, name=None, dir=None, mode=None, id=None, resume=None, start_method=None, **kwargs)[source]
Bases:
ScheduledLogger
A Wandb logger.
This logger can be used to log to wandb. It assumes that wandb is configured locally on your system. Multiple timescales/loggers can be implemented by instantiating multiple loggers with different logger_names. These should still have the same project and run names.
Check the wandb documentation for more details on the parameters.
- Parameters
timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.
project (str) – Name of the project. Wandb’s dash groups all runs with the same project name together.
name (str) – Name of the run. Used to identify the run on the wandb dash.
dir (str) – Local directory where wandb saves logs.
mode (str) – The mode of logging. Can be “online”, “offline” or “disabled”. In offline mode, writes all data to disk for later syncing to a server, while in disabled mode, it makes all calls to wandb api’s noop’s, while maintaining core functionality.
id (str, optional) – A unique ID for this run, used for resuming. It must be unique in the project, and if you delete a run you can’t reuse the ID.
resume (bool, str, optional) – Sets the resuming behavior. Options are the same as mentioned in Wandb’s doc.
start_method (str) – The start method to use for wandb’s process. See https://docs.wandb.ai/guides/track/launch#init-start-error.
**kwargs – You can pass any other arguments to wandb’s init method as keyword arguments. Note, these arguments can’t be overriden from the command line.
- class hive.utils.loggers.ChompLogger(timescales=None, logger_schedules=None)[source]
Bases:
ScheduledLogger
This logger uses the Chomp data structure to store all logged values which are then directly saved to disk.
Any timescales not assigned schedule from logger_schedules will be assigned a ConstantSchedule(True).
- Parameters
timescales (str|list[str]) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
logger_schedules (Schedule|list|dict) – Schedules used to keep track of when to log. If a single schedule, it is copied for each timescale. If a list of schedules, the schedules are matched up in order with the list of timescales provided. If a dictionary, the keys should be the timescale and the values should be the schedule.
- class hive.utils.loggers.CompositeLogger(logger_list)[source]
Bases:
Logger
This Logger aggregates multiple loggers together.
This logger is for convenience and allows for logging using multiple loggers without having to keep track of several loggers. When timescales are updated, this logger updates the timescale for each one of its component loggers. When logging, logs to each of its component loggers as long as the logger is not a ScheduledLogger that should not be logging for the timescale.
Constructor for base Logger class. Every Logger must call this constructor in its own constructor
- Parameters
timescales (str | list(str)) – The different timescales at which logger needs to log. If only logging at one timescale, it is acceptable to only pass a string.
- register_timescale(timescale, schedule=None)[source]
Register a new timescale with the logger.
- Parameters
timescale (str) – Timescale to register.
- update_step(timescale)[source]
Update the step and schedule for a given timescale for every ScheduledLogger.
- Parameters
timescale (str) – A registered timescale.
- should_log(timescale)[source]
Check if you should log for a given timescale. If any logger in the list is scheduled to log, returns True.
- Parameters
timescale (str) – A registered timescale.
- class hive.utils.registry.Registrable[source]
Bases:
object
Class used to denote which types of objects can be registered in the RLHive Registry. These objects can also be configured directly from the command line, and recursively built from the config, assuming type annotations are present.
- class hive.utils.registry.CallableType(fn)[source]
Bases:
Registrable
A wrapper that allows any callable to be registered in the RLHive Registry. Specifically, it maps the arguments and annotations of the wrapped function to the resulting callable, allowing any argument names and type annotations of the underlying function to be present for outer wrapper. When called with some arguments, this object returns a partial function with those arguments assigned.
By default, the type_name is “callable”, but if you want to create specific types of callables, you can simply create a subclass and override the type_name method. See
hive.utils.utils.OptimizerFn
.- Parameters
fn – callable to be wrapped.
- class hive.utils.registry.Registry[source]
Bases:
object
This is the Registry class for RLHive. It allows you to register different types of
Registrable
classes and objects and generates constructors for those classes in the form of get_{type_name}.These constructors allow you to construct objects from dictionary configs. These configs should have two fields: name, which corresponds to the name used when registering a class in the registry, and kwargs, which corresponds to the keyword arguments that will be passed to the constructor of the object. These constructors can also build objects recursively, i.e. if a config contains the config for another Registrable object, this will be automatically created before being passed to the constructor of the original object. These constructors also allow you to directly specify/override arguments for object constructors directly from the command line. These parameters are specified in dot notation. They also are able to handle lists and dictionaries of Registrable objects.
For example, let’s consider the following scenario: Your agent class has an argument arg1 which is annotated to be List[Class1], Class1 is Registrable, and the Class1 constructor takes an argument arg2. In the passed yml config, there are two different Class1 object configs listed. the constructor will check to see if both –agent.arg1.0.arg2 and –agent.arg1.1.arg2 have been passed.
The parameters passed in the command line will be parsed according to the type annotation of the corresponding low level constructor. If it is not one of int, float, str, or bool, it simply loads the string into python using a yaml loader.
Each constructor returns the object, as well a dictionary config with all the parameters used to create the object and any Registrable objects created in the process of creating this object.
- get_agent(object_or_config, prefix=None)
- get_env(object_or_config, prefix=None)
- get_function(object_or_config, prefix=None)
- get_init_fn(object_or_config, prefix=None)
- get_logger(object_or_config, prefix=None)
- get_loss_fn(object_or_config, prefix=None)
- get_optimizer_fn(object_or_config, prefix=None)
- get_replay(object_or_config, prefix=None)
- get_schedule(object_or_config, prefix=None)
- hive.utils.registry.construct_objects(object_constructor, config, prefix=None)[source]
Helper function that constructs any objects specified in the config that are registrable.
Returns the object, as well a dictionary config with all the parameters used to create the object and any Registrable objects created in the process of creating this object.
- Parameters
object_constructor (callable) – constructor of object that corresponds to config. The signature of this function will be analyzed to see if there are any
Registrable
objects that might be specified in the config.config (dict) – The kwargs for the object being created. May contain configs for other Registrable objects that need to be recursively created.
prefix (str) – Prefix that is attached to the argument names when looking for command line arguments.
- hive.utils.registry.get_callable_parsed_args(callable, prefix=None)[source]
Helper function that extracts the command line arguments for a given function.
- Parameters
callable (callable) – function whose arguments will be inspected to extract arguments from the command line.
prefix (str) – Prefix that is attached to the argument names when looking for command line arguments.
- hive.utils.registry.get_parsed_args(arguments, prefix=None)[source]
Helper function that takes a dictionary mapping argument names to types, and extracts command line arguments for those arguments. If the dictionary contains a key-value pair “bar”: int, and the prefix passed is “foo”, this function will look for a command line argument “--foo.bar”. If present, it will cast it to an int.
If the type for a given argument is not one of int, float, str, or bool, it simply loads the string into python using a yaml loader.
- class hive.utils.schedule.Schedule[source]
Bases:
ABC
,Registrable
- class hive.utils.schedule.LinearSchedule(init_value, end_value, steps)[source]
Bases:
Schedule
Defines a linear schedule between two values over some number of steps.
If updated more than the defined number of steps, the schedule stays at the end value.
- Parameters
- class hive.utils.schedule.ConstantSchedule(value)[source]
Bases:
Schedule
Returns a constant value over the course of the schedule
- Parameters
value – The value to be returned.
- class hive.utils.schedule.SwitchSchedule(off_value, on_value, steps)[source]
Bases:
Schedule
Returns one value for the first part of the schedule. After the defined number of steps is reached, switches to returning a second value.
- Parameters
off_value – The value to be returned in the first part of the schedule.
on_value – The value to be returned in the second part of the schedule.
steps (int) – The number of steps after which to switch from the off value to the on value.
- class hive.utils.schedule.DoublePeriodicSchedule(off_value, on_value, off_period, on_period)[source]
Bases:
Schedule
Returns off value for off period, then switches to returning on value for on period. Alternates between the two.
- Parameters
- class hive.utils.schedule.PeriodicSchedule(off_value, on_value, period)[source]
Bases:
DoublePeriodicSchedule
Returns one value on the first step of each period of a predefined number of steps. Returns another value otherwise.
- Parameters
on_value – The value to be returned on the first step of each period.
off_value – The value to be returned for every other step in the period.
period (int) – The number of steps in the period.
- hive.utils.torch_utils.numpify(t)[source]
Convert object to a numpy array.
- Parameters
t (np.ndarray | torch.Tensor | obj) – Converts object to
np.ndarray
.
- class hive.utils.torch_utils.RMSpropTF(params, lr=0.01, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0.0, centered=False, decoupled_decay=False, lr_in_momentum=True)[source]
Bases:
Optimizer
Direct cut-paste from rwhightman/pytorch-image-models. https://github.com/rwightman/pytorch-image-models/blob/f7d210d759beb00a3d0834a3ce2d93f6e17f3d38/timm/optim/rmsprop_tf.py Licensed under Apache 2.0, https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
Implements RMSprop algorithm (TensorFlow style epsilon)
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt and a few other modifications to closer match Tensorflow for matching hyper-params. Noteworthy changes include:
Epsilon applied inside square-root
square_avg initialized to ones
LR scaling of update accumulated in momentum buffer
Proposed by G. Hinton in his course. The centered version first appears in Generating Sequences With Recurrent Neural Networks.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate (default: 1e-2)
momentum (float, optional) – momentum factor (default: 0)
alpha (float, optional) – smoothing (decay) constant (default: 0.9)
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-10)
centered (bool, optional) – if
True
, compute the centered RMSProp, the gradient is normalized by an estimation of its varianceweight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
decoupled_decay (bool, optional) – decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum (bool, optional) – learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow
- hive.utils.utils.create_folder(folder)[source]
Creates a folder.
- Parameters
folder (str) – Folder to create.
- class hive.utils.utils.Seeder[source]
Bases:
object
Class used to manage seeding in RLHive. It sets the seed for all the frameworks that RLHive currently uses. It also deterministically provides new seeds based on the global seed, in case any other objects in RLHive (such as the agents) need their own seed.
- class hive.utils.utils.Chomp[source]
Bases:
dict
An extension of the dictionary class that allows for accessing through dot notation and easy saving/loading.
- class hive.utils.utils.OptimizerFn(fn)[source]
Bases:
CallableType
A wrapper for callables that produce optimizer functions.
These wrapped callables can be partially initialized through configuration files or command line arguments.
- Parameters
fn – callable to be wrapped.
- class hive.utils.utils.LossFn(fn)[source]
Bases:
CallableType
A wrapper for callables that produce loss functions.
These wrapped callables can be partially initialized through configuration files or command line arguments.
- Parameters
fn – callable to be wrapped.
- hive.utils.visualization.find_single_run_data(run_folder)[source]
Looks for a chomp logger data file in run_folder and it’s subdirectories. Once it finds one, it loads the file and returns the data.
- Parameters
run_folder (str) – Which folder to search for the chomp logger data.
- Returns
The Chomp object containing the logger data.
- hive.utils.visualization.find_all_runs_data(runs_folder)[source]
Iterates through each directory in runs_folder, finds one chomp logger data file in each directory, and concatenates the data together under each key present in the data.
- Parameters
runs_folder (str) – Folder which contains subfolders, each of from which one chomp logger data file is loaded.
- Returns
Dictionary with each key corresponding to a key in the chomp logger data files. The values are the folder
- hive.utils.visualization.find_all_experiments_data(experiments_folder, runs_folders)[source]
Finds and loads all log data in a folder.
Assuming the directory structure is as follows:
experiment ├── config_1_runs | ├──seed0/ | ├──seed1/ | . | . | └──seedn/ ├── config_2_runs | ├──seed0/ | ├──seed1/ | . | . | └──seedn/ ├── config_3_runs | ├──seed0/ | ├──seed1/ | . | . | └──seedn/
Where there is some chomp logger data file under each seed directory, then passing “experiment” and the list of config folders (“config_1_runs”, “config_2_runs”, “config_3_runs”), will load all the data.
- hive.utils.visualization.standardize_data(experiment_data, x_key, y_key, num_sampled_points=1000, drop_last=True)[source]
Extracts given keys from data, and standardizes the data across runs by sampling equally spaced points along x data, and interpolating y data of each run.
- Parameters
experiment_data – Data object in the format of
find_all_experiments_data()
.x_key (str) – Key for x axis.
y_key (str) – Key for y axis.
num_sampled_points (int) – How many points to sample along x axis.
drop_last (bool) – Whether to drop the last point in the data.
- hive.utils.visualization.find_and_standardize_data(experiments_folder, runs_folders, x_key, y_key, num_sampled_points, drop_last)[source]
Finds and standardizes the data in experiments_folder.
- Parameters
experiments_folder (str) – Root folder with all experiments data.
runs_folders (list[str]) – List of folders under root folder to load data from.
x_key (str) – Key for x axis.
y_key (str) – Key for y axis.
num_sampled_points (int) – How many points to sample along x axis.
drop_last (bool) – Whether to drop the last point in the data.
- hive.utils.visualization.generate_lineplot(x_datas, y_datas, smoothing_fn=None, line_labels=None, xlabel=None, ylabel=None, cmap_name=None, output_file='output.png')[source]
Aggregates data and generates lineplot.
Module contents
Module contents
Notes
Reproducibility
Achieving reproducibility in deep RL is difficult. Even when the random seed is fixed, libraries such as PyTorch use algorithms and implementations that are nondeterministic. PyTorch has several options that allow the user to turn off some aspects of this nondeterminism, but behavior is still usually only replicable if the runs are executed on the same hardware.
We provide a global seeding class Seeder
that allows the user to set a global seed for all packages currently
used by the framework (NumPy, PyTorch, and Python’s random package). It also
sets the PyTorch options to turn off nondeterminism. When using this seeding
functionality, before starting a run, you must set the environment variable
CUBLAS_WORKSPACE_CONFIG
to either ":16:8"
(limits performance) or
":4096:8"
(uses slightly more memory). See
this page
for more details.
The Seeder
class also provides a function
get_new_seed()
that provides a new
random seed each time it is called, which is useful when in multi-agent setups where
you want each agent to be seeded differently.
Roadmap
We have quite a lot planned for RLHive, so stay tuned! Here is what we are currently planning to add:
Policy gradient methods (PPO)
RNN support
rliable
Lifelong RL support
JAX support
Continuous Actions
Dreamerv2/MBRL
Better config system
RL Debugging tools
Log videos of trajectories
Allow changing objects from command line
Hyperparameter search integration
Add logger based on Python logging
If you have a feature request that isn’t listed here, check out our Contributing page!
Contributing
Core RLHive
Did you spot a bug, or is there something that you think should be added to the main RLHive package? Check our Issues on Github or our roadmap to see if it’s already on our radar. If not, you can either open an issue to let us know, or you can fork our repo and create a pull request with your own feature/bug fix.
Creating Issues
We’d love to hear from you on how we can improve RLHive! When creating issues, please follow these guidelines:
Tag your issue with bug, feature request, or question to help us effectively sort through the issues.
Include the version of RLHive you are running (run
pip list | grep rlhive
)
Creating PRs
When contributing to RLHive, please follow these guidelines:
Create a separate PR for each feature you are adding.
When you are done writing the feature, create a pull request to the dev branch of the main repo.
Each pull request must pass all unit tests and linting checks before being merged. Please run these checks locally before creating PRs/pushing changes to the PR to minimize unnecessary runs on our CI system. You can run the following commands from the root directory of the project:
Unit Tests:
python -m pytest tests/
Linter:
black .
Information (such as installation instructions and editor integrations) for the black formatter is available here.
Make sure your code is documented using Google style docstrings. For examples, see the rest of our repository.
Contrib RLHive
We want to encourage users to contribute their own custom components to RLHive. This could be things like agents, environments, runners, replays, optimizers, and anything else. This will allow everyone to easily use and build on your work.
To do this, we have a contrib
directory that will be part of the package. After
adding new components and any relevant citation information to this folder, new
versions of RLHive will be updated with these components, allowing your work
to get a potentially larger audience. Note, we will be actively maintaining only
the code in the main package, not the contrib
package. We can only commit to giving
minimal feedback during the review stage. If a contribution becomes widely adopted
by the community, we may move it to the main repository to actively maintain.
When submitting the PR for your contributions, you must provide some results with your new components that were generated with the RLHive package, to provide us evidence of the correctness of your implementation.