mixed_precision_for_JAX#
- mpx package
- Module contents
DynamicLossScalingall_finite()calculate_scaled_grad()cast_function()cast_to_bfloat16()cast_to_float16()cast_to_float32()cast_to_full_precision()cast_to_half_precision()cast_tree()filter_grad()filter_value_and_grad()force_full_precision()half_precision_datatype()optimizer_update()scaled()select_tree()set_half_precision_datatype()
- Module contents