Registration

Registering new Classes

To register a new class with the Registry, you need to make sure that it or one of its ancestors subclassed hive.utils.registry.Registrable and provided a definition for hive.utils.registry.Registrable.type_name(). The return value of this function is used to create the getter function for that type (registry.get_{type_name}).

You can either register several different classes at once:

registry.register_all(
    Agent,
    {
        "DQNAgent": DQNAgent,
        "LegalMovesRainbowAgent": LegalMovesRainbowAgent,
        "RainbowDQNAgent": RainbowDQNAgent,
        "RandomAgent": RandomAgent,
    },
)

or one at a time:

registry.register("DQNAgent", DQNAgent, Agent)
registry.register("LegalMovesRainbowAgent", LegalMovesRainbowAgent, Agent)
registry.register("RainbowDQNAgent", RainbowDQNAgent, Agent)
registry.register("RandomAgent", RandomAgent, Agent)

After a class has been registered, you can use pass a config dictionary to the getter function for that type to create the object.

Callables

There are several cases where we want to parameterize some function or constructor partway, but not pass the fully created object in as an argument. One example is optimizers. You might want to pass a learning rate, but you cannot create the final optimizer object until you’ve created the parameters you want to optimize. To deal with such cases, we provide a CallableType class, which can be used to register and wrap any callable. For example, with optimizers, we have:

class OptimizerFn(CallableType):
    """A wrapper for callables that produce optimizer functions.

    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "optimizer_fn"
        """
        return "optimizer_fn"

registry.register_all(
    OptimizerFn,
    {
        "Adadelta": OptimizerFn(optim.Adadelta),
        "Adagrad": OptimizerFn(optim.Adagrad),
        "Adam": OptimizerFn(optim.Adam),
        "Adamax": OptimizerFn(optim.Adamax),
        "AdamW": OptimizerFn(optim.AdamW),
        "ASGD": OptimizerFn(optim.ASGD),
        "LBFGS": OptimizerFn(optim.LBFGS),
        "RMSprop": OptimizerFn(optim.RMSprop),
        "RMSpropTF": OptimizerFn(RMSpropTF),
        "Rprop": OptimizerFn(optim.Rprop),
        "SGD": OptimizerFn(optim.SGD),
        "SparseAdam": OptimizerFn(optim.SparseAdam),
    },
)

With this, we can now make use of the configurability of RLHive objects while still passing callables as arguments.