Jax-metal dependencies issue - now requires ml_dtypes==0.2.0

The JAX ml_dtypes module was recently updated to 0.3.0 - as part of this change, the 'float8_e4m3b11' dtype has been deprecated, with newer versions of JAX also reflecting this change. The new ml_dtypes version now seems to be incompatible with JAX v0.4.11.

As jax-metal currently requires JAX v0.4.11, perhaps the dependencies list should be updated to include ml_dtypes==0.2.0 in order to prevent the following import error:

AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'

Which essentially makes JAX unusable on import (and appears to be fixed by pip install ml_dtypes==0.2.0)

I can confirm that this is still and issue and makes JAX and TensorFlow unusable on a Mac unless pip install ml_dtypes==0.2.0 is explicitly declared. Please fix. :)

Jax-metal dependencies issue - now requires ml_dtypes==0.2.0
 
 
Q