Skip to content

Commit ef9647c

Browse files
authored
Multi-gpu serving (#1670)
1 parent 2433eaf commit ef9647c

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

litgpt/deploy/serve.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(
2727
temperature: float = 0.8,
2828
top_k: int = 50,
2929
top_p: float = 1.0,
30-
max_new_tokens: int = 50
30+
max_new_tokens: int = 50,
31+
devices: int = 1
3132
) -> None:
3233

3334
if not _LITSERVE_AVAILABLE:
@@ -41,6 +42,7 @@ def __init__(
4142
self.top_k = top_k
4243
self.max_new_tokens = max_new_tokens
4344
self.top_p = top_p
45+
self.devices = devices
4446

4547
def setup(self, device: str) -> None:
4648
if ":" in device:
@@ -57,9 +59,11 @@ def setup(self, device: str) -> None:
5759
)
5860

5961
self.llm.distribute(
62+
devices=self.devices,
6063
accelerator=accelerator,
6164
quantize=self.quantize,
62-
precision=self.precision
65+
precision=self.precision,
66+
generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None
6367
)
6468
print("Model successfully initialized.")
6569

@@ -78,9 +82,10 @@ def __init__(
7882
temperature: float = 0.8,
7983
top_k: int = 50,
8084
top_p: float = 1.0,
81-
max_new_tokens: int = 50
85+
max_new_tokens: int = 50,
86+
devices: int = 1
8287
):
83-
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens)
88+
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
8489

8590
def setup(self, device: str):
8691
super().setup(device)
@@ -109,9 +114,10 @@ def __init__(
109114
temperature: float = 0.8,
110115
top_k: int = 50,
111116
top_p: float = 1.0,
112-
max_new_tokens: int = 50
117+
max_new_tokens: int = 50,
118+
devices: int = 1
113119
):
114-
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens)
120+
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
115121

116122
def setup(self, device: str):
117123
super().setup(device)
@@ -197,9 +203,10 @@ def run_server(
197203
top_k=top_k,
198204
top_p=top_p,
199205
max_new_tokens=max_new_tokens,
206+
devices=devices
200207
),
201208
accelerator=accelerator,
202-
devices=devices
209+
devices=1 # We need to use the devives inside the `SimpleLitAPI` class
203210
)
204211

205212
else:
@@ -212,9 +219,10 @@ def run_server(
212219
top_k=top_k,
213220
top_p=top_p,
214221
max_new_tokens=max_new_tokens,
222+
devices=devices # We need to use the devives inside the `StreamLitAPI` class
215223
),
216224
accelerator=accelerator,
217-
devices=devices,
225+
devices=1,
218226
stream=True
219227
)
220228

tests/test_serve.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,46 @@ def run_server():
9898
if process:
9999
process.kill()
100100
server_thread.join()
101+
102+
103+
@RunIf(min_cuda_gpus=2)
104+
def test_multi_gpu_serve(tmp_path):
105+
seed_everything(123)
106+
ours_config = Config.from_name("pythia-14m")
107+
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
108+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
109+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
110+
ours_model = GPT(ours_config)
111+
checkpoint_path = tmp_path / "lit_model.pth"
112+
torch.save(ours_model.state_dict(), checkpoint_path)
113+
config_path = tmp_path / "model_config.yaml"
114+
with open(config_path, "w", encoding="utf-8") as fp:
115+
yaml.dump(asdict(ours_config), fp)
116+
117+
run_command = [
118+
"litgpt", "serve", tmp_path, "--devices", "2"
119+
]
120+
121+
process = None
122+
123+
def run_server():
124+
nonlocal process
125+
try:
126+
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
127+
stdout, stderr = process.communicate(timeout=10)
128+
except subprocess.TimeoutExpired:
129+
print('Server start-up timeout expired')
130+
131+
server_thread = threading.Thread(target=run_server)
132+
server_thread.start()
133+
134+
time.sleep(10)
135+
136+
try:
137+
response = requests.get("http://127.0.0.1:8000")
138+
print(response.status_code)
139+
assert response.status_code == 200, "Server did not respond as expected."
140+
finally:
141+
if process:
142+
process.kill()
143+
server_thread.join()

0 commit comments

Comments
 (0)