@@ -661,36 +661,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
661
661
canonical_path (llm_config .base .params ) if llm_config .base .params else None
662
662
)
663
663
output_dir_path = canonical_path (llm_config .export .output_dir , dir = True )
664
- weight_type = WeightType .FAIRSEQ2 if llm_config .base .fairseq2 else WeightType .LLAMA
665
664
666
- # Convert dtype override string to actual type
665
+ llm_config .base .checkpoint = checkpoint_path
666
+ llm_config .base .checkpoint_dir = checkpoint_dir
667
+ llm_config .base .params = params_path
668
+ llm_config .export .output_dir = output_dir_path
669
+
670
+ # Convert dtype override string to actual type.
667
671
dtype_override = DType [llm_config .model .dtype_override ]
668
672
669
- edge_manager = _load_llama_model (
670
- llm_config ,
671
- checkpoint = checkpoint_path ,
672
- checkpoint_dir = checkpoint_dir ,
673
- params_path = params_path ,
674
- use_kv_cache = llm_config .model .use_kv_cache ,
675
- use_sdpa_with_kv_cache = llm_config .model .use_sdpa_with_kv_cache ,
676
- generate_full_logits = llm_config .debug .generate_full_logits ,
677
- weight_type = weight_type ,
678
- enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
679
- calibration_tasks = llm_config .quantization .calibration_tasks ,
680
- calibration_limit = llm_config .quantization .calibration_limit ,
681
- calibration_seq_length = llm_config .quantization .calibration_seq_length ,
682
- calibration_data = llm_config .quantization .calibration_data ,
683
- tokenizer_path = llm_config .base .tokenizer_path ,
684
- verbose = llm_config .debug .verbose ,
685
- max_seq_len = llm_config .export .max_seq_length ,
686
- max_context_len = llm_config .export .max_context_length ,
687
- input_prune_map_path = llm_config .model .input_prune_map ,
688
- output_prune_map_path = llm_config .model .output_prune_map ,
689
- metadata_str = llm_config .base .metadata ,
690
- dtype_override = dtype_override ,
691
- use_qnn = llm_config .backend .qnn .enabled ,
692
- export_only = llm_config .export .export_only ,
693
- )
673
+ edge_manager = _load_llama_model (llm_config )
694
674
695
675
# At this point, the model is loaded in the default fp32.
696
676
@@ -1167,32 +1147,7 @@ def _load_llama_model_metadata(
1167
1147
return metadata
1168
1148
1169
1149
1170
- def _load_llama_model (
1171
- llm_config : LlmConfig ,
1172
- * ,
1173
- checkpoint : Optional [str ] = None ,
1174
- checkpoint_dir : Optional [str ] = None ,
1175
- params_path : Optional [str ] = None ,
1176
- use_kv_cache : bool = False ,
1177
- use_sdpa_with_kv_cache : bool = False ,
1178
- generate_full_logits : bool = False ,
1179
- weight_type : WeightType = WeightType .LLAMA ,
1180
- enable_dynamic_shape : bool = False ,
1181
- calibration_tasks : Optional [List [str ]] = None ,
1182
- calibration_limit : Optional [int ] = None ,
1183
- calibration_seq_length : Optional [int ] = None ,
1184
- calibration_data : Optional [str ] = None ,
1185
- tokenizer_path : Optional [str ] = None ,
1186
- verbose : bool = False ,
1187
- max_seq_len : int = 128 ,
1188
- max_context_len : int = 128 ,
1189
- input_prune_map_path : Optional [str ] = None ,
1190
- output_prune_map_path : Optional [str ] = None ,
1191
- metadata_str : Optional [str ] = None ,
1192
- dtype_override : Optional [DType ] = None ,
1193
- use_qnn : bool = False ,
1194
- export_only : bool = False ,
1195
- ) -> "LLMEdgeManager" :
1150
+ def _load_llama_model (llm_config : LlmConfig ) -> "LLMEdgeManager" :
1196
1151
"""
1197
1152
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
1198
1153
can help further lower the model to ExecuTorch.
@@ -1220,31 +1175,33 @@ def _load_llama_model(
1220
1175
llm_config = llm_config ,
1221
1176
)
1222
1177
)
1178
+ # Convert dtype override string to actual type.
1179
+ dtype_override = DType [llm_config .model .dtype_override ]
1223
1180
1224
1181
return LLMEdgeManager (
1225
1182
model = model ,
1226
1183
modelname = modelname ,
1227
1184
max_seq_len = model .max_seq_len , # type: ignore
1228
1185
dtype = dtype_override ,
1229
- use_kv_cache = use_kv_cache ,
1230
- generate_full_logits = generate_full_logits ,
1186
+ use_kv_cache = llm_config . model . use_kv_cache ,
1187
+ generate_full_logits = llm_config . debug . generate_full_logits ,
1231
1188
example_inputs = example_inputs ,
1232
1189
example_kwarg_inputs = example_kwarg_inputs ,
1233
1190
dynamic_shapes = dynamic_shapes ,
1234
- enable_dynamic_shape = enable_dynamic_shape ,
1235
- calibration_tasks = calibration_tasks ,
1236
- calibration_limit = calibration_limit ,
1237
- calibration_seq_length = calibration_seq_length ,
1238
- calibration_data = calibration_data ,
1239
- tokenizer_path = tokenizer_path ,
1240
- use_legacy_export = use_qnn ,
1241
- save_exported_program = export_only ,
1242
- verbose = verbose ,
1191
+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
1192
+ calibration_tasks = llm_config . quantization . calibration_tasks ,
1193
+ calibration_limit = llm_config . quantization . calibration_limit ,
1194
+ calibration_seq_length = llm_config . quantization . calibration_seq_length ,
1195
+ calibration_data = llm_config . quantization . calibration_data ,
1196
+ tokenizer_path = llm_config . base . tokenizer_path ,
1197
+ use_legacy_export = llm_config . backend . qnn . enabled ,
1198
+ save_exported_program = llm_config . export . export_only ,
1199
+ verbose = llm_config . debug . verbose ,
1243
1200
metadata = _load_llama_model_metadata (
1244
- weight_type ,
1245
- use_kv_cache ,
1246
- use_sdpa_with_kv_cache ,
1247
- enable_dynamic_shape ,
1201
+ WeightType . FAIRSEQ2 if llm_config . base . fairseq2 else WeightType . LLAMA ,
1202
+ llm_config . model . use_kv_cache ,
1203
+ llm_config . model . use_sdpa_with_kv_cache ,
1204
+ llm_config . model . enable_dynamic_shape ,
1248
1205
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
1249
1206
# `Union[Tensor, Module]`.
1250
1207
model .max_seq_len ,
@@ -1257,7 +1214,7 @@ def _load_llama_model(
1257
1214
# pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
1258
1215
# Module]`.
1259
1216
model .vocab_size ,
1260
- metadata_str ,
1217
+ llm_config . base . metadata ,
1261
1218
),
1262
1219
)
1263
1220
0 commit comments