Skip to content

Commit 1fd8840

Browse files
authored
1 parent 4c74a82 commit 1fd8840

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed

llama_cpp/llama_grammar.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
891891
pos += 1
892892
last_sym_start = out_elements.size()
893893
while pos[0] != '"':
894+
assert pos[0] is not None, "Unexpected end of input"
894895
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
895896
pos = char_pair[1]
896897
out_elements.push_back(
@@ -920,6 +921,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
920921
# : start_type;
921922
# out_elements.push_back({type, char_pair.first});
922923
while pos[0] != "]":
924+
assert pos[0] is not None, "Unexpected end of input"
923925
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
924926
pos = char_pair[1]
925927
_type = (
@@ -935,6 +937,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
935937
# }
936938
# }
937939
if pos[0] == "-" and pos[1] != "]":
940+
assert pos[1] is not None, "Unexpected end of input"
938941
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
939942
pos = endchar_pair[1]
940943
out_elements.push_back(
@@ -1159,33 +1162,59 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
11591162
elif pos[0]:
11601163
raise RuntimeError("expecting newline or end at " + str(pos))
11611164
return parse_space(pos, True)
1165+
1166+
#parse_state parse(const char * src) {
1167+
# try {
1168+
# parse_state state;
1169+
# const char * pos = parse_space(src, true);
1170+
# while (*pos) {
1171+
# pos = parse_rule(state, pos);
1172+
# }
1173+
# // Validate the state to ensure that all rules are defined
1174+
# for (const auto & rule : state.rules) {
1175+
# for (const auto & elem : rule) {
1176+
# if (elem.type == LLAMA_GRETYPE_RULE_REF) {
1177+
# // Ensure that the rule at that location exists
1178+
# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
1179+
# // Get the name of the rule that is missing
1180+
# for (const auto & kv : state.symbol_ids) {
1181+
# if (kv.second == elem.value) {
1182+
# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
1183+
# }
1184+
# }
1185+
# }
1186+
# }
1187+
# }
1188+
# }
1189+
# return state;
1190+
# } catch (const std::exception & err) {
1191+
# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
1192+
# return parse_state();
1193+
# }
1194+
#}
11621195

11631196

1164-
# parse_state parse(const char * src) {
1165-
# try {
1166-
# parse_state state;
1167-
# const char * pos = parse_space(src, true);
1168-
# while (*pos) {
1169-
# pos = parse_rule(state, pos);
1170-
# }
1171-
# return state;
1172-
# } catch (const std::exception & err) {
1173-
# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
1174-
# return parse_state();
1175-
# }
1176-
# }
11771197
def parse(src: const_char_p) -> parse_state:
11781198
try:
11791199
state = parse_state() # type: parse_state
11801200
pos = parse_space(src, True) # type: const_char_p
11811201
while pos[0]:
11821202
pos = parse_rule(state, pos)
1203+
# Validate the state to ensure that all rules are defined
1204+
for rule in state.rules:
1205+
for elem in rule:
1206+
if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF:
1207+
# Ensure that the rule at that location exists
1208+
if elem.value >= len(state.rules) or not state.rules[elem.value]:
1209+
# Get the name of the rule that is missing
1210+
for kv in state.symbol_ids:
1211+
if kv.second == elem.value:
1212+
raise RuntimeError("Undefined rule identifier '" + kv.first + "'")
11831213
return state
11841214
except Exception as err:
11851215
print(f"{parse.__name__}: error parsing grammar: {err}")
11861216
return parse_state()
11871217

1188-
11891218
# void print_grammar_char(FILE * file, uint32_t c) {
11901219
# if (0x20 <= c && c <= 0x7f) {
11911220
# fprintf(file, "%c", static_cast<char>(c));
@@ -1283,6 +1312,7 @@ def print_rule(
12831312
# }
12841313

12851314

1315+
12861316
for i, elem in enumerate(rule[:-1]):
12871317
case = elem.type # type: llama_gretype
12881318
if case is llama_gretype.LLAMA_GRETYPE_END:

0 commit comments

Comments
 (0)