mpx package#
Module contents#
Mixed Precision for JAX - A library for mixed precision training in JAX
- class mpx.DynamicLossScaling(loss_scaling: Array, min_loss_scaling: Array, factor: int = 2, period: int = 2000, counter=None)#
Bases:
ModuleImplements dynamic loss scaling for mixed precision training in JAX. The basic structure is taken from jmp. This class automatically adjusts the loss scaling factor during training to prevent numerical underflow/overflow when using reduced precision (e.g., float16). The scaling factor is increased periodically if gradients are finite, and decreased if non-finite gradients are detected, within specified bounds.
- Attributes:
loss_scaling (jnp.ndarray): Current loss scaling factor. min_loss_scaling (jnp.ndarray): Minimum allowed loss scaling factor. counter (jnp.ndarray): Counter for tracking update periods. factor (int): Multiplicative factor for adjusting loss scaling. period (int): Number of steps between potential increases of loss scaling.
- Methods:
- scale(tree):
Scales all leaves of a pytree by the current loss scaling factor.
- unscale(tree):
Unscales all leaves of a pytree by the inverse of the current loss scaling factor, casting the result to float32.
- adjust(grads_finite: jnp.ndarray) -> ‘DynamicLossScaling’:
Returns a new DynamicLossScaling instance with updated loss scaling and counter, depending on whether the gradients are finite.
- adjust(grads_finite: Array) DynamicLossScaling#
Adjust the loss scaling based on the finiteness of gradients and update the internal counter. It follows https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html and is directly adopted form JMP google-deepmind/jmp .
- Args:
- grads_finite (jnp.ndarray):
A boolean scalar (0-dimensional) indicating whether all gradients are finite. Must satisfy grads_finite.ndim == 0.
- Returns:
- DynamicLossScaling:
A new instance of DynamicLossScaling. Use this and replace the current instance with it.
- counter: Array#
- factor: int#
- loss_scaling: Array#
- min_loss_scaling: Array#
- period: int#
- scale(tree)#
Scales each element in the input tree by the loss scaling factor. This method applies a multiplication operation to every leaf in the given pytree, using the loss scaling factor (converted to jnp.float16) stored in the instance. It returns a new pytree where each element has been scaled accordingly.
- Args:
tree: A pytree (e.g., nested lists, tuples, dicts) containing numerical values that represent the data to be scaled.
- Returns:
A new pytree with each value multiplied by the loss scaling factor as a jnp.float16.
- unscale(tree)#
Unscales a pytree by multiplying each leaf element by the inverse of the loss scaling factor (in float32).
- Args:
- tree:
A pytree (nested structure of arrays, lists, tuples, dicts, etc.) where each leaf is a numeric array. These numerical values will be scaled by the computed inverse loss scaling factor.
- Returns:
A new pytree with the same structure as the input, where each numeric leaf is multiplied by 1 / loss_scaling (as a float32).
- mpx.all_finite(tree: PyTree) Array#
Checks if all elements in a PyTree of arrays are finite.
This function traverses the input PyTree, extracts all array leaves, and verifies whether all elements in these arrays are finite (i.e., not NaN or Inf).
- Args:
tree (PyTree): A PyTree containing arrays to be checked for finiteness.
- Returns:
Array: A scalar ndarray of type bool indicating whether all elements in the input PyTree are finite. Returns True if all elements are finite, otherwise False.
- mpx.calculate_scaled_grad(func, scaling: DynamicLossScaling, has_aux=False, use_mixed_precision=True) PyTree#
- mpx.cast_function(func, dtype, return_dtype=None)#
Casts the function to the specified data type.
- mpx.cast_to_bfloat16(x: PyTree) PyTree#
Casts the input PyTree to the bfloat16 data type.
- Args:
x (PyTree): A PyTree structure containing arrays or tensors to be cast.
- Returns:
PyTree: A PyTree with all arrays or tensors cast to the bfloat16 data type.
- mpx.cast_to_float16(x: PyTree) PyTree#
Casts all elements of a PyTree to the float16 data type.
- Args:
x (PyTree): A PyTree containing numerical data to be cast to float16.
- Returns:
PyTree: A new PyTree with all numerical elements cast to float16.
- mpx.cast_to_float32(x: PyTree) PyTree#
Cast the input PyTree to float32 data type.
This function takes a PyTree and casts all its elements to the float32 data type.
- Args:
x (PyTree): The input PyTree containing elements to be cast.
- Returns:
PyTree: A new PyTree with all elements cast to float32.
- mpx.cast_to_full_precision(x: PyTree) PyTree#
Casts all elements of a PyTree to full precision (float32).
- Args:
x (PyTree): The input PyTree containing elements to be cast.
- Returns:
PyTree: A new PyTree with all elements cast to float32 precision.
- mpx.cast_to_half_precision(x: PyTree) PyTree#
Cast the input PyTree to half precision.
This function converts all elements in the input PyTree to the half-precision datatype (either float16 or bfloat16), depending on the configuration set by set_half_precision_datatype.
- Args:
x (PyTree): The input PyTree containing elements to be cast to half precision.
- Returns:
PyTree: A new PyTree with all elements cast to the half-precision datatype.
- mpx.cast_tree(tree: PyTree, dtype)#
Casts all array elements in a PyTree to a specified data type. This function traverses a PyTree and applies a type casting operation to all array elements with dtype float (float16, bfloat16, float32), leaving all other elements unchanged. Args:
tree (PyTree): The input PyTree containing arrays and other objects. dtype (numpy.dtype or str): The target data type to cast the arrays to.
- Returns:
PyTree: A new PyTree with all array elements cast to the specified data type.
- mpx.filter_grad(func, scaling: DynamicLossScaling, has_aux=False, use_mixed_precision=True) PyTree#
Filters the gradients of a function based on a predicate.
This function computes the gradients of the given function func with respect to its arguments (args and kwargs). It then filters the gradients based on a predicate function that checks whether the gradients are finite. The filtered gradients are returned as a new pytree.
- Args:
func (callable): The function to compute gradients for. This function must only use pytrees as parameters! has_aux (bool): If True, the function is expected to return auxiliary values along with the gradients. use_mixed_precision (bool, optional): If True, the function will be cast to half precision. Defaults to True.
- Returns:
callable: A function that computes the filtered gradients of func. It returns the grad, the new loss scaling, and a boolean indicating whether the gradients are finite (and the aux-value if has_aux is true).
- mpx.filter_value_and_grad(func, scaling: DynamicLossScaling, has_aux=False, use_mixed_precision=True) PyTree#
Wraps a function to compute its value and gradient with support for mixed precision and dynamic loss scaling.
- Args:
func (Callable): The function for which the value and gradient are to be computed. scaling (loss_scaling.DynamicLossScaling): An instance of DynamicLossScaling to handle loss scaling and gradient unscaling. has_aux (bool, optional): Indicates whether the function func returns auxiliary outputs along with the main value. Defaults to False. use_mixed_precision (bool, optional): If True, the function will be cast to half precision. Defaults to True.
- Returns:
Callable: A wrapped function that computes the value, gradient, and additional information:
- If has_aux is True:
((value, aux), loss_scaling_new, grads_finite, grad)
- If has_aux is False:
(value, loss_scaling_new, grads_finite, grad)
- Where:
value: The computed value of the function.
aux: Auxiliary outputs returned by the function (if has_aux is True).
loss_scaling_new: The updated loss scaling object.
grads_finite: A boolean indicating whether all gradients are finite.
grad: The computed gradients, unscaled.
- mpx.force_full_precision(func, return_dtype=<class 'jax.numpy.float16'>)#
A decorator to enforce full precision (float32) for the inputs and outputs of a function. This decorator ensures that all array arguments passed to the decorated function are converted to float32 precision before the function is executed. Additionally, it converts the outputs of the function to the specified return_dtype if they are arrays.
This might come in handy for functions that become numerically unstable in lower precision. Note that some jax-functions do this automatically (e.g., jnp.sum).
- Args:
func (callable): The function to be decorated. return_dtype (dtype): The desired data type for the function’s output arrays.
- Returns:
callable: The wrapped function with enforced input and output precision.
- Example:
@force_full_precision def my_function(x, y):
return x + y
# All array inputs to my_function will be cast to float32, and the output # will be cast to the specified return_dtype if it is an array.
- mpx.half_precision_datatype()#
- mpx.optimizer_update(model: PyTree, optimizer: GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool)#
- mpx.scaled(func: callable, scaling: DynamicLossScaling, has_aux: bool = False) callable#
Scales the output of a function using dynamic loss scaling. This decorator wraps a given function such that its output is scaled using the provided dynamic loss scaling object. If the wrapped function returns auxiliary data (indicated by has_aux=True), only the primary value is scaled; otherwise, the sole returned value is scaled. Parameters:
func (callable): The original function whose output is to be scaled. scaling (DynamicLossScaling): An object providing a scale method for scaling
the function’s output.
- has_aux (bool, optional): Flag indicating whether the wrapped function returns
a tuple (value, aux) where only the value should be scaled. Defaults to False.
- Returns:
callable: A new function that wraps the original function’s behavior by applying the dynamic loss scaling to its result.
- mpx.select_tree(pred: Array, a: PyTree, b: PyTree) PyTree#
Selects elements from one of two pytrees based on a scalar boolean predicate.
This function traverses two input pytrees (a and b) and selects elements from either a or b based on the value of the scalar boolean pred. If pred is True, elements from a are selected; otherwise, elements from b are selected. Non-array elements in the pytrees are taken directly from a.
- Args:
pred (jnp.ndarray): A scalar boolean array (jnp.bool_) that determines which pytree to select elements from. a (PyTree): The first pytree to select elements from. b (PyTree): The second pytree to select elements from.
- Returns:
PyTree: A new pytree with elements selected from a or b based on pred.
- Raises:
AssertionError: If pred is not a scalar boolean array (jnp.bool_).
- mpx.set_half_precision_datatype(datatype)#
Set the half precision datatype for the module.
- Args:
datatype: The datatype to set as half precision (e.g., jnp.float16).