-
Notifications
You must be signed in to change notification settings - Fork 280
Crash on Gemma3 token_embedding Layer During Training #2205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
using responses as an array of empty strings of same length as prompts is solving the problem, for now, and I'm getting segmentation fault in tokenizer when calling tensorflow apis., when using There is some issue with tokenization. |
Thank you, @pctablet505.
However, I was able to get it working with
Also, I wonder what the "correct" behavior for the original code |
Somehow it requires, the number of prompts to be same as number of responses. And probably we need to add some check and throw an error instead of moving forward and causing segmentation fault. |
Yeah, throwing an error here would be good. We could have allowed optional "response" inputs. The reason we don't do that is because this is a "causal LM" preprocessor, where we compute the loss only on the response tokens. So, passing an empty response doesn't make sense for the causal case (it's alright for your case though!). Generally, we use
Does something like this work?
Also, note that your code will work with only TF (and not JAX, Torch), because the latter two do not support string tensors. |
Yes, @abheesht17. As I previously mentioned, Interesting about my code not being portable to JAX and Torch. I will look into recommended approaches to handling strings in a portable way. |
Ah, my bad. Missed it |
Once the linked PR is merged, this issue will be closed. We can continue to track the warning about missing gradients in the other issue: #2102. Can see the same warning for LlamaBackbone too, so this is not a Gemma-specific issue. |
Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.
To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing
Here is the relevant code:
Expected behavior
The kernel shouldn't die.
Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.
Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.
The text was updated successfully, but these errors were encountered: