From 2246299536539964eb4aa488e7cc827f68d89ffc Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 7 Apr 2025 13:17:34 -0700 Subject: [PATCH] update qwen conversion script --- .../convert_qwen_checkpoints.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tools/checkpoint_conversion/convert_qwen_checkpoints.py b/tools/checkpoint_conversion/convert_qwen_checkpoints.py index 0ca26bc6f7..1cbd355081 100644 --- a/tools/checkpoint_conversion/convert_qwen_checkpoints.py +++ b/tools/checkpoint_conversion/convert_qwen_checkpoints.py @@ -106,7 +106,7 @@ def main(_): hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") hf_model.eval() - keras_hub_model = keras_hub.models.QwenBackbone.from_preset( + keras_hub_backbone = keras_hub.models.QwenBackbone.from_preset( f"hf://{hf_preset}" ) keras_hub_tokenizer = keras_hub.models.QwenTokenizer.from_preset( @@ -117,9 +117,18 @@ def main(_): # === Check that the models and tokenizers outputs match === test_tokenizer(keras_hub_tokenizer, hf_tokenizer) - test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + test_model(keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer) print("\n-> Tests passed!") + preprocessor = keras_hub.models.Qwen2CausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_model = keras_hub.models.Qwen2CausalLM( + keras_hub_backbone, preprocessor + ) + + keras_hub_model.save_to_preset(f"./{preset}") + if __name__ == "__main__": flags.mark_flag_as_required("preset")