|
6 | 6 | import torch
|
7 | 7 | import requests
|
8 | 8 | import subprocess
|
| 9 | +from tests.conftest import RunIf |
9 | 10 | import threading
|
10 | 11 | import time
|
11 | 12 | import yaml
|
@@ -54,3 +55,46 @@ def run_server():
|
54 | 55 | if process:
|
55 | 56 | process.kill()
|
56 | 57 | server_thread.join()
|
| 58 | + |
| 59 | + |
| 60 | +@RunIf(min_cuda_gpus=1) |
| 61 | +def test_quantize(tmp_path): |
| 62 | + seed_everything(123) |
| 63 | + ours_config = Config.from_name("pythia-14m") |
| 64 | + download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) |
| 65 | + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) |
| 66 | + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) |
| 67 | + ours_model = GPT(ours_config) |
| 68 | + checkpoint_path = tmp_path / "lit_model.pth" |
| 69 | + torch.save(ours_model.state_dict(), checkpoint_path) |
| 70 | + config_path = tmp_path / "model_config.yaml" |
| 71 | + with open(config_path, "w", encoding="utf-8") as fp: |
| 72 | + yaml.dump(asdict(ours_config), fp) |
| 73 | + |
| 74 | + run_command = [ |
| 75 | + "litgpt", "serve", tmp_path, "--quantize", "bnb.nf4" |
| 76 | + ] |
| 77 | + |
| 78 | + process = None |
| 79 | + |
| 80 | + def run_server(): |
| 81 | + nonlocal process |
| 82 | + try: |
| 83 | + process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| 84 | + stdout, stderr = process.communicate(timeout=10) |
| 85 | + except subprocess.TimeoutExpired: |
| 86 | + print('Server start-up timeout expired') |
| 87 | + |
| 88 | + server_thread = threading.Thread(target=run_server) |
| 89 | + server_thread.start() |
| 90 | + |
| 91 | + time.sleep(10) |
| 92 | + |
| 93 | + try: |
| 94 | + response = requests.get("http://127.0.0.1:8000") |
| 95 | + print(response.status_code) |
| 96 | + assert response.status_code == 200, "Server did not respond as expected." |
| 97 | + finally: |
| 98 | + if process: |
| 99 | + process.kill() |
| 100 | + server_thread.join() |
0 commit comments