Metal breaks jax.debug_callback

When developping jax code locally, I use jax's debug_callback. Metal does not implement it.

NotImplementedError: MLIR translation rule for primitive 'debug_callback' not found for platform METAL

Metal breaks jax.debug_callback
 
 
Q