Skip to content

Commit 0b63ce0

Browse files
authored
Update gemma3_causal_lm_preprocessor.py (#2217)
* Update gemma3_causal_lm_preprocessor.py Added checks for invalid inputs * Update gemma3_causal_lm_preprocessor.py * Update gemma3_causal_lm_preprocessor_test.py Added tests to check invalid inputs
1 parent 342b684 commit 0b63ce0

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py

+1
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ def call(
512512

513513
# Extract text part of the input.
514514
prompts, responses = x["prompts"], x["responses"]
515+
tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
515516

516517
# Find out if the input is batched/not batched. Uprank if not batched.
517518
# In other preprocessors, we don't have to do this, but here, all

keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,20 @@ def test_generate_postprocess(self):
167167
x = preprocessor.generate_postprocess(input_data)
168168
self.assertAllEqual(x, "the quick brown fox \n\n <start_of_image>")
169169

170+
def test_invalid_shape(self):
171+
with self.assertRaises(ValueError):
172+
input_data = {
173+
"prompts": ["hello world", "this is testing"],
174+
"responses": [""],
175+
}
176+
self.text_preprocessor(input_data)
177+
with self.assertRaises(ValueError):
178+
input_data = {
179+
"prompts": ["hello world", "this is testing"],
180+
"responses": ["hello", "", ""],
181+
}
182+
self.text_preprocessor(input_data)
183+
170184
@pytest.mark.kaggle_key_required
171185
@pytest.mark.extra_large
172186
def test_all_presets(self):

0 commit comments

Comments
 (0)