import abc
from hive.utils.registry import Registrable
[docs]class Agent(abc.ABC, Registrable):
"""Base class for agents. Every implemented agent should be a subclass of
this class.
"""
def __init__(self, obs_dim, act_dim, id=0):
"""
Args:
obs_dim: Dimension of observations that agent will see.
act_dim: Number of actions that the agent needs to chose from.
id: Identifier for the agent.
"""
self._obs_dim = obs_dim
self._act_dim = act_dim
self._training = True
self._id = str(id)
@property
def id(self):
return self._id
[docs] @abc.abstractmethod
def act(self, observation):
"""Returns an action for the agent to perform based on the observation.
Args:
observation: Current observation that agent should act on.
Returns:
Action for the current timestep.
"""
pass
[docs] @abc.abstractmethod
def update(self, update_info):
"""
Updates the agent.
Args:
update_info (dict): Contains information agent needs to update
itself.
"""
pass
[docs] def train(self):
"""Changes the agent to training mode."""
self._training = True
[docs] def eval(self):
"""Changes the agent to evaluation mode"""
self._training = False
[docs] @abc.abstractmethod
def save(self, dname):
"""
Saves agent checkpointing information to file for future loading.
Args:
dname (str): directory where agent should save all relevant info.
"""
pass
[docs] @abc.abstractmethod
def load(self, dname):
"""
Loads agent information from file.
Args:
dname (str): directory where agent checkpoint info is stored.
"""
pass
[docs] @classmethod
def type_name(cls):
"""
Returns:
"agent"
"""
return "agent"