hive.agents.qnets.utils module
- hive.agents.qnets.utils.calculate_output_dim(net, input_shape)[source]
Calculates the resulting output shape for a given input shape and network.
- Parameters
net (torch.nn.Module) – The network which you want to calculate the output dimension for.
input_shape (int | tuple[int]) – The shape of the input being fed into the
net
. Batch dimension should not be included.
- Returns
The shape of the output of a network given an input shape. Batch dimension is not included.
- hive.agents.qnets.utils.create_init_weights_fn(initialization_fn)[source]
Returns a function that wraps
initialization_function()
and applies it to modules that have theweight
attribute.- Parameters
initialization_fn (callable) – A function that takes in a tensor and initializes it.
- Returns
Function that takes in PyTorch modules and initializes their weights. Can be used as follows:
init_fn = create_init_weights_fn(variance_scaling_) network.apply(init_fn)
- hive.agents.qnets.utils.calculate_correct_fan(tensor, mode)[source]
Calculate fan of tensor.
- Parameters
tensor (torch.Tensor) – Tensor to calculate fan of.
mode (str) – Which type of fan to compute. Must be one of “fan_in”, “fan_out”, and “fan_avg”.
- Returns
Fan of the tensor based on the mode.
- hive.agents.qnets.utils.variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='uniform')[source]
Implements the
tf.keras.initializers.VarianceScaling
initializer in PyTorch.- Parameters
tensor (torch.Tensor) – Tensor to initialize.
scale (float) – Scaling factor (must be positive).
mode (str) – Must be one of “fan_in”, “fan_out”, and “fan_avg”.
distribution – Random distribution to use, must be one of “truncated_normal”, “untruncated_normal” and “uniform”.
- Returns
Initialized tensor.
- class hive.agents.qnets.utils.InitializationFn(fn)[source]
Bases:
CallableType
A wrapper for callables that produce initialization functions.
These wrapped callables can be partially initialized through configuration files or command line arguments.
- Parameters
fn – callable to be wrapped.