Skip to content

Commit 05c1d56

Browse files
committed
Test llamacpp when successive regex-guided generations
1 parent c1851df commit 05c1d56

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tests/generate/test_integration_llamacpp.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,28 @@ class Spam(BaseModel):
328328

329329
def test_llamacpp_json_function(model):
330330
model.model.reset()
331-
prompt = "<|im_start|>user\nOutput arguments for the function<|im_end|>\n<|im_start|>assistant\n"
331+
prompt = "<|im_start|>user\nOutput arguments for the function, array with 2 elements<|im_end|>\n<|im_start|>assistant\n"
332332

333333
def function(foo: int, bar: List[int]):
334334
return foo + sum(bar)
335335

336336
rng = torch.Generator(device="cpu")
337-
rng.manual_seed(0)
337+
rng.manual_seed(10)
338338
sequence = generate.json(model, function)(
339339
prompt, max_tokens=100, temperature=0.0, rng=rng
340340
)
341341
assert isinstance(sequence, dict)
342342
assert isinstance(function(**sequence), int)
343+
344+
345+
def test_llamacpp_successive_choices(model):
346+
model.model.reset()
347+
348+
choose = generate.regex(model, r"(one|two|three)")
349+
assert choose("pick a numner") in ["one", "two", "three"]
350+
351+
cities = ["New York", "Paris", "San Francisco"]
352+
city = generate.choice(model, cities)
353+
assert city("pick a city") in cities
354+
355+
assert choose("a number") in ["one", "two", "three"]

0 commit comments

Comments
 (0)