Source code for hive.runners.base

from abc import ABC
from typing import List
from hive.agents.agent import Agent
from hive.envs.base import BaseEnv

from hive.runners.utils import Metrics
from hive.utils import schedule
from hive.utils.experiment import Experiment
from hive.utils.loggers import ScheduledLogger


[docs]class Runner(ABC): """Base Runner class used to implement a training loop. Different types of training loops can be created by overriding the relevant functions. """ def __init__( self, environment: BaseEnv, agents: List[Agent], logger: ScheduledLogger, experiment_manager: Experiment, train_steps: int = 1000000, test_frequency: int = 10000, test_episodes: int = 1, max_steps_per_episode: int = 27000, ): """ Args: environment (BaseEnv): Environment used in the training loop. agents (list[Agent]): List of agents that interact with the environment. logger (ScheduledLogger): Logger object used to log metrics. experiment_manager (Experiment): Experiment object that saves the state of the training. train_steps (int): How many steps to train for. If this is -1, there is no limit for the number of training steps. test_frequency (int): After how many training steps to run testing episodes. If this is -1, testing is not run. test_episodes (int): How many episodes to run testing for. """ self._environment = environment if isinstance(agents, list): self._agents = agents else: self._agents = [agents] self._logger = logger self._experiment_manager = experiment_manager if train_steps == -1: self._train_schedule = schedule.ConstantSchedule(True) else: self._train_schedule = schedule.SwitchSchedule(True, False, train_steps) if test_frequency == -1: self._test_schedule = schedule.ConstantSchedule(False) else: self._test_schedule = schedule.PeriodicSchedule(False, True, test_frequency) self._test_episodes = test_episodes self._max_steps_per_episode = max_steps_per_episode self._experiment_manager.experiment_state.update( { "train_schedule": self._train_schedule, "test_schedule": self._test_schedule, } ) self._logger.register_timescale("train") self._logger.register_timescale("test") self._training = True self._save_experiment = False self._run_testing = False
[docs] def train_mode(self, training): """If training is true, sets all agents to training mode. If training is false, sets all agents to eval mode. Args: training (bool): Whether to be in training mode. """ self._training = training for agent in self._agents: agent.train() if training else agent.eval()
[docs] def create_episode_metrics(self): """Create the metrics used during the loop.""" return Metrics( self._agents, [("reward", 0), ("episode_length", 0)], [("full_episode_length", 0)], )
[docs] def run_one_step(self, observation, turn, episode_metrics): """Run one step of the training loop. Args: observation: Current observation that the agent should create an action for. turn (int): Agent whose turn it is. episode_metrics (Metrics): Keeps track of metrics for current episode. """ if self._training: self._train_schedule.update() self._logger.update_step("train") self._run_testing = self._test_schedule.update() or self._run_testing self._save_experiment = ( self._experiment_manager.update_step() or self._save_experiment )
[docs] def run_end_step(self, episode_metrics, done): """Run the final step of an episode. Args: episode_metrics (Metrics): Keeps track of metrics for current episode. done (bool): Whether this step was terminal. """ return NotImplementedError
[docs] def run_episode(self): """Run a single episode of the environment.""" return NotImplementedError
[docs] def run_training(self): """Run the training loop.""" self.train_mode(True) while self._train_schedule.get_value(): # Run training episode if not self._training: self.train_mode(True) episode_metrics = self.run_episode() if self._logger.should_log("train"): episode_metrics = episode_metrics.get_flat_dict() self._logger.log_metrics(episode_metrics, "train") # Run test episodes if self._run_testing: test_metrics = self.run_testing() self._logger.update_step("test") self._logger.log_metrics(test_metrics, "test") self._run_testing = False # Save experiment state if self._save_experiment: self._experiment_manager.save() self._save_experiment = False # Run a final test episode and save the experiment. test_metrics = self.run_testing() self._logger.update_step("test") self._logger.log_metrics(test_metrics, "test") self._experiment_manager.save()
[docs] def run_testing(self): """Run a testing phase.""" self.train_mode(False) aggregated_episode_metrics = self.create_episode_metrics().get_flat_dict() episodes = 0 while episodes <= self._test_episodes: episode_metrics = self.run_episode() episodes += 1 for metric, value in episode_metrics.get_flat_dict().items(): aggregated_episode_metrics[metric] += value / self._test_episodes return aggregated_episode_metrics
[docs] def resume(self): """Resume a saved experiment.""" self._experiment_manager.resume() self._train_schedule = self._experiment_manager.experiment_state[ "train_schedule" ] self._test_schedule = self._experiment_manager.experiment_state["test_schedule"]