import argparse
import inspect
from copy import deepcopy
from functools import partial
from typing import List, Mapping, Sequence, _GenericAlias
import yaml
[docs]class Registrable:
"""Class used to denote which types of objects can be registered in the RLHive
Registry. These objects can also be configured directly from the command line, and
recursively built from the config, assuming type annotations are present.
"""
[docs] @classmethod
def type_name(cls):
"""This should represent a string that denotes the which type of class you are
creating. For example, "logger", "agent", or "env".
"""
raise ValueError
[docs]class Registry:
"""This is the Registry class for RLHive. It allows you to register different types
of :py:class:`Registrable` classes and objects and generates constructors for those
classes in the form of `get_{type_name}`.
These constructors allow you to construct objects from dictionary configs. These
configs should have two fields: `name`, which corresponds to the name used when
registering a class in the registry, and `kwargs`, which corresponds to the keyword
arguments that will be passed to the constructor of the object. These constructors
can also build objects recursively, i.e. if a config contains the config for
another `Registrable` object, this will be automatically created before being
passed to the constructor of the original object. These constructors also allow you
to directly specify/override arguments for object constructors directly from the
command line. These parameters are specified in dot notation. They also are able
to handle lists and dictionaries of Registrable objects.
For example, let's consider the following scenario:
Your agent class has an argument `arg1` which is annotated to be `List[Class1]`,
`Class1` is `Registrable`, and the `Class1` constructor takes an argument `arg2`.
In the passed yml config, there are two different Class1 object configs listed.
the constructor will check to see if both `--agent.arg1.0.arg2` and
`--agent.arg1.1.arg2` have been passed.
The parameters passed in the command line will be parsed according to the type
annotation of the corresponding low level constructor. If it is not one of
`int`, `float`, `str`, or `bool`, it simply loads the string into python using a
yaml loader.
Each constructor returns the object, as well a dictionary config with all the
parameters used to create the object and any Registrable objects created in the
process of creating this object.
"""
def __init__(self) -> None:
self._registry = {}
[docs] def register(self, name, constructor, type):
"""Register a Registrable class/object with RLHive.
Args:
name (str): Name of the class/object being registered.
constructor (callable): Callable that will be passed all kwargs from
configs and be analyzed to get type annotations.
type (type): Type of class/object being registered. Should be subclass of
Registrable.
"""
if not issubclass(type, Registrable):
raise ValueError(f"{type} is not Registrable")
if type.type_name() not in self._registry:
self._registry[type.type_name()] = {}
def getter(self, object_or_config, prefix=None):
if object_or_config is None:
return None, {}
elif isinstance(object_or_config, type):
return object_or_config, {}
name = object_or_config["name"]
kwargs = object_or_config.get("kwargs", {})
expanded_config = deepcopy(object_or_config)
try:
object_class = self._registry[type.type_name()][name]
parsed_args = get_callable_parsed_args(object_class, prefix=prefix)
kwargs.update(parsed_args)
kwargs, kwargs_config = construct_objects(
object_class, kwargs, prefix
)
expanded_config["kwargs"] = kwargs_config
return partial(object_class, **kwargs), expanded_config
except:
raise ValueError(f"Error creating {name} class")
setattr(self.__class__, f"get_{type.type_name()}", getter)
self._registry[type.type_name()][name] = constructor
[docs] def register_all(self, base_class, class_dict):
"""Bulk register function.
Args:
base_class (type): Corresponds to the `type` of the register function
class_dict (dict[str, callable]): A dictionary mapping from name to
constructor.
"""
for cls in class_dict:
self.register(cls, class_dict[cls], base_class)
def __repr__(self):
return str(self._registry)
[docs]def construct_objects(object_constructor, config, prefix=None):
"""Helper function that constructs any objects specified in the config that
are registrable.
Returns the object, as well a dictionary config with all the parameters used to
create the object and any Registrable objects created in the process of creating
this object.
Args:
object_constructor (callable): constructor of object that corresponds to
config. The signature of this function will be analyzed to see if there
are any :py:class:`Registrable` objects that might be specified in the
config.
config (dict): The kwargs for the object being created. May contain configs for
other `Registrable` objects that need to be recursively created.
prefix (str): Prefix that is attached to the argument names when looking for
command line arguments.
"""
signature = inspect.signature(object_constructor)
prefix = "" if prefix is None else f"{prefix}."
expanded_config = deepcopy(config)
for argument in signature.parameters:
if argument not in config:
continue
expected_type = signature.parameters[argument].annotation
if isinstance(expected_type, type) and issubclass(expected_type, Registrable):
config[argument], expanded_config[argument] = registry.__getattribute__(
f"get_{expected_type.type_name()}"
)(config[argument], f"{prefix}{argument}")
if isinstance(expected_type, _GenericAlias):
origin = expected_type.__origin__
args = expected_type.__args__
if (
(origin == List or origin == list)
and len(args) == 1
and isinstance(args[0], type)
and issubclass(args[0], Registrable)
and isinstance(config[argument], Sequence)
):
objs = []
expanded_config[argument] = []
for idx, item in enumerate(config[argument]):
obj, obj_config = registry.__getattribute__(
f"get_{args[0].type_name()}"
)(item, f"{prefix}{argument}.{idx}")
objs.append(obj)
expanded_config[argument].append(obj_config)
config[argument] = objs
elif (
origin == dict
and len(args) == 2
and isinstance(args[1], type)
and issubclass(args[1], Registrable)
and isinstance(config[argument], Mapping)
):
objs = {}
expanded_config[argument] = {}
for key, val in config[argument].items():
obj, obj_config = registry.__getattribute__(
f"get_{args[1].type_name()}"
)(val, f"{prefix}{argument}.{key}")
objs[key] = obj
expanded_config[argument][key] = obj_config
config[argument] = objs
return config, expanded_config
[docs]def get_callable_parsed_args(callable, prefix=None):
"""Helper function that extracts the command line arguments for a given function.
Args:
callable (callable): function whose arguments will be inspected to extract
arguments from the command line.
prefix (str): Prefix that is attached to the argument names when looking for
command line arguments.
"""
signature = inspect.signature(callable)
arguments = {
argument: signature.parameters[argument]
for argument in signature.parameters
if argument != "self"
}
return get_parsed_args(arguments, prefix)
[docs]def get_parsed_args(arguments, prefix=None):
"""Helper function that takes a dictionary mapping argument names to types, and
extracts command line arguments for those arguments. If the dictionary contains
a key-value pair "bar": int, and the prefix passed is "foo", this function will
look for a command line argument "--foo.bar". If present, it will cast it to an
int.
If the type for a given argument is not one of `int`, `float`, `str`, or `bool`,
it simply loads the string into python using a yaml loader.
Args:
arguments (dict[str, type]): dictionary mapping argument names to types
prefix (str): prefix that is attached to each argument name before searching
for command line arguments.
"""
prefix = "" if prefix is None else f"{prefix}."
parser = argparse.ArgumentParser()
for argument in arguments:
parser.add_argument(f"--{prefix}{argument}")
parsed_args, _ = parser.parse_known_args()
parsed_args = vars(parsed_args)
# Strip the prefix from the parsed arguments and remove arguments not present
parsed_args = {
(key[len(prefix) :] if key.startswith(prefix) else key): parsed_args[key]
for key in parsed_args
if parsed_args[key] is not None
}
for argument in parsed_args:
expected_type = arguments[argument]
if isinstance(expected_type, inspect.Parameter):
expected_type = expected_type.annotation
if expected_type in [int, str, float]:
parsed_args[argument] = expected_type(parsed_args[argument])
elif expected_type is bool:
value = str(parsed_args[argument]).lower()
parsed_args[argument] = not ("false".startswith(value) or value == "0")
else:
parsed_args[argument] = yaml.safe_load(parsed_args[argument])
return parsed_args
registry = Registry()