Skip to content

Commit 4a875d8

Browse files
committed
Update on "Use llm_config instead of args in export_llama functions"
Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927) [ghstack-poisoned]
2 parents d7d33d7 + 41c4e81 commit 4a875d8

File tree

2 files changed

+75
-75
lines changed

2 files changed

+75
-75
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,23 @@ class ModelType(str, Enum):
4141

4242

4343
class PreqMode(str, Enum):
44+
"""
45+
If you are dealing with pre-quantized checkpoints, this used to
46+
be the way to specify them. Now you don't need to specify these
47+
options if you use a TorchAo-prequantized checkpoint, but they
48+
are still around to preservce backward compatibility.
49+
"""
50+
4451
PREQ_8DA4W = "8da4w"
4552
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
4653

4754

4855
@dataclass
4956
class BaseConfig:
5057
"""
51-
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
52-
For each of these different models, you can expect each of these fields to change.
58+
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
59+
and are the minimal set of parameters needed to load the pretrained
60+
eager model and its weights.
5361
"""
5462

5563
model_class: ModelType = ModelType.LLAMA3
@@ -73,6 +81,12 @@ class BaseConfig:
7381

7482

7583
class DtypeOverride(str, Enum):
84+
"""
85+
DType of the model. Highly recommended to use "fp32", unless you want to
86+
export without a backend, in which case you can also use "bf16". "fp16"
87+
is not recommended.
88+
"""
89+
7690
FP32 = "fp32"
7791
FP16 = "fp16"
7892
BF16 = "bf16"
@@ -81,10 +95,10 @@ class DtypeOverride(str, Enum):
8195
@dataclass
8296
class ModelConfig:
8397
"""
84-
These are not necessarily specific to the model, but are needed to finish off
85-
the rest of the model configuration in eager. You can think of these like
86-
optimizations / actual configurations. The same ModelConfig can be applied
87-
to different models.
98+
Configurations not necessarily specific to the model, but are needed to
99+
finish off the rest of the model configuration in eager. You can think
100+
of these like optimizations / actual configurations. The same ModelConfig
101+
can be applied to multiple models.
88102
"""
89103

90104
dtype_override: DtypeOverride = DtypeOverride.FP32
@@ -109,6 +123,10 @@ class ModelConfig:
109123

110124
@dataclass
111125
class ExportConfig:
126+
"""
127+
Configures properties relevant to the export process.
128+
"""
129+
112130
max_seq_length: int = 128
113131
max_context_length: int = 128
114132
output_dir: Optional[str] = None
@@ -124,6 +142,10 @@ class ExportConfig:
124142

125143
@dataclass
126144
class DebugConfig:
145+
"""
146+
Configures options to debug the export process.
147+
"""
148+
127149
profile_memory: bool = False
128150
profile_path: Optional[str] = None
129151
generate_etrecord: bool = False
@@ -137,6 +159,14 @@ class DebugConfig:
137159

138160

139161
class Pt2eQuantize(str, Enum):
162+
"""
163+
Type of backend-specific Pt2e quantization strategy to use.
164+
165+
Pt2e uses a different quantization library that is graph-based
166+
compared to `qmode`, which is also specified in the QuantizationConfig
167+
and is source transform-based.
168+
"""
169+
140170
XNNPACK_DYNAMIC = "xnnpack_dynamic"
141171
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142172
QNN_8A8W = "qnn_8a8w"
@@ -157,6 +187,10 @@ class SpinQuant(str, Enum):
157187

158188
@dataclass
159189
class QuantizationConfig:
190+
"""
191+
Configures how the model should be quantized (PTQ).
192+
"""
193+
160194
qmode: Optional[str] = None
161195
embedding_quantize: Optional[str] = None
162196
pt2e_quantize: Optional[Pt2eQuantize] = None
@@ -248,6 +282,11 @@ class MPSConfig:
248282

249283
@dataclass
250284
class BackendConfig:
285+
"""
286+
Configures which backends should be used and how the backends
287+
should be set up.
288+
"""
289+
251290
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig)
252291
coreml: CoreMLConfig = field(default_factory=CoreMLConfig)
253292
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
@@ -262,6 +301,10 @@ class BackendConfig:
262301

263302
@dataclass
264303
class LlmConfig:
304+
"""
305+
The overall configuration for customizing the LLM export process.
306+
"""
307+
265308
base: BaseConfig = field(default_factory=BaseConfig)
266309
model: ModelConfig = field(default_factory=ModelConfig)
267310
export: ExportConfig = field(default_factory=ExportConfig)

examples/models/llama/export_llama_lib.py

Lines changed: 26 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -661,36 +661,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
661661
canonical_path(llm_config.base.params) if llm_config.base.params else None
662662
)
663663
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)
664-
weight_type = WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA
665664

666-
# Convert dtype override string to actual type
665+
llm_config.base.checkpoint = checkpoint_path
666+
llm_config.base.checkpoint_dir = checkpoint_dir
667+
llm_config.base.params = params_path
668+
llm_config.export.output_dir = output_dir_path
669+
670+
# Convert dtype override string to actual type.
667671
dtype_override = DType[llm_config.model.dtype_override]
668672

