Skip to content

Commit

Permalink
Add tests for all aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
aW3st committed Jan 30, 2025
1 parent 67a696f commit b1a9dff
Showing 1 changed file with 45 additions and 16 deletions.
61 changes: 45 additions & 16 deletions integration-tests/models/test_grammar_response_format_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,55 @@ class Weather(BaseModel):
unit: str
temperature: List[int]

json_payload={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
}
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
},
json=json_payload,
)

chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]

assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot

json_payload["response_format"]["type"] = "json"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)

chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]

assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot

json_payload["response_format"]["type"] = "json_schema"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)

chat_completion = response.json()
Expand Down

0 comments on commit b1a9dff

Please sign in to comment.