Source code for hive.runners.multi_agent_loop

import argparse
import copy

from hive import agents as agent_lib
from hive import envs
from hive.runners.base import Runner
from hive.runners.utils import TransitionInfo, load_config
from hive.utils import experiment, loggers, schedule, utils
from hive.utils.registry import get_parsed_args


[docs]class MultiAgentRunner(Runner): """Runner class used to implement a multiagent training loop.""" def __init__( self, environment, agents, logger, experiment_manager, train_steps, test_frequency, test_episodes, stack_size, self_play, max_steps_per_episode=27000, ): """Initializes the Runner object. Args: environment (BaseEnv): Environment used in the training loop. agents (list[Agent]): List of agents that interact with the environment logger (ScheduledLogger): Logger object used to log metrics. experiment_manager (Experiment): Experiment object that saves the state of the training. train_steps (int): How many steps to train for. If this is -1, there is no limit for the number of training steps. test_frequency (int): After how many training steps to run testing episodes. If this is -1, testing is not run. test_episodes (int): How many episodes to run testing for. stack_size (int): The number of frames in an observation sent to an agent. max_steps_per_episode (int): The maximum number of steps to run an episode for. """ super().__init__( environment, agents, logger, experiment_manager, train_steps, test_frequency, test_episodes, max_steps_per_episode, ) self._transition_info = TransitionInfo(self._agents, stack_size) self._self_play = self_play
[docs] def run_one_step(self, observation, turn, episode_metrics): """Run one step of the training loop. If it is the agent's first turn during the episode, do not run an update step. Otherwise, run an update step based on the previous action and accumulated reward since then. Args: observation: Current observation that the agent should create an action for. turn (int): Agent whose turn it is. episode_metrics (Metrics): Keeps track of metrics for current episode. """ super().run_one_step(observation, turn, episode_metrics) agent = self._agents[turn] if self._transition_info.is_started(agent): info = self._transition_info.get_info(agent) if self._training: agent.update(copy.deepcopy(info)) episode_metrics[agent.id]["reward"] += info["reward"] episode_metrics[agent.id]["episode_length"] += 1 episode_metrics["full_episode_length"] += 1 else: self._transition_info.start_agent(agent) stacked_observation = self._transition_info.get_stacked_state( agent, observation ) action = agent.act(stacked_observation) next_observation, reward, done, turn, other_info = self._environment.step( action ) self._transition_info.record_info( agent, { "observation": observation, "action": action, "info": other_info, }, ) if self._self_play: self._transition_info.record_info( agent, { "agent_id": agent.id, }, ) self._transition_info.update_all_rewards(reward) return done, next_observation, turn
[docs] def run_end_step(self, episode_metrics, done=True): """Run the final step of an episode. After an episode ends, iterate through agents and update then with the final step in the episode. Args: episode_metrics (Metrics): Keeps track of metrics for current episode. done (bool): Whether this step was terminal. """ for agent in self._agents: if self._transition_info.is_started(agent): info = self._transition_info.get_info(agent, done=done) if self._training: agent.update(info) episode_metrics[agent.id]["episode_length"] += 1 episode_metrics["full_episode_length"] += 1 episode_metrics[agent.id]["reward"] += info["reward"]
[docs] def run_episode(self): """Run a single episode of the environment.""" episode_metrics = self.create_episode_metrics() done = False observation, turn = self._environment.reset() self._transition_info.reset() steps = 0 # Run the loop until the episode ends or times out while not done and steps < self._max_steps_per_episode: done, observation, turn = self.run_one_step( observation, turn, episode_metrics ) steps += 1 # Run the final update. self.run_end_step(episode_metrics, done) return episode_metrics
[docs]def set_up_experiment(config): """Returns a :py:class:`MultiAgentRunner` object based on the config and any command line arguments. Args: config: Configuration for experiment. """ # Parses arguments from the command line. args = get_parsed_args( { "seed": int, "train_steps": int, "test_frequency": int, "test_episodes": int, "max_steps_per_episode": int, "stack_size": int, "resume": bool, "run_name": str, "save_dir": str, "self_play": bool, "num_agents": int, } ) config.update(args) full_config = utils.Chomp(copy.deepcopy(config)) if "seed" in config: utils.seeder.set_global_seed(config["seed"]) # Set up environment environment, full_config["environment"] = envs.get_env( config["environment"], "environment" ) env_spec = environment.env_spec # Set up loggers logger_config = config.get("loggers", {"name": "NullLogger"}) if logger_config is None or len(logger_config) == 0: logger_config = {"name": "NullLogger"} if isinstance(logger_config, list): logger_config = { "name": "CompositeLogger", "kwargs": {"logger_list": logger_config}, } logger, full_config["loggers"] = loggers.get_logger(logger_config, "loggers") # Set up agents agents = [] full_config["agents"] = [] num_agents = config["num_agents"] if config["self_play"] else len(config["agents"]) for idx in range(num_agents): if not config["self_play"] or idx == 0: agent_config = config["agents"][idx] if config.get("stack_size", 1) > 1: agent_config["kwargs"]["obs_dim"] = ( config["stack_size"] * env_spec.obs_dim[idx][0], *env_spec.obs_dim[idx][1:], ) else: agent_config["kwargs"]["obs_dim"] = env_spec.obs_dim[idx] agent_config["kwargs"]["act_dim"] = env_spec.act_dim[idx] agent_config["kwargs"]["logger"] = logger if "replay_buffer" in agent_config["kwargs"]: replay_args = agent_config["kwargs"]["replay_buffer"]["kwargs"] replay_args["observation_shape"] = env_spec.obs_dim[idx] agent, full_agent_config = agent_lib.get_agent( agent_config, f"agents.{idx}" ) agents.append(agent) full_config["agents"].append(full_agent_config) else: agents.append(copy.copy(agents[0])) agents[-1]._id = idx # Set up experiment manager saving_schedule, full_config["saving_schedule"] = schedule.get_schedule( config["saving_schedule"], "saving_schedule" ) experiment_manager = experiment.Experiment( config["run_name"], config["save_dir"], saving_schedule ) experiment_manager.register_experiment( config=full_config, logger=logger, agents=agents, ) # Set up runner runner = MultiAgentRunner( environment, agents, logger, experiment_manager, config.get("train_steps", -1), config.get("test_frequency", -1), config.get("test_episodes", 1), config.get("stack_size", 1), config.get("self_play", False), config.get("max_steps_per_episode", 1e9), ) if config.get("resume", False): runner.resume() return runner
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("-c", "--config") parser.add_argument("-p", "--preset-config") parser.add_argument("-a", "--agent-config") parser.add_argument("-e", "--env-config") parser.add_argument("-l", "--logger-config") args, _ = parser.parse_known_args() if args.config is None and args.preset_config is None: raise ValueError("Config needs to be provided") config = load_config( args.config, args.preset_config, args.agent_config, args.env_config, args.logger_config, ) runner = set_up_experiment(config) runner.run_training()
if __name__ == "__main__": main()