hive.agents.rainbow module
- class hive.agents.rainbow.RainbowDQNAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]
Bases:
DQNAgent
An agent implementing the Rainbow algorithm.
- Parameters
representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
obs_dim (
Tuple
) – The shape of the observations.act_dim (int) – The number of actions available to the agent.
id – Agent identifier.
optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to
Adam
.loss_fn (LossFn) – Loss function used by the agent. If None, defaults to
SmoothL1Loss
.init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
PrioritizedReplayBuffer
.discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.
target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.
target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.
target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.
epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.
test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
device – Device on which all computations should be run.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
noisy (bool) – Whether to use noisy linear layers for exploration.
std_init (float) – The range for the initialization of the standard deviation of the weights.
use_eps_greedy (bool) – Whether to use epsilon greedy exploration.
double (bool) – Whether to use double DQN.
dueling (bool) – Whether to use a dueling network architecture.
distributional (bool) – Whether to use the distributional RL.
vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.
vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.
atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.
- create_q_networks(representation_net)[source]
Creates the Q-network and target Q-network. Adds the appropriate heads for DQN, Dueling DQN, Noisy Networks, and Distributional DQN.
- Parameters
representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).
- act(observation)[source]
Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.
- Parameters
observation – The current observation.
- update(update_info)[source]
Updates the DQN agent. :param update_info: dictionary containing all the necessary information to :param update the agent. Should contain a full transition: :param with keys for: :param “observation”: :param “action”: :param “reward”: :param “next_observation”: :param and “done”.:
- target_projection(target_net_inputs, next_action, reward, done)[source]
Project distribution of target Q-values.
- Parameters
target_net_inputs – Inputs to feed into the target net to compute the projection of the target Q-values. Should be set from
preprocess_update_batch()
.next_action (Tensor) – Tensor containing next actions used to compute target distribution.
reward (Tensor) – Tensor containing rewards for the current batch.
done (Tensor) – Tensor containing whether the states in the current batch are terminal.