669-
edge_manager = _load_llama_model(
670-
llm_config,
671-
checkpoint=checkpoint_path,
672-
checkpoint_dir=checkpoint_dir,
673-
params_path=params_path,
674-
use_kv_cache=llm_config.model.use_kv_cache,
675-
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
676-
generate_full_logits=llm_config.debug.generate_full_logits,
677-
weight_type=weight_type,
678-
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
679-
calibration_tasks=llm_config.quantization.calibration_tasks,
680-
calibration_limit=llm_config.quantization.calibration_limit,
681-
calibration_seq_length=llm_config.quantization.calibration_seq_length,
682-
calibration_data=llm_config.quantization.calibration_data,
683-
tokenizer_path=llm_config.base.tokenizer_path,
684-
verbose=llm_config.debug.verbose,
685-
max_seq_len=llm_config.export.max_seq_length,
686-
max_context_len=llm_config.export.max_context_length,
687-
input_prune_map_path=llm_config.model.input_prune_map,
688-
output_prune_map_path=llm_config.model.output_prune_map,
689-
metadata_str=llm_config.base.metadata,
690-
dtype_override=dtype_override,
691-
use_qnn=llm_config.backend.qnn.enabled,
692-
export_only=llm_config.export.export_only,
693-
)
673+
edge_manager = _load_llama_model(llm_config)
694674

695675
# At this point, the model is loaded in the default fp32.
696676

@@ -1167,32 +1147,7 @@ def _load_llama_model_metadata(
11671147
return metadata
11681148

11691149

1170-
def _load_llama_model(
1171-
llm_config: LlmConfig,
1172-
*,
1173-
checkpoint: Optional[str] = None,
1174-
checkpoint_dir: Optional[str] = None,
1175-
params_path: Optional[str] = None,
1176-
use_kv_cache: bool = False,
1177-
use_sdpa_with_kv_cache: bool = False,
1178-
generate_full_logits: bool = False,
1179-
weight_type: WeightType = WeightType.LLAMA,
1180-
enable_dynamic_shape: bool = False,
1181-
calibration_tasks: Optional[List[str]] = None,
1182-
calibration_limit: Optional[int] = None,
1183-
calibration_seq_length: Optional[int] = None,
1184-
calibration_data: Optional[str] = None,
1185-
tokenizer_path: Optional[str] = None,
1186-
verbose: bool = False,
1187-
max_seq_len: int = 128,
1188-
max_context_len: int = 128,
1189-
input_prune_map_path: Optional[str] = None,
1190-
output_prune_map_path: Optional[str] = None,
1191-
metadata_str: Optional[str] = None,
1192-
dtype_override: Optional[DType] = None,
1193-
use_qnn: bool = False,
1194-
export_only: bool = False,
1195-
) -> "LLMEdgeManager":
1150+
def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11961151
"""
11971152
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
11981153
can help further lower the model to ExecuTorch.
@@ -1220,31 +1175,33 @@ def _load_llama_model(
12201175
llm_config=llm_config,
12211176
)
12221177
)
1178+
# Convert dtype override string to actual type.
1179+
dtype_override = DType[llm_config.model.dtype_override]
12231180

12241181
return LLMEdgeManager(
12251182
model=model,
12261183
modelname=modelname,
12271184
max_seq_len=model.max_seq_len, # type: ignore
12281185
dtype=dtype_override,
1229-
use_kv_cache=use_kv_cache,
1230-
generate_full_logits=generate_full_logits,
1186+
use_kv_cache=llm_config.model.use_kv_cache,
1187+
generate_full_logits=llm_config.debug.generate_full_logits,
12311188
example_inputs=example_inputs,
12321189
example_kwarg_inputs=example_kwarg_inputs,
12331190
dynamic_shapes=dynamic_shapes,
1234-
enable_dynamic_shape=enable_dynamic_shape,
1235-
calibration_tasks=calibration_tasks,
1236-
calibration_limit=calibration_limit,
1237-
calibration_seq_length=calibration_seq_length,
1238-
calibration_data=calibration_data,
1239-
tokenizer_path=tokenizer_path,
1240-
use_legacy_export=use_qnn,
1241-
save_exported_program=export_only,
1242-
verbose=verbose,
1191+
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
1192+
calibration_tasks=llm_config.quantization.calibration_tasks,
1193+
calibration_limit=llm_config.quantization.calibration_limit,
1194+
calibration_seq_length=llm_config.quantization.calibration_seq_length,
1195+
calibration_data=llm_config.quantization.calibration_data,
1196+
tokenizer_path=llm_config.base.tokenizer_path,
1197+
use_legacy_export=llm_config.backend.qnn.enabled,
1198+
save_exported_program=llm_config.export.export_only,
1199+
verbose=llm_config.debug.verbose,
12431200
metadata=_load_llama_model_metadata(
1244-
weight_type,
1245-
use_kv_cache,
1246-
use_sdpa_with_kv_cache,
1247-
enable_dynamic_shape,
1201+
WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA,
1202+
llm_config.model.use_kv_cache,
1203+
llm_config.model.use_sdpa_with_kv_cache,
1204+
llm_config.model.enable_dynamic_shape,
12481205
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
12491206
# `Union[Tensor, Module]`.
12501207
model.max_seq_len,
@@ -1257,7 +1214,7 @@ def _load_llama_model(
12571214
# pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
12581215
# Module]`.
12591216
model.vocab_size,
1260-
metadata_str,
1217+
llm_config.base.metadata,
12611218
),
12621219
)
12631220

0 commit comments

Comments
 (0)