Skip to content

Crash on Gemma3 token_embedding Layer During Training #2205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rlcauvin opened this issue Apr 5, 2025 · 7 comments · Fixed by #2217
Closed

Crash on Gemma3 token_embedding Layer During Training #2205

rlcauvin opened this issue Apr 5, 2025 · 7 comments · Fixed by #2217
Assignees
Labels
Gemma Gemma model specific issues keras-team-review-pending type:Bug Something isn't working

Comments

@rlcauvin
Copy link

rlcauvin commented Apr 5, 2025

Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.

To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing

Here is the relevant code:

class GemmaEncoder(keras.Layer):

  def __init__(
    self,
    preprocessor: keras_hub.models.Gemma3CausalLMPreprocessor,
    backbone: keras_hub.models.Gemma3Backbone,
    pooling_layer: keras.layers.Layer,
    **kwargs):

    super().__init__(**kwargs)

    self.preprocessor = preprocessor
    self.backbone = backbone
    self.pooling_layer = pooling_layer

  @classmethod
  def from_preset(
    cls,
    preset: str = "gemma3_1b",
    pooling_layer: keras.layers.Layer = None,
    name = "gemma_encoder",
    **kwargs):

    preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(preset, sequence_length = 128)
    backbone = keras_hub.models.Gemma3Backbone.from_preset(preset)
    pooling_layer = keras.layers.GlobalAveragePooling1D(name = name + "_global_average_pooling1d") if pooling_layer is None else pooling_layer

    return cls(preprocessor = preprocessor, backbone = backbone, pooling_layer = pooling_layer, name = name, **kwargs)

  def call(self, inputs):

    adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
      {
      "prompts": keras.ops.array(inputs),
      "responses": keras.ops.array([""])
      }
    tokenized = self.preprocessor(adapted)
    embedded = self.backbone.token_embedding(tokenized[0]["token_ids"])
    pooled = self.pooling_layer(embedded)

    return pooled

gse_layer = GemmaEncoder.from_preset(preset = "gemma3_1b");

gse_layer(inputs = ["oranges and lemons are sour", "lemons and oranges are tart"])

headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
headline_featurizer = gse_layer(headline_input)
dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)

inputs = [headline_input]
outputs = [activation]
nn_model = keras.Model(inputs = inputs, outputs = outputs, name = "nn_model")

optimizer = keras.optimizers.Adam(learning_rate=0.001) # keras.optimizers.Nadam(learning_rate = 0.00007)
nn_model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics = ["accuracy"], run_eagerly = True)

x_train = {"headline" : keras.ops.array(["hello", "goodbye", "see you soon"])}
y_train = keras.ops.array([[1], [0], [0]])

nn_model_history = nn_model.fit(
  x = x_train,
  y = y_train,
  # batch_size = 1,
  epochs = 3,
  verbose = 1)

Expected behavior
The kernel shouldn't die.

Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.

Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.

@pctablet505
Copy link
Collaborator

adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
      {
        "prompts": keras.ops.array(inputs),
        "responses": keras.ops.array([""]*len(inputs))
      }

using responses as an array of empty strings of same length as prompts is solving the problem, for now, and I'm getting segmentation fault in tokenizer when calling tensorflow apis., when using "responses": keras.ops.array([""]).

There is some issue with tokenization.

@rlcauvin
Copy link
Author

rlcauvin commented Apr 16, 2025

Thank you, @pctablet505. "responses": keras.ops.array([""]*len(inputs)) resulted in the following error for me:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-77-35021b1f02b5> in <cell line: 0>()
      1 headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
----> 2 headline_featurizer = gse_layer(headline_input)
      3 dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
      4 activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)
      5 

1 frames
/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

<ipython-input-74-a8041c42a4aa> in call(self, inputs)
     34       "prompts": keras.ops.array(inputs),
---> 35       "responses": keras.ops.array([""] * len(inputs))
     36       }
     37     tokenized = self.preprocessor(adapted)

TypeError: Exception encountered when calling GemmaEncoder.call().

Could not automatically infer the output shape / dtype of 'gemma_encoder' (of type GemmaEncoder). Either the `GemmaEncoder.call()` method is incorrect, or you need to implement the `GemmaEncoder.compute_output_spec() / compute_output_shape()` method. Error encountered:

