Hi,
Are there plans to support complex numbers?
Something simple like this:
def return_complex(x):
return x*1+1.0j
x = jnp.ones((10))
print(return_complex(x))
results in an error.
Hi,
Are there plans to support complex numbers?
Something simple like this:
def return_complex(x):
return x*1+1.0j
x = jnp.ones((10))
print(return_complex(x))
results in an error.
When I run the snippet below:
import jax
import jax.numpy as jnp
def return_complex(x):
return x*1+1.0j
x = jnp.ones((10))
print(return_complex(x))
I get no errors.
https://colab.research.google.com/notebooks/welcome.ipynb#scrollTo=VpzLDQeWxgyY&line=8&uniqifier=1
It runs on other platforms (like Colab) but not on jax-metal.
There is a discussion here on metal - https://github.com/google/jax/issues/8074
Any follow-up for complex numbers on m1 yet? I am facing the same issue too.