import gymnasium as gym
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.utils import InitializationFn
from hive.agents.td3 import TD3
from hive.replays import BaseReplayBuffer
from hive.utils.loggers import Logger
from hive.utils.utils import LossFn, OptimizerFn
[docs]class DDPG(TD3):
"""
An agent implementing the DDPG algorithm. It is implemented by fixing the
n_critics, policy_update_frequency, target_noise, and target_noise_clip parameters
of the :py:class:`~hive.agents.td3.TD3` agent.
"""
def __init__(
self,
observation_space: gym.spaces.Box,
action_space: gym.spaces.Box,
representation_net: FunctionApproximator = None,
actor_net: FunctionApproximator = None,
critic_net: FunctionApproximator = None,
init_fn: InitializationFn = None,
actor_optimizer_fn: OptimizerFn = None,
critic_optimizer_fn: OptimizerFn = None,
critic_loss_fn: LossFn = None,
stack_size: int = 1,
replay_buffer: BaseReplayBuffer = None,
discount_rate: float = 0.99,
n_step: int = 1,
grad_clip: float = None,
reward_clip: float = None,
soft_update_fraction: float = 0.005,
batch_size: int = 64,
logger: Logger = None,
log_frequency: int = 100,
update_frequency: int = 1,
action_noise: float = 0,
min_replay_history: int = 1000,
device="cpu",
id=0,
):
"""
Args:
observation_space (gym.spaces.Box): Observation space for the agent.
action_space (gym.spaces.Box): Action space for the agent.
representation_net (FunctionApproximator): The network that encodes the
observations that are then fed into the actor_net and critic_net. If
None, defaults to :py:class:`~torch.nn.Identity`.
actor_net (FunctionApproximator): The network that takes the encoded
observations from representation_net and outputs the representations
used to compute the actions (ie everything except the last layer).
critic_net (FunctionApproximator): The network that takes two inputs: the
encoded observations from representation_net and actions. It outputs
the representations used to compute the values of the actions (ie
everything except the last layer).
init_fn (InitializationFn): Initializes the weights of agent networks using
create_init_weights_fn.
actor_optimizer_fn (OptimizerFn): A function that takes in the list of
parameters of the actor returns the optimizer for the actor. If None,
defaults to :py:class:`~torch.optim.Adam`.
critic_optimizer_fn (OptimizerFn): A function that takes in the list of
parameters of the critic returns the optimizer for the critic. If None,
defaults to :py:class:`~torch.optim.Adam`.
critic_loss_fn (LossFn): The loss function used to optimize the critic. If
None, defaults to :py:class:`~torch.nn.MSELoss`.
stack_size (int): Number of observations stacked to create the state fed
to the agent.
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`.
discount_rate (float): A number between 0 and 1 specifying how much
future rewards are discounted by the agent.
n_step (int): The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float): Gradients will be clipped to between
[-grad_clip, grad_clip].
reward_clip (float): Rewards will be clipped to between
[-reward_clip, reward_clip].
soft_update_fraction (float): The weight given to the target
net parameters in a soft (polyak) update. Also known as tau.
batch_size (int): The size of the batch sampled from the replay buffer
during learning.
logger (Logger): Logger used to log agent's metrics.
log_frequency (int): How often to log the agent's metrics.
update_frequency (int): How frequently to update the agent. A value of 1
means the agent will be updated every time update is called.
action_noise (float): The standard deviation for the noise added to the
action taken by the agent during training.
min_replay_history (int): How many observations to fill the replay buffer
with before starting to learn.
device: Device on which all computations should be run.
id: Agent identifier.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
representation_net=representation_net,
actor_net=actor_net,
critic_net=critic_net,
init_fn=init_fn,
actor_optimizer_fn=actor_optimizer_fn,
critic_optimizer_fn=critic_optimizer_fn,
critic_loss_fn=critic_loss_fn,
n_critics=1,
stack_size=stack_size,
replay_buffer=replay_buffer,
discount_rate=discount_rate,
n_step=n_step,
grad_clip=grad_clip,
reward_clip=reward_clip,
soft_update_fraction=soft_update_fraction,
batch_size=batch_size,
logger=logger,
log_frequency=log_frequency,
update_frequency=update_frequency,
policy_update_frequency=1,
action_noise=action_noise,
target_noise=0.0,
target_noise_clip=0.0,
min_replay_history=min_replay_history,
device=device,
id=id,
)