Skip to content

Commit 0f3bca7

Browse files
authored
Streaming for serving with chat's generate function (#1426)
1 parent fa88952 commit 0f3bca7

File tree

3 files changed

+86
-16
lines changed

3 files changed

+86
-16
lines changed

litgpt/deploy/__init__.py

Whitespace-only changes.

litgpt/deploy/serve.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from litgpt.model import GPT
1313
from litgpt.config import Config
1414
from litgpt.tokenizer import Tokenizer
15-
from litgpt.generate.base import generate
15+
from litgpt.generate.base import generate as plain_generate
16+
from litgpt.chat.base import generate as stream_generate
1617
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
1718
from litgpt.utils import (
1819
extend_checkpoint_dir,
@@ -28,7 +29,7 @@
2829
LitAPI, LitServer = object, object
2930

3031

31-
class SimpleLitAPI(LitAPI):
32+
class BaseLitAPI(LitAPI):
3233
def __init__(self,
3334
checkpoint_dir: Path,
3435
precision: Optional[str] = None,
@@ -86,12 +87,26 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
8687
encoded = self.tokenizer.encode(prompt, device=self.device)
8788
return encoded
8889

90+
91+
class SimpleLitAPI(BaseLitAPI):
92+
def __init__(self,
93+
checkpoint_dir: Path,
94+
precision: Optional[str] = None,
95+
temperature: float = 0.8,
96+
top_k: int = 50,
97+
top_p: float = 1.0,
98+
max_new_tokens: int = 50):
99+
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)
100+
101+
def setup(self, device: str):
102+
super().setup(device)
103+
89104
def predict(self, inputs: torch.Tensor) -> Any:
90105
# Run the model on the input and return the output.
91106
prompt_length = inputs.size(0)
92107
max_returned_tokens = prompt_length + self.max_new_tokens
93108

94-
y = generate(
109+
y = plain_generate(
95110
self.model,
96111
inputs,
97112
max_returned_tokens,
@@ -111,6 +126,42 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
111126
return {"output": decoded_output}
112127

113128

129+
class StreamLitAPI(BaseLitAPI):
130+
def __init__(self,
131+
checkpoint_dir: Path,
132+
precision: Optional[str] = None,
133+
temperature: float = 0.8,
134+
top_k: int = 50,
135+
top_p: float = 1.0,
136+
max_new_tokens: int = 50):
137+
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)
138+
139+
def setup(self, device: str):
140+
super().setup(device)
141+
142+
def predict(self, inputs: torch.Tensor) -> Any:
143+
# Run the model on the input and return the output.
144+
prompt_length = inputs.size(0)
145+
max_returned_tokens = prompt_length + self.max_new_tokens
146+
147+
for block in self.model.transformer.h:
148+
block.attn.kv_cache.reset_parameters()
149+
150+
yield from stream_generate(
151+
self.model,
152+
inputs,
153+
max_returned_tokens,
154+
temperature=self.temperature,
155+
top_k=self.top_k,
156+
top_p=self.top_p,
157+
stop_tokens=([self.tokenizer.eos_id],)
158+
)
159+
160+
def encode_response(self, output):
161+
for out in output:
162+
yield {"output": self.tokenizer.decode(out)}
163+
164+
114165
def run_server(
115166
checkpoint_dir: Path,
116167
precision: Optional[str] = None,
@@ -120,7 +171,8 @@ def run_server(
120171
max_new_tokens: int = 50,
121172
devices: int = 1,
122173
accelerator: str = "auto",
123-
port: int = 8000
174+
port: int = 8000,
175+
stream: bool = False
124176
) -> None:
125177
"""Serve a LitGPT model using LitServe.
126178
@@ -153,22 +205,40 @@ def run_server(
153205
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
154206
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
155207
port: The network port number on which the model is configured to be served.
208+
stream: Whether to stream the responses.
156209
"""
157210
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
158211
pprint(locals())
159212

160213
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")
161214

162-
server = LitServer(
163-
SimpleLitAPI(
164-
checkpoint_dir=checkpoint_dir,
165-
precision=precision,
166-
temperature=temperature,
167-
top_k=top_k,
168-
top_p=top_p,
169-
max_new_tokens=max_new_tokens,
170-
),
171-
accelerator=accelerator,
172-
devices=devices)
215+
if not stream:
216+
server = LitServer(
217+
SimpleLitAPI(
218+
checkpoint_dir=checkpoint_dir,
219+
precision=precision,
220+
temperature=temperature,
221+
top_k=top_k,
222+
top_p=top_p,
223+
max_new_tokens=max_new_tokens,
224+
),
225+
accelerator=accelerator,
226+
devices=devices
227+
)
228+
229+
else:
230+
server = LitServer(
231+
StreamLitAPI(
232+
checkpoint_dir=checkpoint_dir,
233+
precision=precision,
234+
temperature=temperature,
235+
top_k=top_k,
236+
top_p=top_p,
237+
max_new_tokens=max_new_tokens,
238+
),
239+
accelerator=accelerator,
240+
devices=devices,
241+
stream=True
242+
)
173243

174244
server.run(port=port)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ all = [
3838
"tokenizers>=0.15.2", # pythia, falcon, redpajama
3939
"requests>=2.31.0", # litgpt.data
4040
"litdata==0.2.6", # litgpt.data
41-
"litserve>=0.1.0", # litgpt.deploy
41+
"litserve==0.1.1dev0", # litgpt.deploy
4242
"zstandard>=0.22.0", # litgpt.data.prepare_slimpajama.py
4343
"pandas>=1.9.0", # litgpt.data.prepare_starcoder.py
4444
"pyarrow>=15.0.2", # litgpt.data.prepare_starcoder.py

0 commit comments

Comments
 (0)