hive.runners.utils module
- hive.runners.utils.load_config(config=None, preset_config=None, agent_config=None, env_config=None, logger_config=None)[source]
Used to load config for experiments. Agents, environment, and loggers components in main config file can be overrided based on other log files.
- Parameters
config (str) – Path to configuration file. Either this or
preset_config
must be passed.preset_config (str) – Path to a preset hive config. This path should be relative to
hive/configs
. For example, the Atari DQN config would beatari/dqn.yml
.agent_config (str) – Path to agent configuration file. Overrides settings in base config.
env_config (str) – Path to environment configuration file. Overrides settings in base config.
logger_config (str) – Path to logger configuration file. Overrides settings in base config.
- class hive.runners.utils.Metrics(agents, agent_metrics, episode_metrics)[source]
Bases:
object
Class used to keep track of separate metrics for each agent as well general episode metrics.
- Parameters
agents (list[Agent]) – List of agents for which object will track metrics.
agent_metrics (list[(str, (callable | obj))]) – List of metrics to track for each agent. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable that takes no arguments and creates the initial metric.
episode_metrics (list[(str, (callable | obj))]) – List of non agent specific metrics to keep track of. Should be a list of tuples (metric_name, metric_init) where metric_init is either the initial value of the metric or a callable with no arguments that creates the initial metric.
- class hive.runners.utils.TransitionInfo(agents, stack_size)[source]
Bases:
object
Used to keep track of the most recent transition for each agent.
Any info that the agent needs to remember for updating can be stored here. Should be completely reset between episodes. After any info is extracted, it is automatically removed from the object. Also keeps track of which agents have started their episodes.
This object also handles padding and stacking observations for agents.
- Parameters
- is_started(agent)[source]
Check if agent has started its episode.
- Parameters
agent (Agent) – Agent to check.
- start_agent(agent)[source]
Set the agent’s start flag to true.
- Parameters
agent (Agent) – Agent to start.
- update_all_rewards(rewards)[source]
Update the rewards for all agents. If rewards is list, it updates the rewards according to the order of agents provided in the initializer. If rewards is a dict, the keys should be the agent ids for the agents and the values should be the rewards for those agents. If rewards is a float or int, every agent is updated with that reward.
- get_info(agent, terminated=False, truncated=False)[source]
Get all the info for the agent, and reset the info for that agent. Also adds a done value to the info dictionary that is based on the done parameter to the function.
- get_stacked_state(agent, observation)[source]
Create a stacked state for the agent. The previous observations recorded by this agent are stacked with the current observation. If not enough observations have been recorded, zero arrays are appended.
- Parameters
agent (Agent) – Agent to get stacked state for.
observation – Current observation.
- hive.runners.utils.zeros_like(x)[source]
Create a zero state like some state. This handles slightly more complex objects such as lists and dictionaries of numpy arrays and torch Tensors.
- Parameters
x (np.ndarray | torch.Tensor | dict | list) – State used to define structure/state of zero state.