Skip to main content
Ctrl+K

MPX documentation

  • mixed_precision_for_JAX
  • mixed_precision_for_JAX

MPX documentation#

For basic usage, please read the README at Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX

Contents:

  • mixed_precision_for_JAX
    • mpx package
      • Module contents
        • DynamicLossScaling
          • DynamicLossScaling.adjust()
          • DynamicLossScaling.counter
          • DynamicLossScaling.factor
          • DynamicLossScaling.loss_scaling
          • DynamicLossScaling.min_loss_scaling
          • DynamicLossScaling.period
          • DynamicLossScaling.scale()
          • DynamicLossScaling.unscale()
        • all_finite()
        • calculate_scaled_grad()
        • cast_function()
        • cast_to_bfloat16()
        • cast_to_float16()
        • cast_to_float32()
        • cast_to_full_precision()
        • cast_to_half_precision()
        • cast_tree()
        • filter_grad()
        • filter_value_and_grad()
        • force_full_precision()
        • half_precision_datatype()
        • optimizer_update()
        • scaled()
        • select_tree()
        • set_half_precision_datatype()

next

mixed_precision_for_JAX

This Page

  • Show Source

© Copyright 2025, Alexander Gräfe.

Created using Sphinx 8.2.3.

Built with the PyData Sphinx Theme 0.16.1.