Source code for hive.envs.gym_env

import gym

from hive.envs.base import BaseEnv
from hive.envs.env_spec import EnvSpec

[docs]class GymEnv(BaseEnv): """ Class for loading gym environments. """ def __init__(self, env_name, num_players=1, **kwargs): """ Args: env_name (str): Name of the environment (NOTE: make sure it is available in gym.envs.registry.all()) num_players (int): Number of players for the environment. kwargs: Any arguments you want to pass to :py:meth:`create_env` or :py:meth:`create_env_spec` can be passed as keyword arguments to this constructor. """ self.create_env(env_name, **kwargs) super().__init__(self.create_env_spec(env_name, **kwargs), num_players)
[docs] def create_env(self, env_name, **kwargs): """Function used to create the environment. Subclasses can override this method if they are using a gym style environment that needs special logic. Args: env_name (str): Name of the environment """ self._env = gym.make(env_name)
[docs] def create_env_spec(self, env_name, **kwargs): """Function used to create the specification. Subclasses can override this method if they are using a gym style environment that needs special logic. Args: env_name (str): Name of the environment """ if isinstance(self._env.observation_space, gym.spaces.Tuple): obs_spaces = self._env.observation_space.spaces else: obs_spaces = [self._env.observation_space] if isinstance(self._env.action_space, gym.spaces.Tuple): act_spaces = self._env.action_space.spaces else: act_spaces = [self._env.action_space] return EnvSpec( env_name=env_name, obs_dim=[space.shape for space in obs_spaces], act_dim=[space.n for space in act_spaces], )
[docs] def reset(self): observation = self._env.reset() return observation, self._turn
[docs] def step(self, action): observation, reward, done, info = self._env.step(action) self._turn = (self._turn + 1) % self._num_players return observation, reward, done, self._turn, info
[docs] def render(self, mode="rgb_array"): return self._env.render(mode=mode)
[docs] def seed(self, seed=None): self._env.seed(seed=seed)
[docs] def close(self): self._env.close()