from importlib import import_module
import numpy as np
from hive.envs import BaseEnv
from hive.envs.env_spec import EnvSpec
[docs]class PettingZooEnv(BaseEnv):
"""
PettingZoo environment from https://github.com/PettingZoo-Team/PettingZoo
For now, we only support environments from PettingZoo with discrete actions.
"""
def __init__(
self,
env_name,
env_family,
num_players,
**kwargs,
):
"""
Args:
env_name (str): Name of the environment
env_family (str): Family of the environment such as "Atari",
"Classic", "SISL", "Butterfly", "MAgent", and "MPE".
num_players (int): Number of learning agents
"""
self._env_family = env_family
self.create_env(env_name, num_players, **kwargs)
self._env_spec = self.create_env_spec(env_name, **kwargs)
self._seed = None
super().__init__(self.create_env_spec(env_name, **kwargs), num_players)
[docs] def create_env(self, env_name, num_players, **kwargs):
env_module = import_module("pettingzoo." + self._env_family + "." + env_name)
self._env = env_module.env(players=num_players)
[docs] def create_env_spec(self, env_name, **kwargs):
"""
Each family of environments have their own type of observations and actions.
You can add support for more families here by modifying observation_space and action_space.
"""
if self._env_family in ["classic"]:
observation_space = [
space["observation"] for space in self._env.observation_spaces.values()
]
elif self._env_family in ["sisl"]:
observation_space = list(self._env.observation_spaces.values())
else:
raise ValueError(
f"Hive does not support {self._env_family} environments from PettingZoo yet."
)
return EnvSpec(
env_name=env_name,
observation_space=observation_space,
action_space=list(self._env.action_spaces.values()),
)
[docs] def reset(self):
self._env.reset(seed=self._seed)
observation, _, _, _, _ = self._env.last()
self._seed = None
for key in observation.keys():
observation[key] = np.array(observation[key], dtype=np.uint8)
self._turn = self._env.agents.index(self._env.agent_selection)
return observation, self._turn
[docs] def step(self, action):
self._env.step(action)
observation, _, terminated, truncated, info = self._env.last()
self._turn = (self._turn + 1) % self._num_players
for key in observation.keys():
observation[key] = np.array(observation[key], dtype=np.uint8)
return (
observation,
[self._env.rewards[agent] for agent in self._env.agents],
terminated,
truncated,
self._turn,
info,
)
[docs] def render(self, mode="rgb_array"):
return self._env.render(mode=mode)
[docs] def seed(self, seed=None):
self._seed = seed
[docs] def close(self):
self._env.close()