11
11
// The module takes in a string as input and emits a string as output.
12
12
13
13
#include < executorch/examples/models/llama/runner/runner.h>
14
-
15
- #include < executorch/extension/llm/runner/util.h>
14
+ #include < executorch/extension/module/module.h>
16
15
17
16
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
18
17
#include < pytorch/tokenizers/hf_tokenizer.h>
@@ -26,41 +25,14 @@ using ::executorch::runtime::Result;
26
25
27
26
namespace llm = ::executorch::extension::llm;
28
27
29
- namespace {
30
- static constexpr auto kEnableDynamicShape = " enable_dynamic_shape" ;
31
- static constexpr auto kBosId = " get_bos_id" ;
32
- static constexpr auto kEosIds = " get_eos_ids" ;
33
- static constexpr auto kMaxSeqLen = " get_max_seq_len" ;
34
- static constexpr auto kMaxContextLen = " get_max_context_len" ;
35
- static constexpr auto kVocabSize = " get_vocab_size" ;
36
- static constexpr auto kUseKVCache = " use_kv_cache" ;
37
- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
38
-
39
- std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
40
- const std::string& tokenizer_path) {
41
- auto json_tokenizer = std::make_unique<tokenizers::HFTokenizer>();
42
- if (json_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
43
- ET_LOG (Info, " Loaded json tokenizer" );
44
- return json_tokenizer;
45
- }
46
-
47
- auto tiktoken_tokenizer = get_tiktoken_for_llama ();
48
- if (tiktoken_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
49
- ET_LOG (Info, " Loaded TikToken tokenizer" );
50
- return tiktoken_tokenizer;
51
- }
52
-
53
- auto bpe_tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
54
- if (bpe_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
55
- ET_LOG (Info, " Loaded BPE tokenizer" );
56
- return bpe_tokenizer;
57
- }
58
-
59
- return nullptr ;
28
+ std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
29
+ const std::string& tokenizer_path,
30
+ Version version) {
31
+ auto special_tokens = get_special_tokens (version);
32
+ return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
60
33
}
61
- } // namespace
62
34
63
- std::unique_ptr<Runner> Runner::create (
35
+ std::unique_ptr<llm::TextLLMRunner> create_llama_runner (
64
36
const std::string& model_path,
65
37
const std::string& tokenizer_path,
66
38
std::optional<const std::string> data_path,
@@ -71,309 +43,19 @@ std::unique_ptr<Runner> Runner::create(
71
43
model_path.c_str (),
72
44
tokenizer_path.c_str ());
73
45
74
- // Create the Module
75
- std::unique_ptr<Module> module ;
76
- if (data_path.has_value ()) {
77
- module = std::make_unique<Module>(
78
- model_path, data_path.value (), Module::LoadMode::File);
79
- } else {
80
- module = std::make_unique<Module>(model_path, Module::LoadMode::File);
81
- }
82
-
83
- // Initialize metadata with default values
84
- std::unordered_map<std::string, int64_t > metadata ({
85
- {kEnableDynamicShape , false },
86
- {kMaxSeqLen , 128 },
87
- {kMaxContextLen , 128 },
88
- {kUseKVCache , true },
89
- {kUseSDPAWithKVCache , false },
90
- });
91
-
92
46
// Create and load tokenizer
93
47
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94
- load_tokenizer (tokenizer_path);
48
+ load_llama_tokenizer (tokenizer_path, Version::Default );
95
49
96
- // Fallback to BPE tokenizer if tiktoken fails
97
50
if (tokenizer == nullptr ) {
98
51
ET_LOG (
99
52
Info,
100
53
" Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types" ,
101
54
tokenizer_path.c_str ());
102
55
return nullptr ;
103
56
}
104
-
105
- ET_LOG (Info, " Reading metadata from model" );
106
-
107
- // Set tokenizer-related metadata
108
- metadata[kBosId ] = tokenizer->bos_tok ();
109
- auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
110
- std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
111
- metadata[kVocabSize ] = tokenizer->vocab_size ();
112
-
113
- // Read metadata from the model
114
- auto method_names_result = module ->method_names ();
115
- if (method_names_result.error () != Error::Ok) {
116
- ET_LOG (Error, " Failed reading method names" );
117
- return nullptr ;
118
- }
119
- const auto method_names = method_names_result.get ();
120
-
121
- for (auto & pair : metadata) {
122
- const auto & method_name = pair.first ;
123
- auto & value = pair.second ;
124
-
125
- if (method_names.count (method_name)) {
126
- auto get_result = module ->get (method_name);
127
- value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
128
- } else {
129
- ET_LOG (
130
- Info,
131
- " Method %s not found, using the default value %" PRId64,
132
- method_name.c_str (),
133
- value);
134
- }
135
- ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
136
- }
137
-
138
- // Get EOS IDs if available
139
- if (method_names.count (kEosIds )) {
140
- eos_ids->clear ();
141
- auto execute_result = module ->execute (kEosIds );
142
- if (execute_result.error () != Error::Ok) {
143
- ET_LOG (Error, " Failed to execute %s" , kEosIds );
144
- return nullptr ;
145
- }
146
- for (const auto & eos_id : execute_result.get ()) {
147
- auto value = eos_id.toScalar ().to <int64_t >();
148
- eos_ids->emplace (value);
149
- ET_LOG (Info, " eos_id = %" PRId64, value);
150
- }
151
- }
152
-
153
- // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154
- // TextPrefiller and TextTokenGenerator
155
- auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156
- module .get (), metadata.at (kUseKVCache ));
157
-
158
- // Create text_prefiller
159
- auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160
- text_decoder_runner.get (),
161
- metadata.at (kUseKVCache ),
162
- metadata.at (kEnableDynamicShape ),
163
- metadata.at (kMaxSeqLen ));
164
-
165
- // Create text_token_generator with stats
166
- auto stats = std::make_unique<llm::Stats>();
167
- auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168
- tokenizer.get (),
169
- text_decoder_runner.get (),
170
- metadata.at (kUseKVCache ),
171
- std::move (eos_ids),
172
- stats.get ());
173
-
174
- // Create and return the Runner instance
175
- return std::make_unique<Runner>(
176
- std::move (metadata),
177
- std::move (tokenizer),
178
- std::move (module ),
179
- std::move (text_decoder_runner),
180
- std::move (text_prefiller),
181
- std::move (text_token_generator),
182
- std::move (stats),
183
- temperature);
184
- }
185
-
186
- Runner::Runner (
187
- std::unordered_map<std::string, int64_t > metadata,
188
- std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
189
- std::unique_ptr<::executorch::extension::Module> module ,
190
- std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
191
- text_decoder_runner,
192
- std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
193
- std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
194
- text_token_generator,
195
- std::unique_ptr<::executorch::extension::llm::Stats> stats,
196
- float temperature)
197
- : tokenizer_(std::move(tokenizer)),
198
- metadata_ (std::move(metadata)),
199
- module_(std::move(module )),
200
- text_decoder_runner_(std::move(text_decoder_runner)),
201
- text_prefiller_(std::move(text_prefiller)),
202
- text_token_generator_(std::move(text_token_generator)),
203
- stats_(std::move(stats)),
204
- temperature_(temperature) {
205
- // Note: This constructor assumes that text_prefiller and text_token_generator
206
- // already have references to the Module and TextDecoderRunner they need
207
- }
208
-
209
- bool Runner::is_loaded () const {
210
- return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
211
- }
212
-
213
- Error Runner::load () {
214
- if (is_loaded ()) {
215
- return Error::Ok;
216
- }
217
- ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
218
- ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
219
- return Error::Ok;
220
- }
221
-
222
- // Don't print with the same priority during warmup
223
- #define RUNNER_ET_LOG (warmup, format, ...) \
224
- if (warmup) { \
225
- ET_LOG (Debug, format, __VA_ARGS__); \
226
- } else { \
227
- ET_LOG (Info, format, __VA_ARGS__); \
228
- }
229
-
230
- Error Runner::generate (
231
- const std::string& prompt,
232
- const ::executorch::extension::llm::GenerationConfig& config,
233
- std::function<void (const std::string&)> token_callback,
234
- std::function<void(const llm::Stats&)> stats_callback) {
235
- // Prepare the inputs.
236
- // Use ones-initialized inputs.
237
- ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
238
- if (!is_loaded ()) {
239
- stats_->model_load_start_ms = llm::time_in_ms ();
240
- ET_CHECK_OK_OR_RETURN_ERROR (load ());
241
- stats_->model_load_end_ms = llm::time_in_ms ();
242
- }
243
-
244
- if (config.warming ) {
245
- ET_LOG (Info, " Doing a warmup run..." );
246
- }
247
-
248
- RUNNER_ET_LOG (
249
- config.warming ,
250
- " RSS after loading model: %f MiB (0 if unsupported)" ,
251
- llm::get_rss_bytes () / 1024.0 / 1024.0 );
252
-
253
- // Wrap the token_callback with print function
254
- std::function<void (const std::string&)> wrapped_callback =
255
- [token_callback, config](const std::string& piece) {
256
- if (!config.warming ) {
257
- llm::safe_printf (piece.c_str ());
258
- fflush (stdout);
259
- }
260
- if (token_callback) {
261
- token_callback (piece);
262
- }
263
- };
264
- // First token time only measures the time it takes to encode the prompt and
265
- // return a response token.
266
-
267
- stats_->inference_start_ms = llm::time_in_ms ();
268
- shouldStop_ = false ;
269
-
270
- ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
271
- prompt,
272
- /* bos */ 0 ,
273
- /* eos */ 0 );
274
-
275
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
276
- encode_res.error (), " Failed to encode prompt %s" , prompt.c_str ());
277
-
278
- // encode the (string) prompt into tokens sequence
279
- std::vector<uint64_t > prompt_tokens = encode_res.get ();
280
- int num_prompt_tokens = prompt_tokens.size ();
281
-
282
- ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
283
- ET_CHECK_MSG (
284
- num_prompt_tokens < metadata_.at (kMaxContextLen ),
285
- " num_prompt_tokens %d >= max_seq_len_ %" PRId64
286
- " , Max seq length exceeded - please increase max seq len value in your export script" ,
287
- num_prompt_tokens,
288
- metadata_.at (kMaxContextLen ));
289
-
290
- // Determine max_new_tokens using the GenerationConfig's resolve method
291
- int max_new_tokens = config.resolve_max_new_tokens (
292
- metadata_.at (kMaxContextLen ), num_prompt_tokens);
293
-
294
- ET_LOG (Info, " Max new tokens resolved: %d" , max_new_tokens);
295
-
296
- // Prefill first
297
- // Here feed all tokens to the model and get the next predicted token
298
- // after the prompt. After that we will enter generate loop.
299
-
300
- // print prompts
301
- if (config.echo ) {
302
- wrapped_callback (prompt);
303
- }
304
- int64_t pos = 0 ;
305
- auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
306
- ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
307
- uint64_t cur_token = prefill_res.get ();
308
- stats_->first_token_ms = llm::time_in_ms ();
309
- stats_->prompt_eval_end_ms = llm::time_in_ms ();
310
-
311
- // print the first token from prefill. No prev_token so use cur_token for it.
312
- wrapped_callback (
313
- ET_UNWRAP_TOKENIZER (tokenizer_->decode (cur_token, cur_token)));
314
- RUNNER_ET_LOG (
315
- config.warming ,
316
- " RSS after prompt prefill: %f MiB (0 if unsupported)" ,
317
- llm::get_rss_bytes () / 1024.0 / 1024.0 );
318
-
319
- // start the main loop
320
- prompt_tokens.push_back (cur_token);
321
-
322
- // Generate max_new_tokens - 1 because prefill already generated 1 token.
323
- int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
324
- prompt_tokens,
325
- num_prompt_tokens,
326
- max_new_tokens - 1 ,
327
- temperature_ == -1 .0f ? config.temperature : temperature_,
328
- wrapped_callback));
329
-
330
- stats_->inference_end_ms = llm::time_in_ms ();
331
- if (!config.warming ) {
332
- printf (" \n " );
333
- }
334
- RUNNER_ET_LOG (
335
- config.warming ,
336
- " RSS after finishing text generation: %f MiB (0 if unsupported)" ,
337
- llm::get_rss_bytes () / 1024.0 / 1024.0 );
338
-
339
- if (num_generated_tokens == max_new_tokens) {
340
- RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
341
- }
342
-
343
- stats_->num_prompt_tokens = num_prompt_tokens;
344
- stats_->num_generated_tokens = num_generated_tokens;
345
-
346
- if (config.warming ) {
347
- ET_LOG (Info, " Warmup run finished!" );
348
- } else {
349
- // Do not print report during warmup
350
- ::executorch::llm::print_report (*stats_);
351
- }
352
- if (stats_callback) {
353
- stats_callback (*stats_);
354
- }
355
-
356
- return Error::Ok;
357
- }
358
-
359
- Error Runner::warmup (const std::string& prompt, int32_t max_new_tokens) {
360
- // Create a GenerationConfig for warmup
361
- llm::GenerationConfig config{
362
- .echo = false , .max_new_tokens = max_new_tokens, .warming = true };
363
-
364
- // Call generate with the warmup config
365
- Error err = generate (prompt, config);
366
-
367
- // Reset stats after warmup, not resetting the std::unique_ptr!
368
- stats_->reset ();
369
- return err;
57
+ return llm::create_text_llm_runner (
58
+ model_path, std::move (tokenizer), data_path);
370
59
}
371
60
372
- void Runner::stop () {
373
- if (is_loaded ()) {
374
- text_token_generator_->stop ();
375
- } else {
376
- ET_LOG (Error, " Token generator is not loaded, cannot stop" );
377
- }
378
- }
379
61
} // namespace example
0 commit comments