File tree 1 file changed +18
-3
lines changed
1 file changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -686,9 +686,24 @@ def recurse_elems(elem: Any):
686
686
config_dict ["hidden_act" ] = config_dict .get ("activation" , "silu" )
687
687
config_dict ["tie_word_embeddings" ] = config_dict .get (
688
688
"tie_embeddings" , False )
689
- config_dict ["max_seq_len" ] = config_dict .get ("max_seq_len" , 128_000 )
690
- config_dict ["max_position_embeddings" ] = config_dict .get (
691
- "max_position_embeddings" , 128_000 )
689
+
690
+ if config_dict .get ("max_position_embeddings" ) is None :
691
+ max_position_embeddings = 128_000
692
+ try :
693
+ trust_remote_code_val = kwargs .get ("trust_remote_code" , False )
694
+ hf_config = get_config (model = model ,
695
+ trust_remote_code = trust_remote_code_val ,
696
+ revision = revision ,
697
+ config_format = ConfigFormat .HF )
698
+ if hf_value := hf_config .get_text_config ().max_position_embeddings :
699
+ max_position_embeddings = hf_value
700
+ except Exception as e :
701
+ logger .warning (
702
+ "The params.json file is missing 'max_position_embeddings'"
703
+ " and could not get a value from the HF config."
704
+ " Defaulting to 128000" ,
705
+ exc_info = e )
706
+ config_dict ["max_position_embeddings" ] = max_position_embeddings
692
707
693
708
if config_dict .get ("quantization" ) is not None :
694
709
quantization = config_dict .get ("quantization" , {})
You can’t perform that action at this time.
0 commit comments