hive.agents.td3 module

class hive.agents.td3.TD3(observation_space, action_space, representation_net=None, actor_net=None, critic_net=None, init_fn=None, actor_optimizer_fn=None, critic_optimizer_fn=None, critic_loss_fn=None, n_critics=2, stack_size=1, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, soft_update_fraction=0.005, batch_size=64, logger=None, log_frequency=100, update_frequency=1, policy_update_frequency=2, action_noise=0, target_noise=0.2, target_noise_clip=0.5, min_replay_history=1000, device='cpu', id=0)[source]

Bases: Agent

An agent implementing the TD3 algorithm.

Parameters
  • observation_space (gym.spaces.Box) – Observation space for the agent.

  • action_space (gym.spaces.Box) – Action space for the agent.

  • representation_net (FunctionApproximator) – The network that encodes the observations that are then fed into the actor_net and critic_net. If None, defaults to Identity.

  • actor_net (FunctionApproximator) – The network that takes the encoded observations from representation_net and outputs the representations used to compute the actions (ie everything except the last layer).

  • critic_net (FunctionApproximator) – The network that takes two inputs: the encoded observations from representation_net and actions. It outputs the representations used to compute the values of the actions (ie everything except the last layer).

  • init_fn (InitializationFn) – Initializes the weights of agent networks using create_init_weights_fn.

  • actor_optimizer_fn (OptimizerFn) – A function that takes in the list of parameters of the actor returns the optimizer for the actor. If None, defaults to Adam.

  • critic_optimizer_fn (OptimizerFn) – A function that takes in the list of parameters of the critic returns the optimizer for the critic. If None, defaults to Adam.

  • critic_loss_fn (LossFn) – The loss function used to optimize the critic. If None, defaults to MSELoss.

  • n_critics (int) – The number of critics used by the agent to estimate Q-values. The minimum Q-value is used as the value for the next state when calculating target Q-values for the critic. The output of the first critic is used when computing the loss for the actor. For TD3, the default value is 2. For DDPG, this parameter is 1.

  • stack_size (int) – Number of observations stacked to create the state fed to the agent.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to CircularReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • soft_update_fraction (float) – The weight given to the target net parameters in a soft (polyak) update. Also known as tau.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • logger (Logger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

  • update_frequency (int) – How frequently to update the agent. A value of 1 means the agent will be updated every time update is called.

  • policy_update_frequency (int) – Relative update frequency of the actor compared to the critic. The actor will be updated every policy_update_frequency times the critic is updated.

  • action_noise (float) – The standard deviation for the noise added to the action taken by the agent during training.

  • target_noise (float) – The standard deviation of the noise added to the target policy for smoothing.

  • target_noise_clip (float) – The sampled target_noise is clipped to [-target_noise_clip, target_noise_clip].

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • device – Device on which all computations should be run.

  • id – Agent identifier.

create_networks(representation_net, actor_net, critic_net)[source]

Creates the actor and critic networks.

Parameters
  • representation_net – A network that outputs the shared representations that will be used by the actor and critic networks to process observations.

  • actor_net – The network that will be used to compute actions.

  • critic_net – The network that will be used to compute values of state action pairs.

train()[source]

Changes the agent to training mode.

eval()[source]

Changes the agent to evaluation mode.

scale_action(actions)[source]

Scales actions to [-1, 1].

unscale_actions(actions)[source]

Unscales actions from [-1, 1] to expected scale.

preprocess_update_info(update_info)[source]

Preprocesses the update_info before it goes into the replay buffer. Scales the action to [-1, 1].

Parameters

update_info – Contains the information from the current timestep that the agent should use to update itself.

preprocess_update_batch(batch)[source]

Preprocess the batch sampled from the replay buffer.

Parameters

batch – Batch sampled from the replay buffer for the current update.

Returns

  • (tuple) Inputs used to calculate current state values.

  • (tuple) Inputs used to calculate next state values

  • Preprocessed batch.

Return type

(tuple)

act(observation, agent_traj_state=None)[source]

Returns the action for the agent. If in training mode, adds noise with standard deviation self._action_noise.

Parameters
  • observation – The current observation.

  • agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.

Returns

  • action

  • agent trajectory state

update(update_info, agent_traj_state=None)[source]

Updates the TD3 agent.

Parameters
  • update_info – dictionary containing all the necessary information from the environment to update the agent. Should contain a full transition, with keys for “observation”, “action”, “reward”, “next_observation”, “terminated”, and “truncated

  • agent_traj_state – Contains necessary state information for the agent to process current trajectory. This should be updated and returned.

Returns

  • action

  • agent trajectory state

save(dname)[source]

Saves agent checkpointing information to file for future loading.

Parameters

dname (str) – directory where agent should save all relevant info.

load(dname)[source]

Loads agent information from file.

Parameters

dname (str) – directory where agent checkpoint info is stored.