Source code for hive.runners.single_agent_loop

import argparse
import copy

from hive import agents as agent_lib
from hive import envs
from hive.runners.base import Runner
from hive.runners.utils import TransitionInfo, load_config
from hive.utils import experiment, loggers, schedule, utils
from hive.utils.registry import get_parsed_args


[docs]class SingleAgentRunner(Runner): """Runner class used to implement a sinle-agent training loop.""" def __init__( self, environment, agent, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, max_steps_per_episode=27000, ): """Initializes the Runner object. Args: environment (BaseEnv): Environment used in the training loop. agent (Agent): Agent that will interact with the environment logger (ScheduledLogger): Logger object used to log metrics. experiment_manager (Experiment): Experiment object that saves the state of the training. train_steps (int): How many steps to train for. If this is -1, there is no limit for the number of training steps. test_frequency (int): After how many training steps to run testing episodes. If this is -1, testing is not run. test_episodes (int): How many episodes to run testing for duing each test phase. stack_size (int): The number of frames in an observation sent to an agent. max_steps_per_episode (int): The maximum number of steps to run an episode for. """ super().__init__( environment, agent, logger, experiment_manager, train_steps, test_frequency, test_episodes, max_steps_per_episode, ) self._transition_info = TransitionInfo(self._agents, stack_size)
[docs] def run_one_step(self, observation, episode_metrics): """Run one step of the training loop. Args: observation: Current observation that the agent should create an action for. episode_metrics (Metrics): Keeps track of metrics for current episode. """ super().run_one_step(observation, 0, episode_metrics) agent = self._agents[0] stacked_observation = self._transition_info.get_stacked_state( agent, observation ) action = agent.act(stacked_observation) next_observation, reward, done, _, other_info = self._environment.step(action) info = { "observation": observation, "reward": reward, "action": action, "done": done, "info": other_info, } if self._training: agent.update(copy.deepcopy(info)) self._transition_info.record_info(agent, info) episode_metrics[agent.id]["reward"] += info["reward"] episode_metrics[agent.id]["episode_length"] += 1 episode_metrics["full_episode_length"] += 1 return done, next_observation
[docs] def run_episode(self): """Run a single episode of the environment.""" episode_metrics = self.create_episode_metrics() done = False observation, _ = self._environment.reset() self._transition_info.reset() self._transition_info.start_agent(self._agents[0]) steps = 0 # Run the loop until the episode ends or times out while not done and steps < self._max_steps_per_episode: done, observation = self.run_one_step(observation, episode_metrics) steps += 1 return episode_metrics
[docs]def set_up_experiment(config): """Returns a :py:class:`SingleAgentRunner` object based on the config and any command line arguments. Args: config: Configuration for experiment. """ args = get_parsed_args( { "seed": int, "train_steps": int, "test_frequency": int, "test_episodes": int, "max_steps_per_episode": int, "stack_size": int, "resume": bool, "run_name": str, "save_dir": str, } ) config.update(args) full_config = utils.Chomp(copy.deepcopy(config)) if "seed" in config: utils.seeder.set_global_seed(config["seed"]) environment, full_config["environment"] = envs.get_env( config["environment"], "environment" ) env_spec = environment.env_spec # Set up loggers logger_config = config.get("loggers", {"name": "NullLogger"}) if logger_config is None or len(logger_config) == 0: logger_config = {"name": "NullLogger"} if isinstance(logger_config, list): logger_config = { "name": "CompositeLogger", "kwargs": {"logger_list": logger_config}, } logger, full_config["loggers"] = loggers.get_logger(logger_config, "loggers") # Set up agent if config.get("stack_size", 1) > 1: config["agent"]["kwargs"]["obs_dim"] = ( config["stack_size"] * env_spec.obs_dim[0][0], *env_spec.obs_dim[0][1:], ) else: config["agent"]["kwargs"]["obs_dim"] = env_spec.obs_dim[0] config["agent"]["kwargs"]["act_dim"] = env_spec.act_dim[0] config["agent"]["kwargs"]["logger"] = logger if "replay_buffer" in config["agent"]["kwargs"]: replay_args = config["agent"]["kwargs"]["replay_buffer"]["kwargs"] replay_args["observation_shape"] = env_spec.obs_dim[0] agent, full_config["agent"] = agent_lib.get_agent(config["agent"], "agent") # Set up experiment manager saving_schedule, full_config["saving_schedule"] = schedule.get_schedule( config["saving_schedule"], "saving_schedule" ) experiment_manager = experiment.Experiment( config["run_name"], config["save_dir"], saving_schedule ) experiment_manager.register_experiment( config=full_config, logger=logger, agents=agent, ) # Set up runner runner = SingleAgentRunner( environment, agent, logger, experiment_manager, config.get("train_steps", -1), config.get("test_frequency", -1), config.get("test_episodes", 1), config.get("stack_size", 1), config.get("max_steps_per_episode", 1e9), ) if config.get("resume", False): runner.resume() return runner
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("-c", "--config") parser.add_argument("-p", "--preset-config") parser.add_argument("-a", "--agent-config") parser.add_argument("-e", "--env-config") parser.add_argument("-l", "--logger-config") args, _ = parser.parse_known_args() if args.config is None and args.preset_config is None: raise ValueError("Config needs to be provided") config = load_config( args.config, args.preset_config, args.agent_config, args.env_config, args.logger_config, ) runner = set_up_experiment(config) runner.run_training()
if __name__ == "__main__": main()