len is not well defined for a symbolic Tensor (Placeholder:0). Please call `x.shape` rather than `len(x)` for shape information.

Arguments received by GemmaEncoder.call():
  • args=('<KerasTensor shape=(None,), dtype=string, sparse=False, ragged=False, name=headline>',)
  • kwargs=<class 'inspect._empty'>

However, I was able to get it working with "responses": keras.ops.full(keras.ops.shape(inputs), "", "string"). Yet it output this warning during training:

/usr/local/lib/python3.11/dist-packages/keras/src/optimizers/base_optimizer.py:774: UserWarning: Gradients do not exist for variables ['decoder_block_0/pre_attention_norm/scale', 'decoder_block_0/post_attention_norm/scale', 'decoder_block_0/attention/query/kernel', 'decoder_block_0/attention/key/kernel', 'decoder_block_0/attention/value/kernel', 'decoder_block_0/attention/query_norm/scale', 'decoder_block_0/attention/key_norm/scale', 'decoder_block_0/attention/attention_output/kernel', 'decoder_block_0/pre_ffw_norm/scale', 'decoder_block_0/post_ffw_norm/scale', 'decoder_block_0/ffw_gating/kernel', 'decoder_block_0/ffw_gating_2/kernel', 'decoder_block_0/ffw_linear/kernel', 'decoder_block_1/pre_attention_norm/scale', 'decoder_block_1/post_attention_norm/scale', 'decoder_block_1/attention/query/kernel', 'decoder_block_1/attention/key/kernel', 'decoder_block_1/attention/value/kernel', 'decoder_block_1/attention/query_norm/scale', 'decoder_block_1/attention/key_norm/scale', 'decoder_block_1/attention/attention_output/kernel', 'decoder_block_1/pre_ffw_norm/scale', 'decoder_block_1/post_ffw_norm/scale', 'decoder_block_1/ffw_gating/kernel', 'decoder_block_1/ffw_gating_2/kernel', 'decoder_block_1/ffw_linear/kernel', 'decoder_block_2/pre_attention_norm/scale', 'decoder_block_2/post_attention_norm/scale', 'decoder_block_2/attention/query/kernel', 'decoder_block_2/attention/key/kernel', 'decoder_block_2/attention/value/kernel', 'decoder_block_2/attention/query_norm/scale', 'decoder_block_2/attention/key_norm/scale', 'decoder_block_2/attention/attention_output/kernel', 'decoder_block_2/pre_ffw_norm/scale', 'decoder_block_2/post_ffw_norm/scale', 'decoder_block_2/ffw_gating/kernel', 'decoder_block_2/ffw_gating_2/kernel', 'decoder_block_2/ffw_linear/kernel', 'decoder_block_3/pre_attention_norm/scale', 'decoder_block_3/post_attention_norm/scale', 'decoder_block_3/attention/query/kernel', 'decoder_block_3/attention/key/kernel', 'decoder_block_3/attention/value/kernel', 'decoder_block_3/attention/query_norm/scale', 'decoder_block_3/attention/key_norm/scale', 'decoder_block_3/attention/attention_output/kernel', 'decoder_block_3/pre_ffw_norm/scale', 'decoder_block_3/post_ffw_norm/scale', 'decoder_block_3/ffw_gating/kernel', 'decoder_block_3/ffw_gating_2/kernel', 'decoder_block_3/ffw_linear/kernel', 'decoder_block_4/pre_attention_norm/scale', 'decoder_block_4/post_attention_norm/scale', 'decoder_block_4/attention/query/kernel', 'decoder_block_4/attention/key/kernel', 'decoder_block_4/attention/value/kernel', 'decoder_block_4/attention/query_norm/scale', 'decoder_block_4/attention/key_norm/scale', 'decoder_block_4/attention/attention_output/kernel', 'decoder_block_4/pre_ffw_norm/scale', 'decoder_block_4/post_ffw_norm/scale', 'decoder_block_4/ffw_gating/kernel', 'decoder_block_4/ffw_gating_2/kernel', 'decoder_block_4/ffw_linear/kernel', 'decoder_block_5/pre_attention_norm/scale', 'decoder_block_5/post_attention_norm/scale', 'decoder_block_5/attention/query/kernel', 'decoder_block_5/attention/key/kernel', 'decoder_block_5/attention/value/kernel', 'decoder_block_5/attention/query_norm/scale', 'decoder_block_5/attention/key_norm/scale', 'decoder_block_5/attention/attention_output/kernel', 'decoder_block_5/pre_ffw_norm/scale', 'decoder_block_5/post_ffw_norm/scale', 'decoder_block_5/ffw_gating/kernel', 'decoder_block_5/ffw_gating_2/kernel', 'decoder_block_5/ffw_linear/kernel', 'decoder_block_6/pre_attention_norm/scale', 'decoder_block_6/post_attention_norm/scale', 'decoder_block_6/attention/query/kernel', 'decoder_block_6/attention/key/kernel', 'decoder_block_6/attention/value/kernel', 'decoder_block_6/attention/query_norm/scale', 'decoder_block_6/attention/key_norm/scale', 'decoder_block_6/attention/attention_output/kernel', 'decoder_block_6/pre_ffw_norm/scale', 'decoder_block_6/post_ffw_norm/scale', 'decoder_block_6/ffw_gating/kernel', 'decoder_block_6/ffw_gating_2/kernel', 'decoder_block_6/ffw_linear/kernel', 'decoder_block_7/pre_attention_norm/scale', 'decoder_block_7/post_attention_norm/scale', 'decoder_block_7/attention/query/kernel', 'decoder_block_7/attention/key/kernel', 'decoder_block_7/attention/value/kernel', 'decoder_block_7/attention/query_norm/scale', 'decoder_block_7/attention/key_norm/scale', 'decoder_block_7/attention/attention_output/kernel', 'decoder_block_7/pre_ffw_norm/scale', 'decoder_block_7/post_ffw_norm/scale', 'decoder_block_7/ffw_gating/kernel', 'decoder_block_7/ffw_gating_2/kernel', 'decoder_block_7/ffw_linear/kernel', 'decoder_block_8/pre_attention_norm/scale', 'decoder_block_8/post_attention_norm/scale', 'decoder_block_8/attention/query/kernel', 'decoder_block_8/attention/key/kernel', 'decoder_block_8/attention/value/kernel', 'decoder_block_8/attention/query_norm/scale', 'decoder_block_8/attention/key_norm/scale', 'decoder_block_8/attention/attention_output/kernel', 'decoder_block_8/pre_ffw_norm/scale', 'decoder_block_8/post_ffw_norm/scale', 'decoder_block_8/ffw_gating/kernel', 'decoder_block_8/ffw_gating_2/kernel', 'decoder_block_8/ffw_linear/kernel', 'decoder_block_9/pre_attention_norm/scale', 'decoder_block_9/post_attention_norm/scale', 'decoder_block_9/attention/query/kernel', 'decoder_block_9/attention/key/kernel', 'decoder_block_9/attention/value/kernel', 'decoder_block_9/attention/query_norm/scale', 'decoder_block_9/attention/key_norm/scale', 'decoder_block_9/attention/attention_output/kernel', 'decoder_block_9/pre_ffw_norm/scale', 'decoder_block_9/post_ffw_norm/scale', 'decoder_block_9/ffw_gating/kernel', 'decoder_block_9/ffw_gating_2/kernel', 'decoder_block_9/ffw_linear/kernel', 'decoder_block_10/pre_attention_norm/scale', 'decoder_block_10/post_attention_norm/scale', 'decoder_block_10/attention/query/kernel', 'decoder_block_10/attention/key/kernel', 'decoder_block_10/attention/value/kernel', 'decoder_block_10/attention/query_norm/scale', 'decoder_block_10/attention/key_norm/scale', 'decoder_block_10/attention/attention_output/kernel', 'decoder_block_10/pre_ffw_norm/scale', 'decoder_block_10/post_ffw_norm/scale', 'decoder_block_10/ffw_gating/kernel', 'decoder_block_10/ffw_gating_2/kernel', 'decoder_block_10/ffw_linear/kernel', 'decoder_block_11/pre_attention_norm/scale', 'decoder_block_11/post_attention_norm/scale', 'decoder_block_11/attention/query/kernel', 'decoder_block_11/attention/key/kernel', 'decoder_block_11/attention/value/kernel', 'decoder_block_11/attention/query_norm/scale', 'decoder_block_11/attention/key_norm/scale', 'decoder_block_11/attention/attention_output/kernel', 'decoder_block_11/pre_ffw_norm/scale', 'decoder_block_11/post_ffw_norm/scale', 'decoder_block_11/ffw_gating/kernel', 'decoder_block_11/ffw_gating_2/kernel', 'decoder_block_11/ffw_linear/kernel', 'decoder_block_12/pre_attention_norm/scale', 'decoder_block_12/post_attention_norm/scale', 'decoder_block_12/attention/query/kernel', 'decoder_block_12/attention/key/kernel', 'decoder_block_12/attention/value/kernel', 'decoder_block_12/attention/query_norm/scale', 'decoder_block_12/attention/key_norm/scale', 'decoder_block_12/attention/attention_output/kernel', 'decoder_block_12/pre_ffw_norm/scale', 'decoder_block_12/post_ffw_norm/scale', 'decoder_block_12/ffw_gating/kernel', 'decoder_block_12/ffw_gating_2/kernel', 'decoder_block_12/ffw_linear/kernel', 'decoder_block_13/pre_attention_norm/scale', 'decoder_block_13/post_attention_norm/scale', 'decoder_block_13/attention/query/kernel', 'decoder_block_13/attention/key/kernel', 'decoder_block_13/attention/value/kernel', 'decoder_block_13/attention/query_norm/scale', 'decoder_block_13/attention/key_norm/scale', 'decoder_block_13/attention/attention_output/kernel', 'decoder_block_13/pre_ffw_norm/scale', 'decoder_block_13/post_ffw_norm/scale', 'decoder_block_13/ffw_gating/kernel', 'decoder_block_13/ffw_gating_2/kernel', 'decoder_block_13/ffw_linear/kernel', 'decoder_block_14/pre_attention_norm/scale', 'decoder_block_14/post_attention_norm/scale', 'decoder_block_14/attention/query/kernel', 'decoder_block_14/attention/key/kernel', 'decoder_block_14/attention/value/kernel', 'decoder_block_14/attention/query_norm/scale', 'decoder_block_14/attention/key_norm/scale', 'decoder_block_14/attention/attention_output/kernel', 'decoder_block_14/pre_ffw_norm/scale', 'decoder_block_14/post_ffw_norm/scale', 'decoder_block_14/ffw_gating/kernel', 'decoder_block_14/ffw_gating_2/kernel', 'decoder_block_14/ffw_linear/kernel', 'decoder_block_15/pre_attention_norm/scale', 'decoder_block_15/post_attention_norm/scale', 'decoder_block_15/attention/query/kernel', 'decoder_block_15/attention/key/kernel', 'decoder_block_15/attention/value/kernel', 'decoder_block_15/attention/query_norm/scale', 'decoder_block_15/attention/key_norm/scale', 'decoder_block_15/attention/attention_output/kernel', 'decoder_block_15/pre_ffw_norm/scale', 'decoder_block_15/post_ffw_norm/scale', 'decoder_block_15/ffw_gating/kernel', 'decoder_block_15/ffw_gating_2/kernel', 'decoder_block_15/ffw_linear/kernel', 'decoder_block_16/pre_attention_norm/scale', 'decoder_block_16/post_attention_norm/scale', 'decoder_block_16/attention/query/kernel', 'decoder_block_16/attention/key/kernel', 'decoder_block_16/attention/value/kernel', 'decoder_block_16/attention/query_norm/scale', 'decoder_block_16/attention/key_norm/scale', 'decoder_block_16/attention/attention_output/kernel', 'decoder_block_16/pre_ffw_norm/scale', 'decoder_block_16/post_ffw_norm/scale', 'decoder_block_16/ffw_gating/kernel', 'decoder_block_16/ffw_gating_2/kernel', 'decoder_block_16/ffw_linear/kernel', 'decoder_block_17/pre_attention_norm/scale', 'decoder_block_17/post_attention_norm/scale', 'decoder_block_17/attention/query/kernel', 'decoder_block_17/attention/key/kernel', 'decoder_block_17/attention/value/kernel', 'decoder_block_17/attention/query_norm/scale', 'decoder_block_17/attention/key_norm/scale', 'decoder_block_17/attention/attention_output/kernel', 'decoder_block_17/pre_ffw_norm/scale', 'decoder_block_17/post_ffw_norm/scale', 'decoder_block_17/ffw_gating/kernel', 'decoder_block_17/ffw_gating_2/kernel', 'decoder_block_17/ffw_linear/kernel', 'decoder_block_18/pre_attention_norm/scale', 'decoder_block_18/post_attention_norm/scale', 'decoder_block_18/attention/query/kernel', 'decoder_block_18/attention/key/kernel', 'decoder_block_18/attention/value/kernel', 'decoder_block_18/attention/query_norm/scale', 'decoder_block_18/attention/key_norm/scale', 'decoder_block_18/attention/attention_output/kernel', 'decoder_block_18/pre_ffw_norm/scale', 'decoder_block_18/post_ffw_norm/scale', 'decoder_block_18/ffw_gating/kernel', 'decoder_block_18/ffw_gating_2/kernel', 'decoder_block_18/ffw_linear/kernel', 'decoder_block_19/pre_attention_norm/scale', 'decoder_block_19/post_attention_norm/scale', 'decoder_block_19/attention/query/kernel', 'decoder_block_19/attention/key/kernel', 'decoder_block_19/attention/value/kernel', 'decoder_block_19/attention/query_norm/scale', 'decoder_block_19/attention/key_norm/scale', 'decoder_block_19/attention/attention_output/kernel', 'decoder_block_19/pre_ffw_norm/scale', 'decoder_block_19/post_ffw_norm/scale', 'decoder_block_19/ffw_gating/kernel', 'decoder_block_19/ffw_gating_2/kernel', 'decoder_block_19/ffw_linear/kernel', 'decoder_block_20/pre_attention_norm/scale', 'decoder_block_20/post_attention_norm/scale', 'decoder_block_20/attention/query/kernel', 'decoder_block_20/attention/key/kernel', 'decoder_block_20/attention/value/kernel', 'decoder_block_20/attention/query_norm/scale', 'decoder_block_20/attention/key_norm/scale', 'decoder_block_20/attention/attention_output/kernel', 'decoder_block_20/pre_ffw_norm/scale', 'decoder_block_20/post_ffw_norm/scale', 'decoder_block_20/ffw_gating/kernel', 'decoder_block_20/ffw_gating_2/kernel', 'decoder_block_20/ffw_linear/kernel', 'decoder_block_21/pre_attention_norm/scale', 'decoder_block_21/post_attention_norm/scale', 'decoder_block_21/attention/query/kernel', 'decoder_block_21/attention/key/kernel', 'decoder_block_21/attention/value/kernel', 'decoder_block_21/attention/query_norm/scale', 'decoder_block_21/attention/key_norm/scale', 'decoder_block_21/attention/attention_output/kernel', 'decoder_block_21/pre_ffw_norm/scale', 'decoder_block_21/post_ffw_norm/scale', 'decoder_block_21/ffw_gating/kernel', 'decoder_block_21/ffw_gating_2/kernel', 'decoder_block_21/ffw_linear/kernel', 'decoder_block_22/pre_attention_norm/scale', 'decoder_block_22/post_attention_norm/scale', 'decoder_block_22/attention/query/kernel', 'decoder_block_22/attention/key/kernel', 'decoder_block_22/attention/value/kernel', 'decoder_block_22/attention/query_norm/scale', 'decoder_block_22/attention/key_norm/scale', 'decoder_block_22/attention/attention_output/kernel', 'decoder_block_22/pre_ffw_norm/scale', 'decoder_block_22/post_ffw_norm/scale', 'decoder_block_22/ffw_gating/kernel', 'decoder_block_22/ffw_gating_2/kernel', 'decoder_block_22/ffw_linear/kernel', 'decoder_block_23/pre_attention_norm/scale', 'decoder_block_23/post_attention_norm/scale', 'decoder_block_23/attention/query/kernel', 'decoder_block_23/attention/key/kernel', 'decoder_block_23/attention/value/kernel', 'decoder_block_23/attention/query_norm/scale', 'decoder_block_23/attention/key_norm/scale', 'decoder_block_23/attention/attention_output/kernel', 'decoder_block_23/pre_ffw_norm/scale', 'decoder_block_23/post_ffw_norm/scale', 'decoder_block_23/ffw_gating/kernel', 'decoder_block_23/ffw_gating_2/kernel', 'decoder_block_23/ffw_linear/kernel', 'decoder_block_24/pre_attention_norm/scale', 'decoder_block_24/post_attention_norm/scale', 'decoder_block_24/attention/query/kernel', 'decoder_block_24/attention/key/kernel', 'decoder_block_24/attention/value/kernel', 'decoder_block_24/attention/query_norm/scale', 'decoder_block_24/attention/key_norm/scale', 'decoder_block_24/attention/attention_output/kernel', 'decoder_block_24/pre_ffw_norm/scale', 'decoder_block_24/post_ffw_norm/scale', 'decoder_block_24/ffw_gating/kernel', 'decoder_block_24/ffw_gating_2/kernel', 'decoder_block_24/ffw_linear/kernel', 'decoder_block_25/pre_attention_norm/scale', 'decoder_block_25/post_attention_norm/scale', 'decoder_block_25/attention/query/kernel', 'decoder_block_25/attention/key/kernel', 'decoder_block_25/attention/value/kernel', 'decoder_block_25/attention/query_norm/scale', 'decoder_block_25/attention/key_norm/scale', 'decoder_block_25/attention/attention_output/kernel', 'decoder_block_25/pre_ffw_norm/scale', 'decoder_block_25/post_ffw_norm/scale', 'decoder_block_25/ffw_gating/kernel', 'decoder_block_25/ffw_gating_2/kernel', 'decoder_block_25/ffw_linear/kernel', 'final_normalization/scale'] when minimizing the loss. If using `model.compile()`, did you forget to provide a `loss` argument?
  warnings.warn(

Also, I wonder what the "correct" behavior for the original code "responses": keras.ops.array([""]) is. Should it crash the runtime during training or gracefully output some sort of error or warning?

@pctablet505
Copy link
Collaborator

Somehow it requires, the number of prompts to be same as number of responses. And probably we need to add some check and throw an error instead of moving forward and causing segmentation fault.

@abheesht17 abheesht17 assigned pctablet505 and unassigned abheesht17 Apr 16, 2025
@abheesht17
Copy link
Collaborator

abheesht17 commented Apr 16, 2025

Also, I wonder what the "correct" behavior for the original code "responses": keras.ops.array([""]) is. Should it crash the runtime during training or gracefully output some sort of error or warning?

Yeah, throwing an error here would be good. We could have allowed optional "response" inputs. The reason we don't do that is because this is a "causal LM" preprocessor, where we compute the loss only on the response tokens. So, passing an empty response doesn't make sense for the causal case (it's alright for your case though!).

Generally, we use tf.data when we do a .fit(), where when you do .batch() on your data, it batches up all values in the dictionary, resulting in the same shape for "prompts" and "responses".

"responses": keras.ops.array([""]*len(inputs)) resulted in the following error for me:

Does something like this work?

        adapted["responses"] = (
            keras.ops.full(
                shape=keras.ops.shape(adapted["prompts"]),
                fill_value="",
                dtype="string",
            )
        )

Also, note that your code will work with only TF (and not JAX, Torch), because the latter two do not support string tensors.

@rlcauvin
Copy link
Author

Yes, @abheesht17. As I previously mentioned, "responses": keras.ops.full(keras.ops.shape(inputs), "", "string") works fine. (Though I prefer your version with keyword arguments instead of positional arguments.)

Interesting about my code not being portable to JAX and Torch. I will look into recommended approaches to handling strings in a portable way.

@abheesht17
Copy link
Collaborator

abheesht17 commented Apr 16, 2025

As I previously mentioned, "responses": keras.ops.full(keras.ops.shape(inputs), "", "string") works fine.

Ah, my bad. Missed it

@pctablet505
Copy link
Collaborator

Once the linked PR is merged, this issue will be closed. We can continue to track the warning about missing gradients in the other issue: #2102. Can see the same warning for LlamaBackbone too, so this is not a Gemma-specific issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues keras-team-review-pending type:Bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants