12
12
from litgpt .model import GPT
13
13
from litgpt .config import Config
14
14
from litgpt .tokenizer import Tokenizer
15
- from litgpt .generate .base import generate
15
+ from litgpt .generate .base import generate as plain_generate
16
+ from litgpt .chat .base import generate as stream_generate
16
17
from litgpt .prompts import load_prompt_style , has_prompt_style , PromptStyle
17
18
from litgpt .utils import (
18
19
extend_checkpoint_dir ,
28
29
LitAPI , LitServer = object , object
29
30
30
31
31
- class SimpleLitAPI (LitAPI ):
32
+ class BaseLitAPI (LitAPI ):
32
33
def __init__ (self ,
33
34
checkpoint_dir : Path ,
34
35
precision : Optional [str ] = None ,
@@ -86,12 +87,26 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
86
87
encoded = self .tokenizer .encode (prompt , device = self .device )
87
88
return encoded
88
89
90
+
91
+ class SimpleLitAPI (BaseLitAPI ):
92
+ def __init__ (self ,
93
+ checkpoint_dir : Path ,
94
+ precision : Optional [str ] = None ,
95
+ temperature : float = 0.8 ,
96
+ top_k : int = 50 ,
97
+ top_p : float = 1.0 ,
98
+ max_new_tokens : int = 50 ):
99
+ super ().__init__ (checkpoint_dir , precision , temperature , top_k , top_p , max_new_tokens )
100
+
101
+ def setup (self , device : str ):
102
+ super ().setup (device )
103
+
89
104
def predict (self , inputs : torch .Tensor ) -> Any :
90
105
# Run the model on the input and return the output.
91
106
prompt_length = inputs .size (0 )
92
107
max_returned_tokens = prompt_length + self .max_new_tokens
93
108
94
- y = generate (
109
+ y = plain_generate (
95
110
self .model ,
96
111
inputs ,
97
112
max_returned_tokens ,
@@ -111,6 +126,42 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
111
126
return {"output" : decoded_output }
112
127
113
128
129
+ class StreamLitAPI (BaseLitAPI ):
130
+ def __init__ (self ,
131
+ checkpoint_dir : Path ,
132
+ precision : Optional [str ] = None ,
133
+ temperature : float = 0.8 ,
134
+ top_k : int = 50 ,
135
+ top_p : float = 1.0 ,
136
+ max_new_tokens : int = 50 ):
137
+ super ().__init__ (checkpoint_dir , precision , temperature , top_k , top_p , max_new_tokens )
138
+
139
+ def setup (self , device : str ):
140
+ super ().setup (device )
141
+
142
+ def predict (self , inputs : torch .Tensor ) -> Any :
143
+ # Run the model on the input and return the output.
144
+ prompt_length = inputs .size (0 )
145
+ max_returned_tokens = prompt_length + self .max_new_tokens
146
+
147
+ for block in self .model .transformer .h :
148
+ block .attn .kv_cache .reset_parameters ()
149
+
150
+ yield from stream_generate (
151
+ self .model ,
152
+ inputs ,
153
+ max_returned_tokens ,
154
+ temperature = self .temperature ,
155
+ top_k = self .top_k ,
156
+ top_p = self .top_p ,
157
+ stop_tokens = ([self .tokenizer .eos_id ],)
158
+ )
159
+
160
+ def encode_response (self , output ):
161
+ for out in output :
162
+ yield {"output" : self .tokenizer .decode (out )}
163
+
164
+
114
165
def run_server (
115
166
checkpoint_dir : Path ,
116
167
precision : Optional [str ] = None ,
@@ -120,7 +171,8 @@ def run_server(
120
171
max_new_tokens : int = 50 ,
121
172
devices : int = 1 ,
122
173
accelerator : str = "auto" ,
123
- port : int = 8000
174
+ port : int = 8000 ,
175
+ stream : bool = False
124
176
) -> None :
125
177
"""Serve a LitGPT model using LitServe.
126
178
@@ -153,22 +205,40 @@ def run_server(
153
205
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
154
206
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
155
207
port: The network port number on which the model is configured to be served.
208
+ stream: Whether to stream the responses.
156
209
"""
157
210
checkpoint_dir = extend_checkpoint_dir (checkpoint_dir )
158
211
pprint (locals ())
159
212
160
213
check_valid_checkpoint_dir (checkpoint_dir , model_filename = "lit_model.pth" )
161
214
162
- server = LitServer (
163
- SimpleLitAPI (
164
- checkpoint_dir = checkpoint_dir ,
165
- precision = precision ,
166
- temperature = temperature ,
167
- top_k = top_k ,
168
- top_p = top_p ,
169
- max_new_tokens = max_new_tokens ,
170
- ),
171
- accelerator = accelerator ,
172
- devices = devices )
215
+ if not stream :
216
+ server = LitServer (
217
+ SimpleLitAPI (
218
+ checkpoint_dir = checkpoint_dir ,
219
+ precision = precision ,
220
+ temperature = temperature ,
221
+ top_k = top_k ,
222
+ top_p = top_p ,
223
+ max_new_tokens = max_new_tokens ,
224
+ ),
225
+ accelerator = accelerator ,
226
+ devices = devices
227
+ )
228
+
229
+ else :
230
+ server = LitServer (
231
+ StreamLitAPI (
232
+ checkpoint_dir = checkpoint_dir ,
233
+ precision = precision ,
234
+ temperature = temperature ,
235
+ top_k = top_k ,
236
+ top_p = top_p ,
237
+ max_new_tokens = max_new_tokens ,
238
+ ),
239
+ accelerator = accelerator ,
240
+ devices = devices ,
241
+ stream = True
242
+ )
173
243
174
244
server .run (port = port )
0 commit comments