Skip to content

Commit 4ee4826

Browse files
princepridetracelogfbStephen Chen
authored
[BugFix] Correct max_model_len derivation from config.json for Mistral format (#17937)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: tracelogfb <48808670+tracelogfb@users.noreply.github.com> Co-authored-by: Stephen Chen <tracelog@meta.com>
1 parent 60017dc commit 4ee4826

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/transformers_utils/config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,9 +686,24 @@ def recurse_elems(elem: Any):
686686
config_dict["hidden_act"] = config_dict.get("activation", "silu")
687687
config_dict["tie_word_embeddings"] = config_dict.get(
688688
"tie_embeddings", False)
689-
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
690-
config_dict["max_position_embeddings"] = config_dict.get(
691-
"max_position_embeddings", 128_000)
689+
690+
if config_dict.get("max_position_embeddings") is None:
691+
max_position_embeddings = 128_000
692+
try:
693+
trust_remote_code_val = kwargs.get("trust_remote_code", False)
694+
hf_config = get_config(model=model,
695+
trust_remote_code=trust_remote_code_val,
696+
revision=revision,
697+
config_format=ConfigFormat.HF)
698+
if hf_value := hf_config.get_text_config().max_position_embeddings:
699+
max_position_embeddings = hf_value
700+
except Exception as e:
701+
logger.warning(
702+
"The params.json file is missing 'max_position_embeddings'"
703+
" and could not get a value from the HF config."
704+
" Defaulting to 128000",
705+
exc_info=e)
706+
config_dict["max_position_embeddings"] = max_position_embeddings
692707

693708
if config_dict.get("quantization") is not None:
694709
quantization = config_dict.get("quantization", {})

0 commit comments

Comments
 (0)