hive.agents.drqn module

class hive.agents.drqn.DRQNAgent(observation_space, action_space, representation_net, sequence_fn, id=0, optimizer_fn=None, loss_fn=None, init_fn=None, replay_buffer=None, max_seq_len=1, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, store_hidden=True, burn_frames=0, **kwargs)[source]

Bases: DQNAgent

An agent implementing the DRQN algorithm. Uses an epsilon greedy exploration policy

Parameters
  • observation_space (gym.spaces.Box) – Observation space for the agent.

  • action_space (gym.spaces.Discrete) – Action space for the agent.

  • representation_net (SequenceFunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DRQN), as well as the hidden states of the recurrent component. The structure should be similar to ConvRNNNetwork, i.e., it should have a current module component placed between the convolutional layers and MLP layers. It should also define a method that initializes the hidden state of the recurrent module if the computation requires hidden states as input/output.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to RecurrentReplayBuffer.

  • max_seq_len (int) – The number of consecutive transitions in a sequence.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

create_q_networks(representation_net, sequence_fn)[source]

Creates the Q-network and target Q-network.

Parameters

representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DRQN).

preprocess_update_info(update_info, hidden_state)[source]

Preprocesses the update_info before it goes into the replay buffer. Clips the reward in update_info. :param update_info: Contains the information from the current timestep that the

agent should use to update itself.

preprocess_update_batch(batch)[source]

Preprocess the batch sampled from the replay buffer.

Parameters

batch – Batch sampled from the replay buffer for the current update.

Returns

  • (tuple) Inputs used to calculate current state values.

  • (tuple) Inputs used to calculate next state values

  • Preprocessed batch.

Return type

(tuple)

act(observation, agent_traj_state=None)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters
  • observation – The current observation.

  • agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.

Returns

  • action

  • agent trajectory state

update(update_info, agent_traj_state=None)[source]

Updates the DRQN agent.

Parameters
  • update_info – dictionary containing all the necessary information from the environment to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, “next_observation”, “terminated”, and “truncated”.

  • agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.

Returns

  • action

  • agent trajectory state