Skip to main content
Ctrl+K

MPX documentation

  • mixed_precision_for_JAX
  • mixed_precision_for_JAX

Section Navigation

  • mpx package
  • mixed_precision_for_JAX

mixed_precision_for_JAX#

  • mpx package
    • Module contents
      • DynamicLossScaling
        • DynamicLossScaling.adjust()
        • DynamicLossScaling.counter
        • DynamicLossScaling.factor
        • DynamicLossScaling.loss_scaling
        • DynamicLossScaling.min_loss_scaling
        • DynamicLossScaling.period
        • DynamicLossScaling.scale()
        • DynamicLossScaling.unscale()
      • all_finite()
      • calculate_scaled_grad()
      • cast_function()
      • cast_to_bfloat16()
      • cast_to_float16()
      • cast_to_float32()
      • cast_to_full_precision()
      • cast_to_half_precision()
      • cast_tree()
      • filter_grad()
      • filter_value_and_grad()
      • force_full_precision()
      • half_precision_datatype()
      • optimizer_update()
      • scaled()
      • select_tree()
      • set_half_precision_datatype()

previous

MPX documentation

next

mpx package

This Page

  • Show Source

© Copyright 2025, Alexander Gräfe.

Created using Sphinx 8.2.3.

Built with the PyData Sphinx Theme 0.16.1.