Source code for hive.utils.schedule

import abc

from hive.utils.registry import Registrable, registry


[docs]class Schedule(abc.ABC, Registrable):
[docs] @abc.abstractmethod def get_value(self): """Returns the current value of the variable we are tracking""" pass
[docs] @abc.abstractmethod def update(self): """Update the value of the variable we are tracking and return the updated value. The first call to update will return the initial value of the schedule.""" pass
[docs] @classmethod def type_name(cls): return "schedule"
[docs]class LinearSchedule(Schedule): """Defines a linear schedule between two values over some number of steps. If updated more than the defined number of steps, the schedule stays at the end value. """ def __init__(self, init_value, end_value, steps): """ Args: init_value (int | float): Starting value for schedule. end_value (int | float): End value for schedule. steps (int): Number of steps for schedule. Should be positive. """ steps = max(int(steps), 1) self._delta = (end_value - init_value) / steps self._end_value = end_value self._value = init_value - self._delta
[docs] def get_value(self): return self._value
[docs] def update(self): if self._value == self._end_value: return self._value self._value += self._delta # Check if value is over the end_value if ((self._value - self._end_value) > 0) == (self._delta > 0): self._value = self._end_value return self._value
def __repr__(self): return ( f"<class {type(self).__name__}" f" value={self.get_value()}" f" delta={self._delta}" f" end_value={self._end_value}>" )
[docs]class ConstantSchedule(Schedule): """Returns a constant value over the course of the schedule""" def __init__(self, value): """ Args: value: The value to be returned. """ self._value = value
[docs] def get_value(self): return self._value
[docs] def update(self): return self._value
def __repr__(self): return f"<class {type(self).__name__} value={self.get_value()}>"
[docs]class SwitchSchedule(Schedule): """Returns one value for the first part of the schedule. After the defined number of steps is reached, switches to returning a second value. """ def __init__(self, off_value, on_value, steps): """ Args: off_value: The value to be returned in the first part of the schedule. on_value: The value to be returned in the second part of the schedule. steps (int): The number of steps after which to switch from the off value to the on value. """ self._steps = 0 self._flip_step = steps self._off_value = off_value self._on_value = on_value
[docs] def get_value(self): if self._steps <= self._flip_step: return self._off_value else: return self._on_value
[docs] def update(self): self._steps += 1 value = self.get_value() return value
def __repr__(self): return ( f"<class {type(self).__name__}" f" value={self.get_value()}" f" steps={self._steps}" f" off_value={self._off_value}" f" on_value={self._on_value}" f" flip_step={self._flip_step}>" )
[docs]class DoublePeriodicSchedule(Schedule): """Returns off value for off period, then switches to returning on value for on period. Alternates between the two. """ def __init__(self, off_value, on_value, off_period, on_period): """ Args: on_value: The value to be returned for the on period. off_value: The value to be returned for the off period. on_period (int): the number of steps in the on period. off_period (int): the number of steps in the off period. """ self._steps = -1 self._off_period = off_period self._total_period = self._off_period + on_period self._off_value = off_value self._on_value = on_value
[docs] def get_value(self): if (self._steps % self._total_period) < self._off_period: return self._off_value else: return self._on_value
[docs] def update(self): self._steps += 1 return self.get_value()
def __repr__(self): return ( f"<class {type(self).__name__}" f" value={self.get_value()}" f" steps={self._steps}" f" off_value={self._off_value}" f" on_value={self._on_value}" f" off_period={self._off_period}" f" on_period={self._total_period - self._off_period}>" )
[docs]class PeriodicSchedule(DoublePeriodicSchedule): """Returns one value on the first step of each period of a predefined number of steps. Returns another value otherwise. """ def __init__(self, off_value, on_value, period): """ Args: on_value: The value to be returned on the first step of each period. off_value: The value to be returned for every other step in the period. period (int): The number of steps in the period. """ super().__init__(off_value, on_value, period - 1, 1) def __repr__(self): return ( f"<class {type(self).__name__}" f" value={self.get_value()}" f" steps={self._steps}" f" off_value={self._off_value}" f" on_value={self._on_value}" f" period={self._off_period + 1}>" )
registry.register_all( Schedule, { "LinearSchedule": LinearSchedule, "ConstantSchedule": ConstantSchedule, "SwitchSchedule": SwitchSchedule, "PeriodicSchedule": PeriodicSchedule, "DoublePeriodicSchedule": DoublePeriodicSchedule, }, ) get_schedule = getattr(registry, f"get_{Schedule.type_name()}")