Skip to main content
Ctrl+K

MPX documentation

  • mixed_precision_for_JAX
  • mixed_precision_for_JAX

Index

A | C | D | F | H | L | M | O | P | S | U

A

  • adjust() (mpx.DynamicLossScaling method)
  • all_finite() (in module mpx)

C

  • calculate_scaled_grad() (in module mpx)
  • cast_function() (in module mpx)
  • cast_to_bfloat16() (in module mpx)
  • cast_to_float16() (in module mpx)
  • cast_to_float32() (in module mpx)
  • cast_to_full_precision() (in module mpx)
  • cast_to_half_precision() (in module mpx)
  • cast_tree() (in module mpx)
  • counter (mpx.DynamicLossScaling attribute)

D

  • DynamicLossScaling (class in mpx)

F

  • factor (mpx.DynamicLossScaling attribute)
  • filter_grad() (in module mpx)
  • filter_value_and_grad() (in module mpx)
  • force_full_precision() (in module mpx)

H

  • half_precision_datatype() (in module mpx)

L

  • loss_scaling (mpx.DynamicLossScaling attribute)

M

  • min_loss_scaling (mpx.DynamicLossScaling attribute)
  • module
    • mpx
  • mpx
    • module

O

  • optimizer_update() (in module mpx)

P

  • period (mpx.DynamicLossScaling attribute)

S

  • scale() (mpx.DynamicLossScaling method)
  • scaled() (in module mpx)
  • select_tree() (in module mpx)
  • set_half_precision_datatype() (in module mpx)

U

  • unscale() (mpx.DynamicLossScaling method)

© Copyright 2025, Alexander Gräfe.

Created using Sphinx 8.2.3.

Built with the PyData Sphinx Theme 0.16.1.