Skip to content

Commit 8f55962

Browse files
reidliu41reidliu41
and
reidliu41
authored
[Misc] refactor prompt embedding examples (#18405)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
1 parent be48360 commit 8f55962

File tree

3 files changed

+191
-102
lines changed

3 files changed

+191
-102
lines changed

docs/source/features/prompt_embeds.md

Lines changed: 2 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -20,59 +20,7 @@ To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPromp
2020

2121
You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples:
2222

23-
```python
24-
from vllm import LLM
25-
import transformers
26-
27-
model_name = "meta-llama/Llama-3.2-1B-Instruct"
28-
29-
# Transformers
30-
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
31-
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
32-
33-
llm = LLM(model=model_name, enable_prompt_embeds=True)
34-
35-
# Refer to the HuggingFace repo for the correct format to use
36-
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
37-
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
38-
39-
embedding_layer = transformers_model.get_input_embeddings()
40-
prompt_embeds = embedding_layer(token_ids).squeeze(0)
41-
42-
# Single prompt inference
43-
outputs = llm.generate({
44-
"prompt_embeds": prompt_embeds,
45-
})
46-
47-
for o in outputs:
48-
generated_text = o.outputs[0].text
49-
print(generated_text)
50-
51-
# Batch inference
52-
53-
chats = [
54-
[{"role": "user", "content": "Please tell me about the capital of France."}],
55-
[{"role": "user", "content": "When is the day longest during the year?"}],
56-
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}]
57-
]
58-
59-
token_ids_list = [
60-
tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats
61-
]
62-
prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list]
63-
64-
outputs = llm.generate(
65-
[
66-
{
67-
"prompt_embeds": prompt_embeds,
68-
} for prompt_embeds in prompt_embeds_list
69-
]
70-
)
71-
72-
for o in outputs:
73-
generated_text = o.outputs[0].text
74-
print(generated_text)
75-
```
23+
<gh-file:examples/offline_inference/prompt_embed_inference.py>
7624

7725
## Online Serving
7826

@@ -93,52 +41,4 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \
9341

9442
Then, you can use the OpenAI client as follows:
9543

96-
```python
97-
from openai import OpenAI
98-
import transformers
99-
import torch
100-
101-
openai_api_key = "EMPTY"
102-
openai_api_base = "http://localhost:8000/v1"
103-
104-
client = OpenAI(
105-
api_key=openai_api_key,
106-
base_url=openai_api_base,
107-
)
108-
109-
model_name = "meta-llama/Llama-3.2-1B-Instruct"
110-
111-
# Transformers
112-
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
113-
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
114-
115-
116-
# Refer to the HuggingFace repo for the correct format to use
117-
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
118-
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
119-
120-
embedding_layer = transformers_model.get_input_embeddings()
121-
prompt_embeds = embedding_layer(token_ids).squeeze(0)
122-
123-
# Prompt embeddings
124-
buffer = io.BytesIO()
125-
torch.save(prompt_embeds, buffer)
126-
buffer.seek(0)
127-
binary_data = buffer.read()
128-
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
129-
130-
131-
completion = client_with_prompt_embeds.completions.create(
132-
model=model_name,
133-
# NOTE: The OpenAI client does not allow `None` as an input to
134-
# `prompt`. Use an empty string if you have no text prompts.
135-
prompt="",
136-
max_tokens=5,
137-
temperature=0.0,
138-
# NOTE: The OpenAI client allows passing in extra JSON body via the
139-
# `extra_body` argument.
140-
extra_body={"prompt_embeds": encoded_embeds}
141-
)
142-
143-
print(completion.choices[0].text)
144-
```
44+
<gh-file:examples/online_serving/prompt_embed_inference_with_openai_client.py>
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Demonstrates how to generate prompt embeddings using
4+
Hugging Face Transformers and use them as input to vLLM
5+
for both single and batch inference.
6+
7+
Model: meta-llama/Llama-3.2-1B-Instruct
8+
Note: This model is gated on Hugging Face Hub.
9+
You must request access to use it:
10+
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
11+
12+
Requirements:
13+
- vLLM
14+
- transformers
15+
16+
Run:
17+
python examples/offline_inference/prompt_embed_inference.py
18+
"""
19+
20+
import torch
21+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
22+
PreTrainedTokenizer)
23+
24+
from vllm import LLM
25+
26+
27+
def init_tokenizer_and_llm(model_name: str):
28+
tokenizer = AutoTokenizer.from_pretrained(model_name)
29+
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
30+
embedding_layer = transformers_model.get_input_embeddings()
31+
llm = LLM(model=model_name, enable_prompt_embeds=True)
32+
return tokenizer, embedding_layer, llm
33+
34+
35+
def get_prompt_embeds(chat: list[dict[str,
36+
str]], tokenizer: PreTrainedTokenizer,
37+
embedding_layer: torch.nn.Module):
38+
token_ids = tokenizer.apply_chat_template(chat,
39+
add_generation_prompt=True,
40+
return_tensors='pt')
41+
prompt_embeds = embedding_layer(token_ids).squeeze(0)
42+
return prompt_embeds
43+
44+
45+
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
46+
embedding_layer: torch.nn.Module):
47+
chat = [{
48+
"role": "user",
49+
"content": "Please tell me about the capital of France."
50+
}]
51+
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
52+
53+
outputs = llm.generate({
54+
"prompt_embeds": prompt_embeds,
55+
})
56+
57+
print("\n[Single Inference Output]")
58+
print("-" * 30)
59+
for o in outputs:
60+
print(o.outputs[0].text)
61+
print("-" * 30)
62+
63+
64+
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
65+
embedding_layer: torch.nn.Module):
66+
chats = [[{
67+
"role": "user",
68+
"content": "Please tell me about the capital of France."
69+
}],
70+
[{
71+
"role": "user",
72+
"content": "When is the day longest during the year?"
73+
}],
74+
[{
75+
"role": "user",
76+
"content": "Where is bigger, the moon or the sun?"
77+
}]]
78+
79+
prompt_embeds_list = [
80+
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
81+
]
82+
83+
outputs = llm.generate([{
84+
"prompt_embeds": embeds
85+
} for embeds in prompt_embeds_list])
86+
87+
print("\n[Batch Inference Outputs]")
88+
print("-" * 30)
89+
for i, o in enumerate(outputs):
90+
print(f"Q{i+1}: {chats[i][0]['content']}")
91+
print(f"A{i+1}: {o.outputs[0].text}\n")
92+
print("-" * 30)
93+
94+
95+
def main():
96+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
97+
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
98+
single_prompt_inference(llm, tokenizer, embedding_layer)
99+
batch_prompt_inference(llm, tokenizer, embedding_layer)
100+
101+
102+
if __name__ == "__main__":
103+
main()
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
vLLM OpenAI-Compatible Client with Prompt Embeddings
4+
5+
This script demonstrates how to:
6+
1. Generate prompt embeddings using Hugging Face Transformers
7+
2. Encode them in base64 format
8+
3. Send them to a vLLM server via the OpenAI-compatible Completions API
9+
10+
Run the vLLM server first:
11+
vllm serve meta-llama/Llama-3.2-1B-Instruct \
12+
--task generate \
13+
--max-model-len 4096 \
14+
--enable-prompt-embeds
15+
16+
Run the client:
17+
python examples/online_serving/prompt_embed_inference_with_openai_client.py
18+
19+
Model: meta-llama/Llama-3.2-1B-Instruct
20+
Note: This model is gated on Hugging Face Hub.
21+
You must request access to use it:
22+
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
23+
24+
Dependencies:
25+
- transformers
26+
- torch
27+
- openai
28+
"""
29+
import base64
30+
import io
31+
32+
import torch
33+
import transformers
34+
from openai import OpenAI
35+
36+
37+
def main():
38+
client = OpenAI(
39+
api_key="EMPTY",
40+
base_url="http://localhost:8000/v1",
41+
)
42+
43+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
44+
45+
# Transformers
46+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
47+
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(
48+
model_name)
49+
50+
# Refer to the HuggingFace repo for the correct format to use
51+
chat = [{
52+
"role": "user",
53+
"content": "Please tell me about the capital of France."
54+
}]
55+
token_ids = tokenizer.apply_chat_template(chat,
56+
add_generation_prompt=True,
57+
return_tensors='pt')
58+
59+
embedding_layer = transformers_model.get_input_embeddings()
60+
prompt_embeds = embedding_layer(token_ids).squeeze(0)
61+
62+
# Prompt embeddings
63+
buffer = io.BytesIO()
64+
torch.save(prompt_embeds, buffer)
65+
buffer.seek(0)
66+
binary_data = buffer.read()
67+
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
68+
69+
completion = client.completions.create(
70+
model=model_name,
71+
# NOTE: The OpenAI client does not allow `None` as an input to
72+
# `prompt`. Use an empty string if you have no text prompts.
73+
prompt="",
74+
max_tokens=5,
75+
temperature=0.0,
76+
# NOTE: The OpenAI client allows passing in extra JSON body via the
77+
# `extra_body` argument.
78+
extra_body={"prompt_embeds": encoded_embeds})
79+
80+
print("-" * 30)
81+
print(completion.choices[0].text)
82+
print("-" * 30)
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)