Source code for hive.runners.base

from abc import ABC
from typing import List
from hive.agents.agent import Agent
from hive.envs.base import BaseEnv

from hive.runners.utils import Metrics
from hive.utils import schedule
from hive.utils.experiment import Experiment
from hive.utils.loggers import ScheduledLogger
from hive.utils.registry import Registrable
from hive.utils.utils import seeder


[docs]class Runner(ABC, Registrable): """Base Runner class used to implement a training loop. Different types of training loops can be created by overriding the relevant functions. """ def __init__( self, environment: BaseEnv, agents: List[Agent], logger: ScheduledLogger, experiment_manager: Experiment, train_steps: int, eval_environment: BaseEnv = None, test_frequency: int = -1, test_episodes: int = 1, max_steps_per_episode: int = 1e9, ): """ Args: environment (BaseEnv): Environment used in the training loop. agents (list[Agent]): List of agents that interact with the environment. logger (ScheduledLogger): Logger object used to log metrics. experiment_manager (Experiment): Experiment object that saves the state of the training. train_steps (int): How many steps to train for. This is the number of times that agent.update is called. If this is -1, there is no limit for the number of training steps. test_frequency (int): After how many training steps to run testing episodes. If this is -1, testing is not run. test_episodes (int): How many episodes to run testing for. """ self._train_environment = environment self._train_environment.seed(seeder.get_new_seed("environment")) self._eval_environment = eval_environment if self._eval_environment is not None: self._eval_environment.seed(seeder.get_new_seed("environment")) if isinstance(agents, list): self._agents = agents else: self._agents = [agents] self._logger = logger self._experiment_manager = experiment_manager if train_steps == -1: self._train_schedule = schedule.ConstantSchedule(True) else: self._train_schedule = schedule.SwitchSchedule(True, False, train_steps) if test_frequency == -1: self._test_schedule = schedule.ConstantSchedule(False) else: self._test_schedule = schedule.PeriodicSchedule(False, True, test_frequency) self._test_episodes = test_episodes self._max_steps_per_episode = max_steps_per_episode self._experiment_manager.register_experiment( logger=self._logger, agents=self._agents, environment=self._train_environment, eval_environment=self._eval_environment, ) self._experiment_manager.experiment_state.update( { "train_schedule": self._train_schedule, "test_schedule": self._test_schedule, } ) self._logger.register_timescale("train") self._logger.register_timescale("test") self._training = True self._save_experiment = False self._run_testing = False
[docs] def register_config(self, config): self._experiment_manager.register_config(config) self._logger.log_config(config)
[docs] def train_mode(self, training): """If training is true, sets all agents to training mode. If training is false, sets all agents to eval mode. Args: training (bool): Whether to be in training mode. """ self._training = training for agent in self._agents: agent.train() if training else agent.eval()
[docs] def create_episode_metrics(self): """Create the metrics used during the loop.""" return Metrics( self._agents, [("reward", 0), ("episode_length", 0)], [("full_episode_length", 0)], )
[docs] def update_step(self): """Update steps for various schedules. Run testing if appropriate.""" if self._training: self._train_schedule.update() self._logger.update_step("train") if self._test_schedule.update(): self.run_testing() self._save_experiment = ( self._experiment_manager.update_step() or self._save_experiment )
[docs] def run_episode(self, environment): """Run a single episode of the environment. Args: environment (BaseEnv): Environment in which the agent will take a step in. """ return NotImplementedError
[docs] def run_training(self): """Run the training loop. Note, to ensure that the test phase is run during the individual runners must call :py:meth:`~Runner.update_step` in their :py:meth:`~Runner.run_episode` methods. See :py:class:`~hive.runners.single_agent_loop.SingleAgentRunner` and :py:class:`~hive.runners.multi_agent_loop.MultiAgentRunner` for examples.""" # Run an initial test episode self.run_testing() self.train_mode(True) while self._train_schedule.get_value(): # Run training episode if not self._training: self.train_mode(True) episode_metrics = self.run_episode(self._train_environment) if self._logger.should_log("train"): episode_metrics = episode_metrics.get_flat_dict() self._logger.log_metrics(episode_metrics, "train") # Save experiment state if self._save_experiment: self._experiment_manager.save() self._save_experiment = False # Run a final test episode and save the experiment. self.run_testing() self._experiment_manager.save()
[docs] def run_testing(self): """Run a testing phase.""" if self._eval_environment is None: return self.train_mode(False) aggregated_episode_metrics = self.create_episode_metrics().get_flat_dict() for _ in range(self._test_episodes): episode_metrics = self.run_episode(self._eval_environment) for metric, value in episode_metrics.get_flat_dict().items(): aggregated_episode_metrics[metric] += value / self._test_episodes self._logger.update_step("test") self._logger.log_metrics(aggregated_episode_metrics, "test") self._run_testing = False self.train_mode(True)
[docs] def resume(self): """Resume a saved experiment.""" if self._experiment_manager.is_resumable(): self._experiment_manager.resume() self._train_schedule = self._experiment_manager.experiment_state[ "train_schedule" ] self._test_schedule = self._experiment_manager.experiment_state["test_schedule"]
[docs] @classmethod def type_name(cls): """ Returns: "runner" """ return "runner"