Source code for hive.agents.random

import os

import numpy as np
import torch

from hive.agents.agent import Agent
from hive.utils.utils import seeder


[docs]class RandomAgent(Agent): """An agent that takes random steps at each timestep.""" def __init__(self, obs_dim, act_dim, id=0, logger=None): """ Args: obs_dim: The shape of the observations. act_dim (int): The number of actions available to the agent. id: Agent identifier. logger (ScheduledLogger): Logger used to log agent's metrics. """ super().__init__(obs_dim=obs_dim, act_dim=act_dim, id=id) self._rng = np.random.default_rng(seed=seeder.get_new_seed())
[docs] @torch.no_grad() def act(self, observation): """Returns a random action for the agent.""" action = self._rng.integers(self._act_dim) return action
[docs] def update(self, update_info): pass
[docs] def save(self, dname): torch.save( { "rng": self._rng, }, os.path.join(dname, "agent.pt"), )
[docs] def load(self, dname): checkpoint = torch.load(os.path.join(dname, "agent.pt")) self._rng = checkpoint["rng"]