hive.runners.base module

class hive.runners.base.Runner(environment, agents, logger, experiment_manager, train_steps=1000000, test_frequency=10000, test_episodes=1, max_steps_per_episode=27000)[source]

Bases: ABC

Base Runner class used to implement a training loop.

Different types of training loops can be created by overriding the relevant functions.

Parameters
  • environment (BaseEnv) – Environment used in the training loop.

  • agents (list[Agent]) – List of agents that interact with the environment.

  • logger (ScheduledLogger) – Logger object used to log metrics.

  • experiment_manager (Experiment) – Experiment object that saves the state of the training.

  • train_steps (int) – How many steps to train for. If this is -1, there is no limit for the number of training steps.

  • test_frequency (int) – After how many training steps to run testing episodes. If this is -1, testing is not run.

  • test_episodes (int) – How many episodes to run testing for.

train_mode(training)[source]

If training is true, sets all agents to training mode. If training is false, sets all agents to eval mode.

Parameters

training (bool) – Whether to be in training mode.

create_episode_metrics()[source]

Create the metrics used during the loop.

run_one_step(observation, turn, episode_metrics)[source]

Run one step of the training loop.

Parameters
  • observation – Current observation that the agent should create an action for.

  • turn (int) – Agent whose turn it is.

  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

run_end_step(episode_metrics, done)[source]

Run the final step of an episode.

Parameters
  • episode_metrics (Metrics) – Keeps track of metrics for current episode.

  • done (bool) – Whether this step was terminal.

run_episode()[source]

Run a single episode of the environment.

run_training()[source]

Run the training loop.

run_testing()[source]

Run a testing phase.

resume()[source]

Resume a saved experiment.