hi,
I am currently running LSTM on TensorFlow. However, when i switched from keras2 to keras3. code running time has increased 10 times -- it seems there is no GPU acceleration.
Here is my code:
batch size = 256
optimiser = adam
activation = tanh
_______________________________________________
Layer (type) Output Shape Param #
=============================================
input_1 (InputLayer) [(None, 7, 16)] 0
bidirectional (Bidirection (None, 7, 320) 226560
al)
bidirectional_1 (Bidirecti (None, 7, 512) 1181696
onal)
bidirectional_2 (Bidirecti (None, 256) 656384
onal)
dense (Dense) (None, 1) 257
==============================================
Total params: 2064897 (7.88 MB)
Trainable params: 2064897 (7.88 MB)
Non-trainable params: 0 (0.00 Byte)
______________________________________________
This is keras 3.6.0 + tensorflow 2.17.0 + tensorflow-metal 1.1.0 training status:
Training------------
Epoch 1/200
28/681 ━━━━━━━━━━━━━━━━━━━━ 8:13 756ms/step - loss: 0.5901 - mape: 338.6876 - mse: 0.8591
This is keras 2.14.0 + tensorflow 2.14.0 + tensorflow-metal 1.1.0 training status:
Training------------
Epoch 1/200
681/681 [==============================] - 37s 49ms/step - loss: 3.6345 - mape: 499038.7500 - mse: 34.4148 - val_loss: 3.5452 - val_mape: 41.7964 - val_mse: 32.0133 - lr: 0.0010
Is that because keras3 has no GPU support on macos?
Apart from that, if I change LSTM activation from tanh to sigmoid in keras2, it does not have GPU support as well.
My system is 15.0.1 and the code was running on python3.11
I am not sure why these happen.
Thanks
tensorflow-metal
RSS for tagTensorFlow accelerates machine learning model training with Metal on Mac GPUs.
Posts under tensorflow-metal tag
55 Posts
Sort by:
Post
Replies
Boosts
Views
Activity
I was working on my project and when I tried to train a model the kernel crashed, so I restarted the kernel and tried the same and still I got the same crashing issue. Then I read one of the thread having the same issue where the apple support was saying to install tensorflow-macos and tensorflow-metal and read the guide from this site:
https://developer.apple.com/metal/tensorflow-plugin/
and I did so, I tried every single thing and when I tried the test code provided in the site, I got the same error, here's the code and the output.
Code:
import tensorflow as tf
cifar = tf.keras.datasets.cifar100
(x_train, y_train), (x_test, y_test) = cifar.load_data()
model = tf.keras.applications.ResNet50(
include_top=True,
weights=None,
input_shape=(32, 32, 3),
classes=100,)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
model.fit(x_train, y_train, epochs=5, batch_size=64)
and here's the output:
Epoch 1/5
The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure.
Click here for more info.
View Jupyter log for further details.
And here's the half of log file as it was not fully coming:
metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-10-06 23:30:49.894405: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-10-06 23:30:49.894420: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-10-06 23:30:49.894444: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-10-06 23:30:49.894460: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )
2024-10-06 23:30:56.701461: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
[libprotobuf FATAL google/protobuf/message_lite.cc:353] CHECK failed: target + size == res:
libc++abi: terminating due to uncaught exception of type google::protobuf::FatalException: CHECK failed: target + size == res:
Please respond to this post as soon as possible as I am working on my project now and getting this error again n again.
Device: Apple MacBook Air M1.
The metal plugin for TensorFlow had its GitHub repo taken down, and on pypi, the last update was a year ago for TF 2.14. What's the status on the metal plugin? For now it seems to work fine for TF 2.15 but what's the plan for the future?
The following code taken from keras.io produces the error
InternalError: Exception encountered when calling GPT2Tokenizer.call().
...
2 root error(s) found.
(0) INTERNAL: stream cannot wait for itself
Macos on Macbook, M2 Max. Setting the optimizer to "Adam" does not help.
import keras_nlp # version 0.15
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
causal_lm.compile(sampler="greedy")
# the next call produces the error
causal_lm.generate(["Keras is a"])
Following this instruction to install jax (https://developer.apple.com/metal/jax/), I still encountered this error:
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. This error is frequently encountered on macOS when running an x86 Python installation on ARM hardware. In this case, try installing an ARM build of Python. Otherwise, you may be able work around this issue by building jaxlib from source.
How to fix it?
getting this error again and again even if I tried reinstalling.
Traceback (most recent call last):
File "", line 1, in
File "/Users/aman/LLM/env/lib/python3.8/site-packages/tensorflow/init.py", line 439, in
_ll.load_library(_plugin_dir)
File "/Users/aman/LLM/env/lib/python3.8/site-packages/tensorflow/python/framework/load_library.py", line 151, in load_library
py_tf.TF_LoadLibrary(lib)
tensorflow.python.framework.errors_impl.NotFoundError: dlopen(/Users/aman/LLM/env/lib/python3.8/site-packages/tensorflow-plugins/libmetal_plugin.dylib, 0x0006): Symbol not found: OBJC_CLASS$_MPSGraphRandomOpDescriptor
Referenced from: /Users/aman/LLM/env/lib/python3.8/site-packages/tensorflow-plugins/libmetal_plugin.dylib
Expected in: /System/Library/Frameworks/MetalPerformanceShadersGraph.framework/Versions/A/MetalPerformanceShadersGraph
I've been attempting to install tf metal on my computer so that I can use GPUs instead of CPUs. I have tf macOS installed already, and I am fully updated with pip and tf. I'm currently 2 months into building and training a tf CNN, and I'm at the point where training a single epoch for my network will take a week (I have a lot of data that I need to use). I desperately need to use GPUs but am stuck with CPUs for now. I can't get access to a cluster, so the best I can do is continue to use my M2 MacBook. Is there any other way I can install TF metal? Is there a way I can use GPUs (rather than CPUs) when using TF if I can't get install metal?
I keep getting this error message:
"ERROR: Could not find a version that satisfies the requirement tensorflow-metal (from versions: none) ERROR: No matching distribution found for tensorflow-metal"
I looked on apple forums, tried to download it from GitHub (the page is down), and anything else I could think of and/or find on the internet to help, but it still isn't installing.
I've used the following commands and still no luck:
python -m pip install tensorflow-metal
pip install https://github.com/apple/tensorflow_metal/releases/download/v0.5.0/tensorflow_metal-0.5.0-py3-none-any.whl
pip install tensorflow-metal
pip3 install tensorflow-metal
SYSTEM_VERSION_COMPAT=0 python -m pip install tensorflow-metal
SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal
conda install -c anaconda tensorflow-gpu
Any help would be appreciated! Thanks so much!
"Last year, I upgraded to an M2 Max laptop, expecting that tensorflow-metal would facilitate effective local prototyping utilizing the Apple Silicon's capabilities.
It has been quite some time since tensorflow-metal was last updated, and there appear to be several unresolved issues noted by the community here. I've personally observed the following behavior with my setup:
Without tensorflow-metal:
import tensorflow as tf
for _ in range(10):
print(tf.random.normal((3,)).numpy())
[-1.4213976 0.08230731 -1.1260201 ]
[ 1.2913705 -0.47693467 -1.2886043 ]
[ 0.09144169 -1.0892165 0.9313669 ]
[ 1.1081179 0.9865657 -1.0298151]
[ 0.03328908 -0.00655857 -0.02662632]
[-1.002391 -1.1873596 -1.1168724]
[-1.2135247 -1.2823236 -1.0396363]
[-0.03492929 -0.9228362 0.19147137]
[-0.59353966 0.502279 0.80000925]
[-0.82247525 -0.13076428 0.99579334]
With tensorflow-metal:
import tensorflow as tf
for _ in range(10):
print(tf.random.normal((3,)).numpy())
[ 1.0031303 0.8095635 -0.0610961]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
[-1.3544159 0.7045493 0.03666191]
Given these observations, it seems there may be an issue with the randomness of tf.random.normal when using tensorflow-metal.
My current setup includes MacOS 14.5, tensorflow 2.14.1, and tensorflow-macos 2.14.1. I am interested in understanding if there are known solutions or workarounds for this behavior.
Furthermore, could anyone provide an update on whether tensorflow-metal is still being actively developed, or if alternative approaches are recommended for utilizing the GPU capabilities of this hardware?
Hello,
I’m currently working on Tiny ML or ML on Edge using the Google Colab platform. Due to the exhaust of my compute unit’s free usage, I’m being prompted to pay. I’ve been considering leveraging the GPU capabilities of my iPad M1 and Intel-based Mac. Both devices utilize Thunderbolt ports capable of sharing connections up to 30GB/s. Since I’m primarily using a classification model, extensive GPU usage isn’t necessary.
I’m looking for assistance or guidance on utilizing the iPad’s processor as an eGPU on my Mac, possibly through an API or Apple technology. Any help would be greatly appreciated!
The Keras Embedding layer cannot be calculated on Metal because of the missing Op:StatelessRandomGetKeyCounter, as shown in this error message:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Could not satisfy device specification '/job:localhost/replica:0/task:0/device:GPU:0'. enable_soft_placement=0. Supported device types [CPU]. All available devices [/job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:CPU:0]. [Op:StatelessRandomGetKeyCounter]
A workaround is to enable soft placement, but this obviously is slower:
tf.config.set_soft_device_placement(True)
Reporting it here as recommended by the TensorFlow Plugin Metal team.
I consistently receive corrupted results from tf.signal.fft3d() when it is within a function that has a @tf.function decorator. The results are all zero (0.) for entries after a certain x index (see image). Surprisingly, the issue depends on the matrix size. For example, (1023, 1023, 287) works but (1023, 1023, 575) does not. The issue is problematic because it occurs silently and not for all matrix sizes, i.e. can easily slip through tests.
The error occurs only when tensorflow-metal is installed. The Tensorflow version is 2.16.1. My hardware is a Macbook Pro M3 Max with 40 GPU cores, 128 GB RAM running MacOS Sonoma version 14.5 (23F79). A Python environment to reproduce the bug can be created as follows:
conda create --name tfmetalbug python=3.11.9
conda activate tfmetalbug
pip install tensorflow tensorflow-metal
conda install matplotlib
The following code reproduces the issue:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Wrap fft3d with tf.function
@tf.function
def fft3d_wrapper_function(x):
return tf.signal.fft3d(x)
# Generate a 3D image
img = tf.random.normal(shape=(1023, 1023, 575), stddev=1., dtype=float) # generate random 3d image
img = tf.dtypes.cast(img, tf.complex64) # convert to complex values
# Compute the 3D FFT
img_fft = fft3d_wrapper_function(img)
# Visualize the 3D FFT
plt.imshow(np.real(img_fft)[:, img_fft.shape[1]//2+10, :], cmap="gray", vmin=-0.001, vmax=0.001)
plt.savefig("fft3d_wrapper_function.png")
For me, removing the @tf.function decorator has resolved the issue.
I'm trying to convert a TensorFlow model that I didn't create and know approximately nothing about to CoreML so that I can use it in some functional tests. I can't tell you much about the model, but you can read about it on the blog from the team that created it: https://research.google/blog/improving-mobile-app-accessibility-with-icon-detection/
I can't convert this model to a TensorFlow Lite model because it uses a few full TensorFlow operations (which I could work around) and it exceeds the 4-tensor output limit (which I can't, AFAIK). So instead, I'm trying to convert the model to CoreML so that I can run it on-device.
The issue I'm running into is that every approach fails in different ways. If I load the model with tf.saved_model.load and pass that as the first parameter to the convert call, it says
NotImplementedError: Expected model format: [SavedModel | concrete_function | tf.keras.Model | .h5 | GraphDef], got <tensorflow.python.trackable.autotrackable.AutoTrackable object at 0x30d90c250>
If I pass model.signatures['serving_default'] as the first parameter to convert, I get
NotImplementedError: Expected model format: [SavedModel | concrete_function | tf.keras.Model | .h5 | GraphDef], got ConcreteFunction [...a page or two of info about the function here...]
If I try to wrap it in a Keras layer using the instructions provided in the converter, it fails because a sequential model can't have multiple outputs.
If I try to use a tf.keras.layers.TFSMLayer to load the model, it fails because there are multiple tags, and there's no way to specify tags when constructing the layer. (It tells me that I need to add 'tags' to load the model, but if I do that, it tells me that tags isn't a valid parameter to the call.)
If I load the model with tf.saved_model.load and specify a single tag, then re-save it in a different location with tf.saved_model.save to generate a new model with only a single tag, then do
input_layer = tf.keras.Input(shape=(768, 768, 3), dtype="int8")
layer = tf.keras.layers.TFSMLayer("./serve_model", call_endpoint='serving_default')
outputs = layer(input_layer)
model = tf.keras.Model(input_layer, outputs)
I get
AttributeError: 'Functional' object has no attribute '_get_save_spec'
At one point, I also tried this:
class LayerFromSavedModel(tf.keras.layers.Layer):
def __init__(self):
super(LayerFromSavedModel, self).__init__()
self.vars = legacy_model.variables
def call(self, inputs):
return legacy_model.signatures['serving_default'](inputs)
input = tf.keras.Input(shape=(3000, 3000, 3))
model = tf.keras.Model(input, LayerFromSavedModel()(input))
and saw a similar failure.
I've run out of ideas here. Is there simply no support whatsoever in the converter for importing a TensorFlow 2 SavedModel into CoreML, or am I missing something fundamental?
followed instruction in
https://developer.apple.com/metal/jax/
I got
Successfully installed importlib-metadata-7.1.0 jax-0.4.28 jax-metal-0.0.7 jaxlib-0.4.28 opt-einsum-3.3.0 scipy-1.13.0 six-1.16.0 zipp-3.18.2
but the test failed
python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "", line 1, in
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/init.py", line 37, in
import jax.core as _core
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/core.py", line 18, in
from jax._src.core import (
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/core.py", line 39, in
from jax._src import dtypes
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/dtypes.py", line 33, in
from jax._src import config
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/config.py", line 27, in
from jax._src import lib
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/lib/init.py", line 84, in
cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.
Cannot assign a device for operation encoder/down1/downs_0/conv1/weight/Initializer/random_uniform/RandomUniform: Could not satisfy explicit device specification '' because the node {{colocation_node encoder/down1/downs_0/conv1/weight/Initializer/random_uniform/RandomUniform}} was colocated with a group of nodes that required incompatible device '/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0].
Colocation Debug Info:
Colocation group had the following types and supported devices:
Root Member(assigned_device_name_index_=-1 requested_device_name_='/device:GPU:0' assigned_device_name_='' resource_device_name_='/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]
Identity: GPU CPU
Mul: GPU CPU
AddV2: GPU CPU
Sub: GPU CPU
RandomUniform: GPU CPU
Assign: CPU
VariableV2: GPU CPU
Const: GPU CPU
Regardless of the installation version combinations of tensorflow & metal (2.14, 2.15, 2.16), I find a metal/non-metal incompatibility for some layer types. For the GRU layer, for example, metal-trained weights (model.save_weights()/load_weights()) are not compatible with inference using the CPU. That is, train a model using metal, run inference using metal, then run inference again after uninstalling metal, and the results differ -- sometimes a night and day difference. This essentially eliminates the usefulness of tensorflow-metal for me. From my limited testing, models using other, simple combinations of layer types including Dense and LSTM do not show this problem. Just the GRU. And by "testing" I mean really simple models, like one GRU layer. Apple Framework Metal Team: You are doing very useful work, and I kindly ask, please address this bug :)
I noticed from the system requirements, TensorFlow only seems to support Python. Are there any plans to add JavaScript as TensorFlow has JS support?
Thank you for your time...
I using a Macbook pro with an m2 pro chip. I was trying to work with TensorFlow but I encountered an illegal hardware instruction error. To resolve it I initiated the installation of a metal plugin which is throwing the following error.
or semicolon (after version specifier)
awscli>=1.16.100boto3>=1.9.100
~~~~~~~~~~~^
Unable to locate awscli
[end of output]
When fitting a CNN model, every second Epoch takes zero seconds and with OUT_OF_RANGE warnings. Im using structured folders of categorical images for training and validation. Here is the warning message that occurs after every second Epoch.
The fitting looks like this...
37/37 ━━━━━━━━━━━━━━━━━━━━ 14s 337ms/step - accuracy: 0.5255 - loss: 1.0819 - val_accuracy: 0.2578 - val_loss: 2.4472
Epoch 4/20
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - accuracy: 0.5312 - loss: 1.1106 - val_accuracy: 0.1250 - val_loss: 3.0711
Epoch 5/20
2024-04-19 09:22:51.673909: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[[{{node IteratorGetNext}}]]
2024-04-19 09:22:51.673928: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[[{{node IteratorGetNext}}]]
[[IteratorGetNext/_59]]
2024-04-19 09:22:51.673940: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 10431687783238222105
2024-04-19 09:22:51.673944: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17360824274615977051
2024-04-19 09:22:51.673955: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 10732905483452597729
My setup is..
Tensor Flow Version: 2.16.1
Python 3.9.19 (main, Mar 21 2024, 12:07:41)
[Clang 14.0.6 ]
Pandas 2.2.2 Scikit-Learn 1.4.2 GPU is available
My generator is..
train_generator = datagen.flow_from_directory(
scalp_dir_train, # directory
target_size=(256, 256),# all images found will be resized
batch_size=32,
class_mode='categorical'
#subset='training' # Specify the subset as training
)
n_samples = train_generator.samples # gets the number of samples
validation_generator = datagen.flow_from_directory(
scalp_dir_test, # directory path
target_size=(256, 256),
batch_size=32,
class_mode='categorical'
#subset='validation' # Specifying the subset as validation
Here is my model.
early_stopping_monitor = EarlyStopping(patience = 10,restore_best_weights=True)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import SGD
optimizer = Adam(learning_rate=0.01)
model = Sequential()
model.add(Conv2D(128, (3, 3), activation='relu',padding='same', input_shape=(256, 256, 3)))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.3))
model.add(Conv2D(64, (3, 3),padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(Dense(4, activation='softmax')) # Defined by the number of classes
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
Here is the fit...
history=model.fit(
train_generator,
steps_per_epoch=37,
epochs=20,
validation_data=validation_generator,
validation_steps=12,
callbacks=[early_stopping_monitor]
#verbose=2
)
Hi,
I just noticed that using the jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct
Config:
M1 Pro Macbook Pro 2021
python 3.12.3 ; jax-metal 0.0.6 ; jax 0.4.26 ; jaxlib 0.4.23
MWE:
import jax
import jax.numpy as jnp
x = jnp.arange(20).reshape(5, 4)
print(f"{x=}\n")
def return_arr_with_ins(arr, ins):
return jnp.insert(arr, 2, ins, axis=1)
x2 = return_arr_with_ins(x, 99)
print(f"{x2=}\n")
return_arr_with_ins_jit = jax.jit(return_arr_with_ins)
x3 = return_arr_with_ins_jit(x, 99)
print(f"{x3=}\n")
Output: x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 2, 3, 0],
[ 4, 5, 6, 7, 0],
[ 8, 9, 10, 11, 0],
[12, 13, 14, 15, 0],
[16, 17, 18, 19, 0]], dtype=int32)
The same code run on a non-metal machine gives the correct results:
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
Not sure if this is the correct channel for bug reports, please feel free to let me know if there's a more appropriate place!
(Copied from https://github.com/google/jax/issues/20835)
I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip.
The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal.
This is essential functionality for a wide class of probabilistic machine learning models. Note that following functions (among others) are broken as a result:
jax.scipy.stats.binom.logpmf
jax.scipy.stats.nbinom.logpmf
jax.scipy.stats.poisson.logpmf
jax.scipy.stats.dirichlet.logpdf
jax.scipy.stats.beta.logpdf
jax.scipy.stats.gamma.logpdf
...
>>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf
gammaln(n + 1),
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln
return lax.lgamma(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma
return lgamma_p.bind(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma'
<stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')