Skip to content

Commit 7cfac37

Browse files
Isotr0pyywang96
authored and
jimpang
committed
[Model] Initialize Fuyu-8B support (vllm-project#3924)
Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 381e2cb commit 7cfac37

File tree

6 files changed

+844
-0
lines changed

6 files changed

+844
-0
lines changed

docs/source/models/supported_models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ Decoder-only Language Models
137137
- Phi-3-Small
138138
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
139139
-
140+
* - :code:`PersimmonForCausalLM`
141+
- Persimmon
142+
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
143+
-
140144
* - :code:`QWenLMHeadModel`
141145
- Qwen
142146
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
@@ -178,6 +182,10 @@ Vision Language Models
178182
- Models
179183
- Example HuggingFace Models
180184
- :ref:`LoRA <lora>`
185+
* - :code:`FuyuForCausalLM`
186+
- Fuyu
187+
- :code:`adept/fuyu-8b` etc.
188+
-
181189
* - :code:`LlavaForConditionalGeneration`
182190
- LLaVA-1.5
183191
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.

examples/fuyu_example.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import requests
2+
from PIL import Image
3+
4+
from vllm import LLM, SamplingParams
5+
6+
7+
def run_fuyu():
8+
llm = LLM(model="adept/fuyu-8b", max_model_len=4096)
9+
10+
# single-image prompt
11+
prompt = "What is the highest life expectancy at of male?\n"
12+
url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png"
13+
image = Image.open(requests.get(url, stream=True).raw)
14+
sampling_params = SamplingParams(temperature=0, max_tokens=64)
15+
16+
outputs = llm.generate(
17+
{
18+
"prompt": prompt,
19+
"multi_modal_data": {
20+
"image": image
21+
},
22+
},
23+
sampling_params=sampling_params)
24+
25+
for o in outputs:
26+
generated_text = o.outputs[0].text
27+
print(generated_text)
28+
29+
30+
if __name__ == "__main__":
31+
run_fuyu()

tests/models/test_fuyu.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from typing import List, Optional, Tuple, Type
2+
3+
import pytest
4+
5+
from vllm.multimodal.utils import rescale_image_size
6+
from vllm.sequence import SampleLogprobs
7+
from vllm.utils import is_cpu
8+
9+
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
10+
from .utils import check_logprobs_close
11+
12+
pytestmark = pytest.mark.vlm
13+
14+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
15+
"stop_sign": "What's the content of the image?\n", # noqa: E501
16+
"cherry_blossom": "What is the season?\n",
17+
"boardwalk": "What's in this image?\n",
18+
})
19+
20+
models = ["adept/fuyu-8b"]
21+
22+
23+
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
24+
Optional[SampleLogprobs]]):
25+
"""Sanitize vllm output to be comparable with hf output."""
26+
output_ids, output_str, out_logprobs = vllm_output
27+
28+
hf_output_str = output_str.lstrip() + "|ENDOFTEXT|"
29+
30+
return output_ids, hf_output_str, out_logprobs
31+
32+
33+
def run_test(
34+
hf_runner: Type[HfRunner],
35+
vllm_runner: Type[VllmRunner],
36+
image_assets: _ImageAssets,
37+
model: str,
38+
*,
39+
size_factors: List[float],
40+
dtype: str,
41+
max_tokens: int,
42+
num_logprobs: int,
43+
tensor_parallel_size: int,
44+
distributed_executor_backend: Optional[str] = None,
45+
):
46+
"""Inference result should be the same between hf and vllm.
47+
48+
All the image fixtures for the test is under tests/images.
49+
For huggingface runner, we provide the PIL images as input.
50+
For vllm runner, we provide MultiModalDataDict objects
51+
and corresponding vision language config as input.
52+
Note, the text input is also adjusted to abide by vllm contract.
53+
The text output is sanitized to be able to compare with hf.
54+
"""
55+
images = [asset.pil_image for asset in image_assets]
56+
57+
inputs_per_image = [(
58+
[prompt for _ in size_factors],
59+
[rescale_image_size(image, factor) for factor in size_factors],
60+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
61+
62+
# NOTE: take care of the order. run vLLM first, and then run HF.
63+
# vLLM needs a fresh new process without cuda initialization.
64+
# if we run HF first, the cuda initialization will be done and it
65+
# will hurt multiprocessing backend with fork method (the default method).
66+
67+
# max_model_len should be greater than image_feature_size
68+
with vllm_runner(model,
69+
max_model_len=2560,
70+
max_num_seqs=1,
71+
dtype=dtype,
72+
tensor_parallel_size=tensor_parallel_size,
73+
distributed_executor_backend=distributed_executor_backend,
74+
enforce_eager=True) as vllm_model:
75+
vllm_outputs_per_image = [
76+
vllm_model.generate_greedy_logprobs(prompts,
77+
max_tokens,
78+
num_logprobs=num_logprobs,
79+
images=vllm_images)
80+
for prompts, vllm_images in inputs_per_image
81+
]
82+
83+
with hf_runner(model, dtype=dtype) as hf_model:
84+
hf_model.model.get_output_embeddings = lambda: \
85+
hf_model.model.language_model.get_output_embeddings()
86+
eos_token_id = hf_model.processor.tokenizer.eos_token_id
87+
hf_outputs_per_image = [
88+
hf_model.generate_greedy_logprobs_limit(prompts,
89+
max_tokens,
90+
num_logprobs=num_logprobs,
91+
images=hf_images,
92+
eos_token_id=eos_token_id)
93+
for prompts, hf_images in inputs_per_image
94+
]
95+
96+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
97+
vllm_outputs_per_image):
98+
check_logprobs_close(
99+
outputs_0_lst=hf_outputs,
100+
outputs_1_lst=[
101+
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
102+
],
103+
name_0="hf",
104+
name_1="vllm",
105+
)
106+
107+
108+
target_dtype = "half"
109+
if is_cpu():
110+
target_dtype = "bfloat16"
111+
112+
113+
@pytest.mark.parametrize("model", models)
114+
@pytest.mark.parametrize(
115+
"size_factors",
116+
[
117+
# No image
118+
[],
119+
# Single-scale
120+
[0.25],
121+
# Single-scale, batched
122+
[0.25, 0.25, 0.25],
123+
# Multi-scale
124+
[0.25, 0.2, 0.15],
125+
],
126+
)
127+
@pytest.mark.parametrize("dtype", [target_dtype])
128+
@pytest.mark.parametrize("max_tokens", [128])
129+
@pytest.mark.parametrize("num_logprobs", [10])
130+
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
131+
dtype: str, max_tokens: int, num_logprobs: int) -> None:
132+
run_test(
133+
hf_runner,
134+
vllm_runner,
135+
image_assets,
136+
model,
137+
size_factors=size_factors,
138+
dtype=dtype,
139+
max_tokens=max_tokens,
140+
num_logprobs=num_logprobs,
141+
tensor_parallel_size=1,
142+
)

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
2424
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
2525
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
26+
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
2627
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
2728
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
2829
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
@@ -49,6 +50,7 @@
4950
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
5051
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
5152
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
53+
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
5254
"PaliGemmaForConditionalGeneration":
5355
("paligemma", "PaliGemmaForConditionalGeneration"),
5456
"PhiForCausalLM": ("phi", "PhiForCausalLM"),

0 commit comments

Comments
 (0)