Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmaCausalLM fails to load if TensorFlow NumPy behavior isenabled #2136

Open
t-kalinowski opened this issue Mar 12, 2025 · 4 comments
Open
Assignees
Labels
Gemma Gemma model specific issues type:Bug Something isn't working

Comments

@t-kalinowski
Copy link

Describe the bug
Calling GemmaCausalLM.from_preset() errors if TF NumPy type promotion behavior is enabled. This happens regardless of which Keras backend is used.

To Reproduce

Given script bug.py

# /// script
# dependencies = [
#   "keras",
#   "keras-hub",
#   "tensorflow"
# ]
# ///


import tensorflow as tf
tf.experimental.numpy.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
keras.config.set_dtype_policy("float16")

import json
with open(os.path.expanduser("~/.kaggle/kaggle.json"), "r") as f:
    kaggle_credentials = json.load(f)
os.environ["KAGGLE_USERNAME"] = kaggle_credentials["username"]
os.environ["KAGGLE_KEY"] = kaggle_credentials["key"]

import keras_hub

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

Calling uv run --python 3.11 bug2.py produces:

tomasz@tomaszkalinows-WQVX deep_learning_with_r_3e % uv run --python 3.11 bug2.py
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
Traceback (most recent call last):
  File "/Users/tomasz/github/t-kalinowski/deep_learning_with_r_3e/bug2.py", line 27, in <module>
    gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/models/task.py", line 198, in from_preset
    return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 670, in load_task
    return super().load_task(
           ^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 618, in load_task
    kwargs["backbone"] = self.load_backbone(
                         ^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 648, in load_backbone
    backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 631, in _raise_loading_failure
    raise ValueError(msg)
ValueError: A total of 127 objects could not be loaded. Example error message for object <ReversibleEmbedding name=token_embedding, built=True>:

Layer 'token_embedding' expected 1 variables, but received 0 variables during loading. Expected: ['embeddings']

List of objects that could not be loaded:
[<ReversibleEmbedding name=token_embedding, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>]

A similar error happens if KERAS_BACKEND='jax' is configured.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Mar 12, 2025
@Gopi-Uppari
Copy link

Hi @t-kalinowski,

I was able to reproduce the issue, and it looks like the problem is with the model weights for GemmaCausalLM (gemma_2b_en). The error suggests that the weights didn’t load correctly, which could mean a corrupt download or missing files.

Instead of troubleshooting keras_hub, a better approach would be to use Hugging Face’s transformers library to load the model. It’s more widely supported and tends to have fewer compatibility issues. Could you please refer to this gist file.

Thank you.

@rlcauvin
Copy link

rlcauvin commented Mar 17, 2025

Instead of troubleshooting keras_hub, a better approach would be to use Hugging Face’s transformers library to load the model. It’s more widely supported and tends to have fewer compatibility issues. Could you please refer to this gist file.

I would certainly hope that this alternate approach doesn't mean the KerasHub team will ignore the apparent bug in GemmaCausalLM. And shouldn't the issue be tagged as type:Bug?

@sonali-kumari1 sonali-kumari1 added the type:Bug Something isn't working label Mar 18, 2025
@divyashreepathihalli
Copy link
Collaborator

I am not sure about Keras's support for tf.experimental.numpy.experimental_enable_numpy_behavior(dtype_conversion_mode="safe") with TF backend.
I can confirm that the weights on Kaggle for GemmaCausalLM is complete. You can successfully load the weights and test inference and training without the experimental numpy behavior enabled.

@t-kalinowski
Copy link
Author

The issue happens regardless of backend.

AFAIK enable_numpy_behavior() should only affect the __getitem__ method for tensors, while enable_numpy_behavior(,"safe") will also permit some additional type promotions, e.g., when multiplying a 'float32' with a 'int32'.

Were you able to identify why it breaks?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants