Skip to content

Commit 2433eaf

Browse files
authored
Support the refactored API in litgpt serve (#1668)
1 parent 3eab461 commit 2433eaf

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

litgpt/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ def distribute(
239239
"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."
240240
)
241241

242+
if precision is None:
243+
precision = get_default_supported_precision(training=False)
244+
242245
plugins = None
243246
if quantize is not None and quantize.startswith("bnb."):
244247
if "mixed" in precision:

litgpt/deploy/serve.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def setup(self, device: str) -> None:
5353
print("Initializing model...")
5454
self.llm = LLM.load(
5555
model=self.checkpoint_dir,
56+
distribute=None
57+
)
58+
59+
self.llm.distribute(
5660
accelerator=accelerator,
5761
quantize=self.quantize,
5862
precision=self.precision

tests/test_serve.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import requests
88
import subprocess
9+
from tests.conftest import RunIf
910
import threading
1011
import time
1112
import yaml
@@ -54,3 +55,46 @@ def run_server():
5455
if process:
5556
process.kill()
5657
server_thread.join()
58+
59+
60+
@RunIf(min_cuda_gpus=1)
61+
def test_quantize(tmp_path):
62+
seed_everything(123)
63+
ours_config = Config.from_name("pythia-14m")
64+
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
65+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
66+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
67+
ours_model = GPT(ours_config)
68+
checkpoint_path = tmp_path / "lit_model.pth"
69+
torch.save(ours_model.state_dict(), checkpoint_path)
70+
config_path = tmp_path / "model_config.yaml"
71+
with open(config_path, "w", encoding="utf-8") as fp:
72+
yaml.dump(asdict(ours_config), fp)
73+
74+
run_command = [
75+
"litgpt", "serve", tmp_path, "--quantize", "bnb.nf4"
76+
]
77+
78+
process = None
79+
80+
def run_server():
81+
nonlocal process
82+
try:
83+
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
84+
stdout, stderr = process.communicate(timeout=10)
85+
except subprocess.TimeoutExpired:
86+
print('Server start-up timeout expired')
87+
88+
server_thread = threading.Thread(target=run_server)
89+
server_thread.start()
90+
91+
time.sleep(10)
92+
93+
try:
94+
response = requests.get("http://127.0.0.1:8000")
95+
print(response.status_code)
96+
assert response.status_code == 200, "Server did not respond as expected."
97+
finally:
98+
if process:
99+
process.kill()
100+
server_thread.join()

0 commit comments

Comments
 (0)