hive.agents.drqn module
- class hive.agents.drqn.DRQNAgent(observation_space, action_space, representation_net, sequence_fn, id=0, optimizer_fn=None, loss_fn=None, init_fn=None, replay_buffer=None, max_seq_len=1, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, store_hidden=True, burn_frames=0, **kwargs)[source]
Bases:
DQNAgent
An agent implementing the DRQN algorithm. Uses an epsilon greedy exploration policy
- Parameters
observation_space (gym.spaces.Box) – Observation space for the agent.
action_space (gym.spaces.Discrete) – Action space for the agent.
representation_net (SequenceFunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DRQN), as well as the hidden states of the recurrent component. The structure should be similar to ConvRNNNetwork, i.e., it should have a current module component placed between the convolutional layers and MLP layers. It should also define a method that initializes the hidden state of the recurrent module if the computation requires hidden states as input/output.
id – Agent identifier.
optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to
Adam
.loss_fn (LossFn) – Loss function used by the agent. If None, defaults to
SmoothL1Loss
.init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
RecurrentReplayBuffer
.max_seq_len (int) – The number of consecutive transitions in a sequence.
discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.
target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.
target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.
target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.
epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.
test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
device – Device on which all computations should be run.
logger (ScheduledLogger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
- create_q_networks(representation_net, sequence_fn)[source]
Creates the Q-network and target Q-network.
- Parameters
representation_net – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DRQN).
- preprocess_update_info(update_info, hidden_state)[source]
Preprocesses the
update_info
before it goes into the replay buffer. Clips the reward in update_info. :param update_info: Contains the information from the current timestep that theagent should use to update itself.
- preprocess_update_batch(batch)[source]
Preprocess the batch sampled from the replay buffer.
- Parameters
batch – Batch sampled from the replay buffer for the current update.
- Returns
(tuple) Inputs used to calculate current state values.
(tuple) Inputs used to calculate next state values
Preprocessed batch.
- Return type
(tuple)
- act(observation, agent_traj_state=None)[source]
Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.
- Parameters
observation – The current observation.
agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.
- Returns
action
agent trajectory state
- update(update_info, agent_traj_state=None)[source]
Updates the DRQN agent.
- Parameters
update_info – dictionary containing all the necessary information from the environment to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, “next_observation”, “terminated”, and “truncated”.
agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.
- Returns
action
agent trajectory state