import copy
from typing import List
from hive.agents.agent import Agent
from hive.envs.base import BaseEnv
from hive.runners import Runner
from hive.runners.utils import TransitionInfo
from hive.utils import utils
from hive.utils.experiment import Experiment
from hive.utils.loggers import CompositeLogger, NullLogger, ScheduledLogger
[docs]class SingleAgentRunner(Runner):
"""Runner class used to implement a sinle-agent training loop."""
def __init__(
self,
environment: BaseEnv,
agent: Agent,
loggers: List[ScheduledLogger],
experiment_manager: Experiment,
train_steps: int,
eval_environment: BaseEnv = None,
test_frequency: int = -1,
test_episodes: int = 1,
stack_size: int = 1,
max_steps_per_episode: int = 1e9,
seed: int = None,
):
"""Initializes the SingleAgentRunner.
Args:
environment (BaseEnv): Environment used in the training loop.
agent (Agent): Agent that will interact with the environment
loggers (List[ScheduledLogger]): List of loggers used to log metrics.
experiment_manager (Experiment): Experiment object that saves the state of
the training.
train_steps (int): How many steps to train for. This is the number
of times that agent.update is called. If this is -1, there is no
limit for the number of training steps.
eval_environment (BaseEnv): Environment used to evaluate the agent. If
None, the ``environment`` parameter (which is a function) is
used to create a second environment.
test_frequency (int): After how many training steps to run testing
episodes. If this is -1, testing is not run.
test_episodes (int): How many episodes to run testing for duing each test
phase.
stack_size (int): The number of frames in an observation sent to an agent.
max_steps_per_episode (int): The maximum number of steps to run an episode
for.
seed (int): Seed used to set the global seed for libraries used by
Hive and seed the :py:class:`~hive.utils.utils.Seeder`.
"""
if seed is not None:
utils.seeder.set_global_seed(seed)
if eval_environment is None:
eval_environment = environment
environment = environment()
eval_environment = eval_environment() if test_frequency != -1 else None
env_spec = environment.env_spec
# Set up loggers
if loggers is None:
logger = NullLogger()
else:
logger = CompositeLogger(loggers)
agent = agent(
observation_space=env_spec.observation_space[0],
action_space=env_spec.action_space[0],
stack_size=stack_size,
logger=logger,
)
# Set up experiment manager
experiment_manager = experiment_manager()
super().__init__(
environment=environment,
eval_environment=eval_environment,
agents=[agent],
logger=logger,
experiment_manager=experiment_manager,
train_steps=train_steps,
test_frequency=test_frequency,
test_episodes=test_episodes,
max_steps_per_episode=max_steps_per_episode,
)
self._stack_size = stack_size
[docs] def run_one_step(
self,
environment,
observation,
episode_metrics,
transition_info,
agent_traj_state,
):
"""Run one step of the training loop.
Args:
observation: Current observation that the agent should create an action
for.
episode_metrics (Metrics): Keeps track of metrics for current episode.
"""
agent = self._agents[0]
stacked_observation = transition_info.get_stacked_state(agent, observation)
action, agent_traj_state = agent.act(stacked_observation, agent_traj_state)
(
next_observation,
reward,
terminated,
truncated,
_,
other_info,
) = environment.step(action)
info = {
"observation": observation,
"next_observation": next_observation,
"reward": reward,
"action": action,
"terminated": terminated,
"truncated": truncated,
"info": other_info,
}
if self._training:
agent_traj_state = agent.update(copy.deepcopy(info), agent_traj_state)
transition_info.record_info(agent, info)
episode_metrics[agent.id]["reward"] += info["reward"]
episode_metrics[agent.id]["episode_length"] += 1
episode_metrics["full_episode_length"] += 1
return terminated, truncated, next_observation, agent_traj_state
[docs] def run_end_step(
self,
environment,
observation,
episode_metrics,
transition_info,
agent_traj_state,
):
"""Run the final step of an episode.
After an episode ends, set the truncated value to true.
Args:
environment (BaseEnv): Environment in which the agent will take a step in.
observation: Current observation that the agent should create an action
for.
episode_metrics (Metrics): Keeps track of metrics for current
episode.
transition_info (TransitionInfo): Used to keep track of the most
recent transition for the agent.
agent_traj_state: Trajectory state object that will be passed to the
agent when act and update are called. The agent returns a new
trajectory state object to replace the state passed in.
"""
agent = self._agents[0]
stacked_observation = transition_info.get_stacked_state(agent, observation)
action, agent_traj_state = agent.act(stacked_observation, agent_traj_state)
next_observation, reward, terminated, _, _, other_info = environment.step(
action
)
truncated = not terminated
info = {
"observation": observation,
"next_observation": next_observation,
"reward": reward,
"action": action,
"terminated": terminated,
"truncated": truncated,
"info": other_info,
}
if self._training:
agent_traj_state = agent.update(copy.deepcopy(info), agent_traj_state)
transition_info.record_info(agent, info)
episode_metrics[agent.id]["reward"] += info["reward"]
episode_metrics[agent.id]["episode_length"] += 1
episode_metrics["full_episode_length"] += 1
return terminated, truncated, next_observation, agent_traj_state
[docs] def run_episode(self, environment):
"""Run a single episode of the environment.
Args:
environment (BaseEnv): Environment in which the agent will take a step in.
"""
episode_metrics = self.create_episode_metrics()
terminated, truncated = False, False
observation, _ = environment.reset()
transition_info = TransitionInfo(self._agents, self._stack_size)
transition_info.start_agent(self._agents[0])
agent_traj_state = None
steps = 0
# Run the loop until the episode ends or times out
while (
not (terminated or truncated)
and steps < self._max_steps_per_episode - 1
and (not self._training or self._train_schedule.get_value())
):
terminated, truncated, observation, agent_traj_state = self.run_one_step(
environment,
observation,
episode_metrics,
transition_info,
agent_traj_state,
)
steps += 1
self.update_step()
if not (terminated or truncated):
self.run_end_step(
environment,
observation,
episode_metrics,
transition_info,
agent_traj_state,
)
self.update_step()
return episode_metrics