MPX documentation#
For basic usage, please read the README at Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX
Contents:
- mixed_precision_for_JAX
- mpx package
- Module contents
DynamicLossScalingall_finite()calculate_scaled_grad()cast_function()cast_to_bfloat16()cast_to_float16()cast_to_float32()cast_to_full_precision()cast_to_half_precision()cast_tree()filter_grad()filter_value_and_grad()force_full_precision()half_precision_datatype()optimizer_update()scaled()select_tree()set_half_precision_datatype()
- Module contents
- mpx package