From b1a9dfff216aa0d8873d9b9778df3ae91c479b4f Mon Sep 17 00:00:00 2001 From: Alex Weston Date: Thu, 30 Jan 2025 14:11:05 -0500 Subject: [PATCH] Add tests for all aliases --- .../test_grammar_response_format_llama.py | 61 ++++++++++++++----- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index f2a8a96da46..809dc3dd793 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -29,26 +29,55 @@ class Weather(BaseModel): unit: str temperature: List[int] + json_payload={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "response_format": {"type": "json_object", "value": Weather.schema()}, + } # send the request response = requests.post( f"{llama_grammar.base_url}/v1/chat/completions", headers=llama_grammar.headers, - json={ - "model": "tgi", - "messages": [ - { - "role": "system", - "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", - }, - { - "role": "user", - "content": "What's the weather like the next 3 days in San Francisco, CA?", - }, - ], - "seed": 42, - "max_tokens": 500, - "response_format": {"type": "json_object", "value": Weather.schema()}, - }, + json=json_payload, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }' + assert chat_completion == response_snapshot + + json_payload["response_format"]["type"] = "json" + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json=json_payload, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }' + assert chat_completion == response_snapshot + + json_payload["response_format"]["type"] = "json_schema" + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json=json_payload, ) chat_completion = response.json()