jnp.where is not working with Metal GPU backend


XlaRuntimeError Traceback (most recent call last) Cell In[49], line 4 1 arr = jnp.array( [7, 8, 9]) 3 # Find indices where the condition is True ----> 4 indices = jnp.where(arr > 1) 6 print(indices)

XlaRuntimeError: UNKNOWN

jnp.round doesn't work either...

Same here. More specifically the error is:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <ipython-input-5-64a76e03061b>:1:0: error: failed to legalize operation 'mhlo.pad'

and for jnp.round the error is:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <ipython-input-9-ea6c0ef3275e>:1:0: error: failed to legalize operation 'mhlo.round_nearest_even'

N.B. for the case of jnp.where(), specifying x and y args other than 'None' resolves the issue

Same issue. +1

Same issue with many JAX-based packages! +1

jnp.where is not working with Metal GPU backend
 
 
Q