hive.agents.legal_moves_rainbow module

class hive.agents.legal_moves_rainbow.LegalMovesRainbowAgent(representation_net, obs_dim, act_dim, optimizer_fn=None, loss_fn=None, init_fn=None, id=0, replay_buffer=None, discount_rate=0.99, n_step=1, grad_clip=None, reward_clip=None, update_period_schedule=None, target_net_soft_update=False, target_net_update_fraction=0.05, target_net_update_schedule=None, epsilon_schedule=None, test_epsilon=0.001, min_replay_history=5000, batch_size=32, device='cpu', logger=None, log_frequency=100, noisy=True, std_init=0.5, use_eps_greedy=False, double=True, dueling=True, distributional=True, v_min=0, v_max=200, atoms=51)[source]

Bases: RainbowDQNAgent

A Rainbow agent which supports games with legal actions.

Parameters
  • representation_net (FunctionApproximator) – A network that outputs the representations that will be used to compute Q-values (e.g. everything except the final layer of the DQN).

  • obs_dim (Tuple) – The shape of the observations.

  • act_dim (int) – The number of actions available to the agent.

  • id – Agent identifier.

  • optimizer_fn (OptimizerFn) – A function that takes in a list of parameters to optimize and returns the optimizer. If None, defaults to Adam.

  • loss_fn (LossFn) – Loss function used by the agent. If None, defaults to SmoothL1Loss.

  • init_fn (InitializationFn) – Initializes the weights of qnet using create_init_weights_fn.

  • replay_buffer (BaseReplayBuffer) – The replay buffer that the agent will push observations to and sample from during learning. If None, defaults to PrioritizedReplayBuffer.

  • discount_rate (float) – A number between 0 and 1 specifying how much future rewards are discounted by the agent.

  • n_step (int) – The horizon used in n-step returns to compute TD(n) targets.

  • grad_clip (float) – Gradients will be clipped to between [-grad_clip, grad_clip].

  • reward_clip (float) – Rewards will be clipped to between [-reward_clip, reward_clip].

  • update_period_schedule (Schedule) – Schedule determining how frequently the agent’s Q-network is updated.

  • target_net_soft_update (bool) – Whether the target net parameters are replaced by the qnet parameters completely or using a weighted average of the target net parameters and the qnet parameters.

  • target_net_update_fraction (float) – The weight given to the target net parameters in a soft update.

  • target_net_update_schedule (Schedule) – Schedule determining how frequently the target net is updated.

  • epsilon_schedule (Schedule) – Schedule determining the value of epsilon through the course of training.

  • test_epsilon (float) – epsilon (probability of choosing a random action) to be used during testing phase.

  • min_replay_history (int) – How many observations to fill the replay buffer with before starting to learn.

  • batch_size (int) – The size of the batch sampled from the replay buffer during learning.

  • device – Device on which all computations should be run.

  • logger (ScheduledLogger) – Logger used to log agent’s metrics.

  • log_frequency (int) – How often to log the agent’s metrics.

  • noisy (bool) – Whether to use noisy linear layers for exploration.

  • std_init (float) – The range for the initialization of the standard deviation of the weights.

  • use_eps_greedy (bool) – Whether to use epsilon greedy exploration.

  • double (bool) – Whether to use double DQN.

  • dueling (bool) – Whether to use a dueling network architecture.

  • distributional (bool) – Whether to use the distributional RL.

  • vmin (float) – The minimum of the support of the categorical value distribution for distributional RL.

  • vmax (float) – The maximum of the support of the categorical value distribution for distributional RL.

  • atoms (int) – Number of atoms discretizing the support range of the categorical value distribution for distributional RL.

create_q_networks(representation_net)[source]

Creates the qnet and target qnet.

preprocess_update_info(update_info)[source]

Preprocesses the update_info before it goes into the replay buffer. Clips the reward in update_info.

Parameters

update_info – Contains the information from the current timestep that the agent should use to update itself.

preprocess_update_batch(batch)[source]

Preprocess the batch sampled from the replay buffer.

Parameters

batch – Batch sampled from the replay buffer for the current update.

Returns

  • (tuple) Inputs used to calculate current state values.

  • (tuple) Inputs used to calculate next state values

  • Preprocessed batch.

Return type

(tuple)

act(observation)[source]

Returns the action for the agent. If in training mode, follows an epsilon greedy policy. Otherwise, returns the action with the highest Q-value.

Parameters

observation – The current observation.

class hive.agents.legal_moves_rainbow.LegalMovesHead(base_network)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, legal_moves)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

dist(x, legal_moves)[source]
training: bool
hive.agents.legal_moves_rainbow.action_encoding(action_mask)[source]