Skip to content

Commit d7d33d7

Browse files
committed
Update on "Use llm_config instead of args in export_llama functions"
Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927) [ghstack-poisoned]
2 parents a14f548 + 1a40c98 commit d7d33d7

File tree

5 files changed

+28
-18
lines changed

5 files changed

+28
-18
lines changed

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
"//caffe2:torch",
6868
"//executorch/examples/models:model_base",
6969
"//executorch/examples/models/llama:llama_transformer",
70+
"//executorch/examples/models/llama/config:llm_config",
7071
"//executorch/examples/models:checkpoint",
7172
],
7273
)
@@ -266,6 +267,7 @@ runtime.python_library(
266267
":export_library",
267268
"//executorch/examples/models/llama/config:llm_config",
268269
"fbsource//third-party/pypi/hydra-core:hydra-core",
270+
"fbsource//third-party/pypi/omegaconf:omegaconf",
269271
],
270272
)
271273

examples/models/llama/export_llama_hydra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from executorch.examples.models.llama.config.llm_config import LlmConfig
1414
from executorch.examples.models.llama.export_llama_lib import export_llama
1515
from hydra.core.config_store import ConfigStore
16+
from omegaconf import OmegaConf
1617

1718
cs = ConfigStore.instance()
1819
cs.store(name="llm_config", node=LlmConfig)
1920

2021

2122
@hydra.main(version_base=None, config_name="llm_config")
2223
def main(llm_config: LlmConfig) -> None:
23-
export_llama(llm_config)
24+
export_llama(OmegaConf.to_object(llm_config))
2425

2526

2627
if __name__ == "__main__":

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
get_vulkan_quantizer,
5757
)
5858
from executorch.util.activation_memory_profiler import generate_memory_trace
59-
from omegaconf.dictconfig import DictConfig
6059

6160
from ..model_factory import EagerModelFactory
6261
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -576,12 +575,12 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
576575

577576

578577
def export_llama(
579-
export_options: Union[argparse.Namespace, DictConfig],
578+
export_options: Union[argparse.Namespace, LlmConfig],
580579
) -> str:
581580
if isinstance(export_options, argparse.Namespace):
582581
# Legacy CLI.
583582
llm_config = convert_args_to_llm_config(export_options)
584-
elif isinstance(export_options, DictConfig):
583+
elif isinstance(export_options, LlmConfig):
585584
# Hydra CLI.
586585
llm_config = export_options
587586
else:

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import unittest
88

99
from executorch.devtools.backend_debug import get_delegation_info
10-
from executorch.examples.models.llama.config.llm_config import LlmConfig
11-
from executorch.examples.models.llama.export_llama_lib import _export_llama
10+
from executorch.examples.models.llama.export_llama_lib import (
11+
_export_llama,
12+
build_args_parser,
13+
)
1214

1315
UNWANTED_OPS = [
1416
"aten_permute_copy_default",
@@ -33,13 +35,13 @@ def test_has_expected_ops_and_op_counts(self):
3335
# we cannot test quantization args in this way
3436
# since quantization requires promoting meta tensors
3537
# to device=cpu, which requires real weights.
38+
parser = build_args_parser()
39+
args = parser.parse_args([])
40+
args.use_sdpa_with_kv_cache = True
41+
args.use_kv_cache = True
42+
args.verbose = True
3643

37-
llm_config = LlmConfig()
38-
llm_config.model.use_sdpa_with_kv_cache = True
39-
llm_config.model.use_kv_cache = True
40-
llm_config.debug.verbose = True
41-
42-
builder = _export_llama(llm_config)
44+
builder = _export_llama(args)
4345
graph_module = builder.edge_manager.exported_program().graph_module
4446
delegation_info = get_delegation_info(graph_module)
4547

examples/models/llama3_2_vision/runner/eager.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import torch
11+
from executorch.examples.models.llama.config.llm_config import LlmConfig
1112

1213
from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export
1314
from executorch.examples.models.llama.runner.eager import execute_runner
@@ -22,18 +23,23 @@ class EagerLlamaRunner(TorchTuneLlamaRunner):
2223
Runs llama in eager mode with provided checkpoint file.
2324
"""
2425

25-
def __init__(self, args):
26-
with open(args.params, "r") as f:
26+
def __init__(
27+
self,
28+
llm_config: LlmConfig,
29+
tokenizer_config_path: Optional[str] = None,
30+
use_attention_sink: bool = False,
31+
):
32+
with open(llm_config.base.params, "r") as f:
2733
params = json.loads(f.read())
2834
super().__init__(
29-
tokenizer_path=args.tokenizer_path,
30-
max_seq_len=args.max_seq_length,
35+
tokenizer_path=llm_config.base.tokenizer_path,
36+
max_seq_len=llm_config.export.max_seq_length,
3137
max_batch_size=1,
32-
use_kv_cache=args.use_kv_cache,
38+
use_kv_cache=llm_config.model.use_kv_cache,
3339
vocab_size=params["vocab_size"],
3440
device="cuda" if torch.cuda.is_available() else "cpu",
3541
)
36-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
42+
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
3743
self.model = manager.model.eval().to(device=self.device)
3844

3945
def forward(

0 commit comments

Comments
 (0)