Source code for hive.envs.minatar.minatar

from importlib import import_module

import numpy as np

from hive.envs.base import BaseEnv
from hive.envs.env_spec import EnvSpec


[docs]class MinAtarEnv(BaseEnv): """ Class for loading MinAtar environments. See https://github.com/kenjyoung/MinAtar. """ def __init__( self, env_name, sticky_action_prob=0.1, difficulty_ramping=True, ): """ Args: env_name (str): Name of the environment sticky_actions (bool): Whether to use sticky_actions as per Machado et al. difficulty_ramping (bool): Whether to periodically increase difficulty. """ env_module = import_module("minatar.environments." + env_name) self.env_name = env_name self._env = env_module.Env(ramping=difficulty_ramping) self.n_channels = self._env.state_shape()[2] self.sticky_action_prob = sticky_action_prob self.last_action = 0 self.visualized = False self.closed = False super().__init__(self.create_env_spec(env_name), num_players=1)
[docs] def create_env_spec(self, env_name): obs_dim = tuple(self._env.state_shape()) new_positions = [2, 0, 1] obs_dim = tuple(obs_dim[i] for i in new_positions) return EnvSpec( env_name=env_name, obs_dim=[obs_dim], act_dim=[6], )
[docs] def reset(self): self._env.reset() return np.transpose(self._env.state(), [2, 1, 0]), 0
[docs] def seed(self, seed=None): self._env.seed(seed=seed)
[docs] def step(self, action=None): """ Remarks: * Execute self.frame_skips steps taking the action in the the environment. * This may execute fewer than self.frame_skip steps in the environment, if the done state is reached. * Furthermore, in this case the returned observation should be ignored. """ assert action is not None reward, done = self._env.act(action) reward = float(reward) info = {} observation = np.transpose(self._env.state(), [2, 1, 0]) return observation, reward, done, None, info