2
2
3
3
Run `pytest tests/models/test_mistral.py`.
4
4
"""
5
- import pickle
5
+ import json
6
6
import uuid
7
- from typing import Any , Dict , List
7
+ from dataclasses import asdict
8
+ from typing import Any , Dict , List , Optional , Tuple
8
9
9
10
import pytest
10
11
from mistral_common .protocol .instruct .messages import ImageURLChunk
14
15
15
16
from vllm import EngineArgs , LLMEngine , SamplingParams , TokensPrompt
16
17
from vllm .multimodal import MultiModalDataBuiltins
18
+ from vllm .sequence import Logprob , SampleLogprobs
17
19
18
20
from .utils import check_logprobs_close
19
21
@@ -81,13 +83,33 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
81
83
LIMIT_MM_PER_PROMPT = dict (image = 4 )
82
84
83
85
MAX_MODEL_LEN = [8192 , 65536 ]
84
- FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle "
85
- FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle "
86
+ FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json "
87
+ FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json "
86
88
89
+ OutputsLogprobs = List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]
87
90
88
- def load_logprobs (filename : str ) -> Any :
89
- with open (filename , 'rb' ) as f :
90
- return pickle .load (f )
91
+
92
+ # For the test author to store golden output in JSON
93
+ def _dump_outputs_w_logprobs (outputs : OutputsLogprobs , filename : str ) -> None :
94
+ json_data = [(tokens , text ,
95
+ [{k : asdict (v )
96
+ for k , v in token_logprobs .items ()}
97
+ for token_logprobs in (logprobs or [])])
98
+ for tokens , text , logprobs in outputs ]
99
+
100
+ with open (filename , "w" ) as f :
101
+ json .dump (json_data , f )
102
+
103
+
104
+ def load_outputs_w_logprobs (filename : str ) -> OutputsLogprobs :
105
+ with open (filename , "rb" ) as f :
106
+ json_data = json .load (f )
107
+
108
+ return [(tokens , text ,
109
+ [{int (k ): Logprob (** v )
110
+ for k , v in token_logprobs .items ()}
111
+ for token_logprobs in logprobs ])
112
+ for tokens , text , logprobs in json_data ]
91
113
92
114
93
115
@pytest .mark .skip (
@@ -103,7 +125,7 @@ def test_chat(
103
125
model : str ,
104
126
dtype : str ,
105
127
) -> None :
106
- EXPECTED_CHAT_LOGPROBS = load_logprobs (FIXTURE_LOGPROBS_CHAT )
128
+ EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs (FIXTURE_LOGPROBS_CHAT )
107
129
with vllm_runner (
108
130
model ,
109
131
dtype = dtype ,
@@ -120,10 +142,10 @@ def test_chat(
120
142
outputs .extend (output )
121
143
122
144
logprobs = vllm_runner ._final_steps_generate_w_logprobs (outputs )
123
- check_logprobs_close (outputs_0_lst = logprobs ,
124
- outputs_1_lst = EXPECTED_CHAT_LOGPROBS ,
125
- name_0 = "output " ,
126
- name_1 = "h100_ref " )
145
+ check_logprobs_close (outputs_0_lst = EXPECTED_CHAT_LOGPROBS ,
146
+ outputs_1_lst = logprobs ,
147
+ name_0 = "h100_ref " ,
148
+ name_1 = "output " )
127
149
128
150
129
151
@pytest .mark .skip (
@@ -133,7 +155,7 @@ def test_chat(
133
155
@pytest .mark .parametrize ("model" , MODELS )
134
156
@pytest .mark .parametrize ("dtype" , ["bfloat16" ])
135
157
def test_model_engine (vllm_runner , model : str , dtype : str ) -> None :
136
- EXPECTED_ENGINE_LOGPROBS = load_logprobs (FIXTURE_LOGPROBS_ENGINE )
158
+ EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs (FIXTURE_LOGPROBS_ENGINE )
137
159
args = EngineArgs (
138
160
model = model ,
139
161
tokenizer_mode = "mistral" ,
@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
162
184
break
163
185
164
186
logprobs = vllm_runner ._final_steps_generate_w_logprobs (outputs )
165
- check_logprobs_close (outputs_0_lst = logprobs ,
166
- outputs_1_lst = EXPECTED_ENGINE_LOGPROBS ,
167
- name_0 = "output " ,
168
- name_1 = "h100_ref " )
187
+ check_logprobs_close (outputs_0_lst = EXPECTED_ENGINE_LOGPROBS ,
188
+ outputs_1_lst = logprobs ,
189
+ name_0 = "h100_ref " ,
190
+ name_1 = "output " )
0 commit comments