import abc
import gymnasium as gym
from hive.utils.registry import Registrable
[docs]class Agent(abc.ABC, Registrable):
"""Base class for agents. Every implemented agent should be a subclass of
this class.
"""
def __init__(self, observation_space: gym.Space, action_space: gym.Space, id=0):
"""
Args:
observation_space (gym.Space): Observation space for agent.
action_space (gym.Space): Action space for agent.
id: Identifier for the agent.
"""
self._observation_space = observation_space
self._action_space = action_space
self._training = True
self._id = str(id)
@property
def id(self):
return self._id
[docs] @abc.abstractmethod
def act(self, observation, agent_traj_state):
"""Returns an action for the agent to perform based on the observation.
Args:
observation: Current observation that agent should act on.
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.
Returns:
- Action for the current timestep.
- Agent trajectory state.
"""
pass
[docs] @abc.abstractmethod
def update(self, update_info, agent_traj_state):
"""
Updates the agent.
Args:
update_info (dict): Contains information from the environment agent needs
to update itself.
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.
Returns:
Agent trajectory state.
"""
pass
[docs] def train(self):
"""Changes the agent to training mode."""
self._training = True
[docs] def eval(self):
"""Changes the agent to evaluation mode"""
self._training = False
[docs] @abc.abstractmethod
def save(self, dname):
"""
Saves agent checkpointing information to file for future loading.
Args:
dname (str): directory where agent should save all relevant info.
"""
pass
[docs] @abc.abstractmethod
def load(self, dname):
"""
Loads agent information from file.
Args:
dname (str): directory where agent checkpoint info is stored.
"""
pass
[docs] @classmethod
def type_name(cls):
"""
Returns:
"agent"
"""
return "agent"