hive.agents.td3 module
- class hive.agents.td3.TD3(observation_space, action_space, representation_net=None, actor_net=None, critic_net=None, init_fn=None, actor_optimizer_fn=None, critic_optimizer_fn=None, critic_loss_fn=None, n_critics=2, stack_size=1, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, soft_update_fraction=0.005, batch_size=64, logger=None, log_frequency=100, update_frequency=1, policy_update_frequency=2, action_noise=0, target_noise=0.2, target_noise_clip=0.5, min_replay_history=1000, device='cpu', id=0)[source]
Bases:
Agent
An agent implementing the TD3 algorithm.
- Parameters
observation_space (gym.spaces.Box) – Observation space for the agent.
action_space (gym.spaces.Box) – Action space for the agent.
representation_net (FunctionApproximator) – The network that encodes the observations that are then fed into the actor_net and critic_net. If None, defaults to
Identity
.actor_net (FunctionApproximator) – The network that takes the encoded observations from representation_net and outputs the representations used to compute the actions (ie everything except the last layer).
critic_net (FunctionApproximator) – The network that takes two inputs: the encoded observations from representation_net and actions. It outputs the representations used to compute the values of the actions (ie everything except the last layer).
init_fn (InitializationFn) – Initializes the weights of agent networks using create_init_weights_fn.
actor_optimizer_fn (OptimizerFn) – A function that takes in the list of parameters of the actor returns the optimizer for the actor. If None, defaults to
Adam
.critic_optimizer_fn (OptimizerFn) – A function that takes in the list of parameters of the critic returns the optimizer for the critic. If None, defaults to
Adam
.critic_loss_fn (LossFn) – The loss function used to optimize the critic. If None, defaults to
MSELoss
.n_critics (int) – The number of critics used by the agent to estimate Q-values. The minimum Q-value is used as the value for the next state when calculating target Q-values for the critic. The output of the first critic is used when computing the loss for the actor. For TD3, the default value is 2. For DDPG, this parameter is 1.
stack_size (int) – Number of observations stacked to create the state fed to the agent.
replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to
CircularReplayBuffer
.discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.
n_step (int) – The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].
reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].
soft_update_fraction (float) – The weight given to the target net parameters in a soft (polyak) update. Also known as tau.
batch_size (int) – The size of the batch sampled from the replay buffer during learning.
logger (Logger) – Logger used to log agent’s metrics.
log_frequency (int) – How often to log the agent’s metrics.
update_frequency (int) – How frequently to update the agent. A value of 1 means the agent will be updated every time update is called.
policy_update_frequency (int) – Relative update frequency of the actor compared to the critic. The actor will be updated every policy_update_frequency times the critic is updated.
action_noise (float) – The standard deviation for the noise added to the action taken by the agent during training.
target_noise (float) – The standard deviation of the noise added to the target policy for smoothing.
target_noise_clip (float) – The sampled target_noise is clipped to [-target_noise_clip, target_noise_clip].
min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.
device – Device on which all computations should be run.
id – Agent identifier.
- create_networks(representation_net, actor_net, critic_net)[source]
Creates the actor and critic networks.
- Parameters
representation_net – A network that outputs the shared representations that will be used by the actor and critic networks to process observations.
actor_net – The network that will be used to compute actions.
critic_net – The network that will be used to compute values of state action pairs.
- preprocess_update_info(update_info)[source]
Preprocesses the
update_info
before it goes into the replay buffer. Scales the action to [-1, 1].- Parameters
update_info – Contains the information from the current timestep that the agent should use to update itself.
- preprocess_update_batch(batch)[source]
Preprocess the batch sampled from the replay buffer.
- Parameters
batch – Batch sampled from the replay buffer for the current update.
- Returns
(tuple) Inputs used to calculate current state values.
(tuple) Inputs used to calculate next state values
Preprocessed batch.
- Return type
(tuple)
- act(observation, agent_traj_state=None)[source]
Returns the action for the agent. If in training mode, adds noise with standard deviation
self._action_noise
.- Parameters
observation – The current observation.
agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.
- Returns
action
agent trajectory state
- update(update_info, agent_traj_state=None)[source]
Updates the TD3 agent.
- Parameters
update_info – dictionary containing all the necessary information from the environment to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, “next_observation”, “terminated”, and “truncated
agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.
- Returns
action
agent trajectory state