Skip to content

Commit 81cf909

Browse files
authored
multiple fixes, var copy
1 parent 1fd8840 commit 81cf909

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

llama_cpp/llama_grammar.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int:
523523
result = state.symbol_ids.insert(std.string(src, len), next_id)
524524
return result[0].second # type: ignore
525525

526-
527526
# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
528527
# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
529528
# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
@@ -841,23 +840,22 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
841840
raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
842841

843842

844-
previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()]
843+
previous_elements:std.vector[LlamaGrammarElement] = std.vector(out_elements[last_sym_start:]) # type: std.vector[LlamaGrammarElement]
845844

846845
if min_times == 0:
847846
out_elements.resize(last_sym_start)
848847
else:
849848
# Repeat the previous elements (min_times - 1) times
850849
for i in range(1, min_times):
851-
out_elements.extend(previous_elements)
850+
out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end())
852851

853852
last_rec_rule_id = 0 # type: int
854853
n_opt = 1 if max_times < 0 else max_times - min_times # type: int
855-
rec_rule = previous_elements # type: List[LlamaGrammarElement]
854+
rec_rule = std.vector(previous_elements) # type: List[LlamaGrammarElement]
856855

857856
for i in range(n_opt):
858-
rec_rule = previous_elements
859-
rec_rule.resize(len(previous_elements))
860-
rec_rule_id = generate_symbol_id(state, rule_name) # type: int
857+
rec_rule.resize(previous_elements.size())
858+
rec_rule_id = generate_symbol_id(state, rule_name) # type: int
861859
if i > 0 or max_times < 0:
862860
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
863861
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
@@ -868,12 +866,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
868866
if n_opt > 0:
869867
out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
870868

871-
872-
873-
874-
875-
876-
869+
877870

878871
# while (*pos) {
879872
while pos[0]:
@@ -1208,8 +1201,8 @@ def parse(src: const_char_p) -> parse_state:
12081201
if elem.value >= len(state.rules) or not state.rules[elem.value]:
12091202
# Get the name of the rule that is missing
12101203
for kv in state.symbol_ids:
1211-
if kv.second == elem.value:
1212-
raise RuntimeError("Undefined rule identifier '" + kv.first + "'")
1204+
if kv[1] == elem.value:
1205+
raise RuntimeError("Undefined rule identifier '" + kv[0] + "'")
12131206
return state
12141207
except Exception as err:
12151208
print(f"{parse.__name__}: error parsing grammar: {err}")

0 commit comments

Comments
 (0)