hive.agents.qnets.utils module

hive.agents.qnets.utils.calculate_output_dim(net, input_shape)[source]

Calculates the resulting output shape for a given input shape and network.

Parameters
  • net (torch.nn.Module) – The network which you want to calculate the output dimension for.

  • input_shape (int | tuple[int]) – The shape of the input being fed into the net. Batch dimension should not be included.

Returns

The shape of the output of a network given an input shape. Batch dimension is not included.

hive.agents.qnets.utils.create_init_weights_fn(initialization_fn)[source]

Returns a function that wraps initialization_function() and applies it to modules that have the weight attribute.

Parameters

initialization_fn (callable) – A function that takes in a tensor and initializes it.

Returns

Function that takes in PyTorch modules and initializes their weights. Can be used as follows:

init_fn = create_init_weights_fn(variance_scaling_)
network.apply(init_fn)

hive.agents.qnets.utils.calculate_correct_fan(tensor, mode)[source]

Calculate fan of tensor.

Parameters
  • tensor (torch.Tensor) – Tensor to calculate fan of.

  • mode (str) – Which type of fan to compute. Must be one of “fan_in”, “fan_out”, and “fan_avg”.

Returns

Fan of the tensor based on the mode.

hive.agents.qnets.utils.variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='uniform')[source]

Implements the tf.keras.initializers.VarianceScaling initializer in PyTorch.

Parameters
  • tensor (torch.Tensor) – Tensor to initialize.

  • scale (float) – Scaling factor (must be positive).

  • mode (str) – Must be one of “fan_in”, “fan_out”, and “fan_avg”.

  • distribution – Random distribution to use, must be one of “truncated_normal”, “untruncated_normal” and “uniform”.

Returns

Initialized tensor.

class hive.agents.qnets.utils.InitializationFn(fn)[source]

Bases: CallableType

A wrapper for callables that produce initialization functions.

These wrapped callables can be partially initialized through configuration files or command line arguments.

Parameters

fn – callable to be wrapped.

classmethod type_name()[source]
Returns

“init_fn”