@@ -27,7 +27,8 @@ def __init__(
27
27
temperature : float = 0.8 ,
28
28
top_k : int = 50 ,
29
29
top_p : float = 1.0 ,
30
- max_new_tokens : int = 50
30
+ max_new_tokens : int = 50 ,
31
+ devices : int = 1
31
32
) -> None :
32
33
33
34
if not _LITSERVE_AVAILABLE :
@@ -41,6 +42,7 @@ def __init__(
41
42
self .top_k = top_k
42
43
self .max_new_tokens = max_new_tokens
43
44
self .top_p = top_p
45
+ self .devices = devices
44
46
45
47
def setup (self , device : str ) -> None :
46
48
if ":" in device :
@@ -57,9 +59,11 @@ def setup(self, device: str) -> None:
57
59
)
58
60
59
61
self .llm .distribute (
62
+ devices = self .devices ,
60
63
accelerator = accelerator ,
61
64
quantize = self .quantize ,
62
- precision = self .precision
65
+ precision = self .precision ,
66
+ generate_strategy = "sequential" if self .devices is not None and self .devices > 1 else None
63
67
)
64
68
print ("Model successfully initialized." )
65
69
@@ -78,9 +82,10 @@ def __init__(
78
82
temperature : float = 0.8 ,
79
83
top_k : int = 50 ,
80
84
top_p : float = 1.0 ,
81
- max_new_tokens : int = 50
85
+ max_new_tokens : int = 50 ,
86
+ devices : int = 1
82
87
):
83
- super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens )
88
+ super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens , devices )
84
89
85
90
def setup (self , device : str ):
86
91
super ().setup (device )
@@ -109,9 +114,10 @@ def __init__(
109
114
temperature : float = 0.8 ,
110
115
top_k : int = 50 ,
111
116
top_p : float = 1.0 ,
112
- max_new_tokens : int = 50
117
+ max_new_tokens : int = 50 ,
118
+ devices : int = 1
113
119
):
114
- super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens )
120
+ super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens , devices )
115
121
116
122
def setup (self , device : str ):
117
123
super ().setup (device )
@@ -197,9 +203,10 @@ def run_server(
197
203
top_k = top_k ,
198
204
top_p = top_p ,
199
205
max_new_tokens = max_new_tokens ,
206
+ devices = devices
200
207
),
201
208
accelerator = accelerator ,
202
- devices = devices
209
+ devices = 1 # We need to use the devives inside the `SimpleLitAPI` class
203
210
)
204
211
205
212
else :
@@ -212,9 +219,10 @@ def run_server(
212
219
top_k = top_k ,
213
220
top_p = top_p ,
214
221
max_new_tokens = max_new_tokens ,
222
+ devices = devices # We need to use the devives inside the `StreamLitAPI` class
215
223
),
216
224
accelerator = accelerator ,
217
- devices = devices ,
225
+ devices = 1 ,
218
226
stream = True
219
227
)
220
228
0 commit comments