jax-metal issues with IDEs

Hi,

following instructions at https://developer.apple.com/metal/jax/, jax works fine on M1 pro. However, only in Terminal. If you run Jupyter Notebook or Pycharm, the following always defaults to CPU.

from jax.lib import xla_bridge

print(xla_bridge.get_backend().platform)

I also notice that if you restart the Terminal, jax defaults to CPU only. You need to always set the virtual environment to jax-meta first to get Apple Silicon's GPU work:

python3 -m venv ~/jax-metal

source ~/jax-metal/bin/activate

Is there any way to make sure that Jupyter Notebook and other IDEs default to jax-metal? I'm currently only able to use it in Terminal after each time manually setting the virtual environment to jax-metal, which is annoying.

jax-metal issues with IDEs
 
 
Q