hive.agents.rainbow module

class hive.agents.rainbow.RainbowDQNAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]

Bases: DQNAgent

An agent implementing the Rainbow algorithm.

Parameters
  • representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

  • obs_dim (Tuple) – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to PrioritizedReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

  • noisy (bool) – Whether to use noisy linear layers for exploration.

  • std_init (float) – The range for the initialization of the standard deviation of the weights.

  • use_eps_greedy (bool) – Whether to use epsilon greedy exploration.

  • double (bool) – Whether to use double DQN.

  • dueling (bool) – Whether to use a dueling network architecture.

  • distributional (bool) – Whether to use the distributional RL.

  • vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.

  • vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.

  • atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.

create_q_networks(representation_net)[source]

Creates the Q-network and target Q-network. Adds the appropriate heads for DQN, Dueling DQN, Noisy Networks, and Distributional DQN.

Parameters

representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

act(observation)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters

observation – The current observation.

update(update_info)[source]

Updates the DQN agent. :param update_info: dictionary containing all the necessary information to :param update the agent. Should contain a full transition: :param with keys for: :param “observation”: :param “action”: :param “reward”: :param “next_observation”: :param and “done”.:

target_projection(target_net_inputs, next_action, reward, done)[source]

Project distribution of target Q-values.

Parameters
  • target_net_inputs – Inputs to feed into the target net to compute the projection of the target Q-values. Should be set from preprocess_update_batch().

  • next_action (Tensor) – Tensor containing next actions used to compute target distribution.

  • reward (Tensor) – Tensor containing rewards for the current batch.

  • done (Tensor) – Tensor containing whether the states in the current batch are terminal.