Skip to content

Commit 7d8ace2

Browse files
DarkLight1337garg-amit
authored andcommitted
[CI/Build] Update pixtral tests to use JSON (vllm-project#8436)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent dd97620 commit 7d8ace2

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed
-20.4 KB
Binary file not shown.
-20.4 KB
Binary file not shown.

tests/models/test_pixtral.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
33
Run `pytest tests/models/test_mistral.py`.
44
"""
5-
import pickle
5+
import json
66
import uuid
7-
from typing import Any, Dict, List
7+
from dataclasses import asdict
8+
from typing import Any, Dict, List, Optional, Tuple
89

910
import pytest
1011
from mistral_common.protocol.instruct.messages import ImageURLChunk
@@ -14,6 +15,7 @@
1415

1516
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
1617
from vllm.multimodal import MultiModalDataBuiltins
18+
from vllm.sequence import Logprob, SampleLogprobs
1719

1820
from .utils import check_logprobs_close
1921

@@ -81,13 +83,33 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
8183
LIMIT_MM_PER_PROMPT = dict(image=4)
8284

8385
MAX_MODEL_LEN = [8192, 65536]
84-
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
85-
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
86+
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
87+
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
8688

89+
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
8790

88-
def load_logprobs(filename: str) -> Any:
89-
with open(filename, 'rb') as f:
90-
return pickle.load(f)
91+
92+
# For the test author to store golden output in JSON
93+
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
94+
json_data = [(tokens, text,
95+
[{k: asdict(v)
96+
for k, v in token_logprobs.items()}
97+
for token_logprobs in (logprobs or [])])
98+
for tokens, text, logprobs in outputs]
99+
100+
with open(filename, "w") as f:
101+
json.dump(json_data, f)
102+
103+
104+
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
105+
with open(filename, "rb") as f:
106+
json_data = json.load(f)
107+
108+
return [(tokens, text,
109+
[{int(k): Logprob(**v)
110+
for k, v in token_logprobs.items()}
111+
for token_logprobs in logprobs])
112+
for tokens, text, logprobs in json_data]
91113

92114

93115
@pytest.mark.skip(
@@ -103,7 +125,7 @@ def test_chat(
103125
model: str,
104126
dtype: str,
105127
) -> None:
106-
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
128+
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
107129
with vllm_runner(
108130
model,
109131
dtype=dtype,
@@ -120,10 +142,10 @@ def test_chat(
120142
outputs.extend(output)
121143

122144
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
123-
check_logprobs_close(outputs_0_lst=logprobs,
124-
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
125-
name_0="output",
126-
name_1="h100_ref")
145+
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
146+
outputs_1_lst=logprobs,
147+
name_0="h100_ref",
148+
name_1="output")
127149

128150

129151
@pytest.mark.skip(
@@ -133,7 +155,7 @@ def test_chat(
133155
@pytest.mark.parametrize("model", MODELS)
134156
@pytest.mark.parametrize("dtype", ["bfloat16"])
135157
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
136-
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
158+
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
137159
args = EngineArgs(
138160
model=model,
139161
tokenizer_mode="mistral",
@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
162184
break
163185

164186
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
165-
check_logprobs_close(outputs_0_lst=logprobs,
166-
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
167-
name_0="output",
168-
name_1="h100_ref")
187+
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
188+
outputs_1_lst=logprobs,
189+
name_0="h100_ref",
190+
name_1="output")

0 commit comments

Comments
 (0)