Skip to content

Commit 1a014c2

Browse files
committed
update handle
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
1 parent abd769a commit 1a014c2

File tree

5 files changed

+1210
-53
lines changed

5 files changed

+1210
-53
lines changed

Diff for: builder/requirements.txt

+9
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,12 @@
66
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
77

88
runpod~=1.7.0
9+
tensorrt_llm
10+
transformers
11+
fastapi
12+
uvicorn
13+
pydantic
14+
numpy
15+
torch
16+
huggingface-hub
17+
python-dotenv

Diff for: src/handler.py

+134-53
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,151 @@
11
import os
2+
import asyncio
3+
from typing import Optional, Dict, Any
4+
from dataclasses import dataclass
25
import runpod
3-
from typing import List, AsyncGenerator, Dict, Union
4-
from tensorrt_llm import LLM, SamplingParams
6+
from transformers import AutoTokenizer
7+
from tensorrt_llm import LLM, BuildConfig
8+
from tensorrt_llm import LlmArgs
9+
from serve import OpenAIServer
10+
from dotenv import load_dotenv
511
from huggingface_hub import login
6-
from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig
12+
import requests
713

8-
# Enable build caching
9-
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1"
10-
# Optionally, set a custom cache directory
11-
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
12-
#HF_TOKEN for downloading models
14+
@dataclass
15+
class ServerConfig:
16+
model: str
17+
tokenizer: Optional[str] = None
18+
max_beam_width: Optional[int] = BuildConfig.max_beam_width
19+
max_batch_size: Optional[int] = BuildConfig.max_batch_size
20+
max_num_tokens: Optional[int] = BuildConfig.max_num_tokens
21+
max_seq_len: Optional[int] = BuildConfig.max_seq_len
22+
tp_size: Optional[int] = 1
23+
pp_size: Optional[int] = 1
24+
kv_cache_free_gpu_memory_fraction: Optional[float] = 0.9
25+
trust_remote_code: bool = False
1326

27+
@classmethod
28+
def from_env(cls) -> 'ServerConfig':
29+
model = os.getenv('TRTLLM_MODEL')
30+
if not model:
31+
raise ValueError("TRTLLM_MODEL environment variable must be set")
1432

33+
return cls(
34+
model=model,
35+
tokenizer=os.getenv('TRTLLM_TOKENIZER'),
36+
max_beam_width=int(os.getenv('TRTLLM_MAX_BEAM_WIDTH', str(BuildConfig.max_beam_width))) if os.getenv('TRTLLM_MAX_BEAM_WIDTH') else None,
37+
max_batch_size=int(os.getenv('TRTLLM_MAX_BATCH_SIZE', str(BuildConfig.max_batch_size))) if os.getenv('TRTLLM_MAX_BATCH_SIZE') else None,
38+
max_num_tokens=int(os.getenv('TRTLLM_MAX_NUM_TOKENS', str(BuildConfig.max_num_tokens))) if os.getenv('TRTLLM_MAX_NUM_TOKENS') else None,
39+
max_seq_len=int(os.getenv('TRTLLM_MAX_SEQ_LEN', str(BuildConfig.max_seq_len))) if os.getenv('TRTLLM_MAX_SEQ_LEN') else None,
40+
tp_size=int(os.getenv('TRTLLM_TP_SIZE', '1')) if os.getenv('TRTLLM_TP_SIZE') else None,
41+
pp_size=int(os.getenv('TRTLLM_PP_SIZE', '1')) if os.getenv('TRTLLM_PP_SIZE') else None,
42+
kv_cache_free_gpu_memory_fraction=float(os.getenv('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION', '0.9')) if os.getenv('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION') else None,
43+
trust_remote_code=os.getenv('TRTLLM_TRUST_REMOTE_CODE', '').lower() in ('true', '1', 'yes')
44+
)
1545

16-
hf_token = os.environ["HF_TOKEN"]
17-
login(token=hf_token)
46+
def validate(self) -> None:
47+
if not self.model:
48+
raise ValueError("Model path or name must be provided")
1849

50+
class TensorRTLLMServer:
51+
"""
52+
Singleton class to manage TensorRT-LLM server instance and handle requests
53+
"""
54+
# _instance = None
55+
# _initialized = False
1956

57+
# def __new__(cls):
58+
# if cls._instance is None:
59+
# cls._instance = super(TensorRTLLMServer, cls).__new__(cls)
60+
# return cls._instance
2061

21-
class TRTLLMWorker:
22-
def __init__(self, model_path: str):
23-
self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig())
62+
def __init__(self):
63+
self._initialize_server()
64+
self.host = '0.0.0.0'
65+
self.port = 8000
66+
67+
def _initialize_server(self):
68+
"""Initialize the TensorRT-LLM server and load model"""
69+
# Load environment variables
70+
load_dotenv()
2471

25-
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
26-
sampling_params = SamplingParams(max_new_tokens=max_tokens)
27-
outputs = self.llm.generate(prompts, sampling_params)
72+
# Handle HuggingFace login
73+
huggingface_token = os.getenv("HF_TOKEN")
74+
if huggingface_token:
75+
print("Logging in to Hugging Face...")
76+
login(huggingface_token)
77+
78+
# Initialize configuration
79+
self.config = ServerConfig.from_env()
80+
self.config.validate()
81+
82+
# Create build configuration
83+
build_config = BuildConfig(
84+
max_batch_size=self.config.max_batch_size,
85+
max_num_tokens=self.config.max_num_tokens,
86+
max_beam_width=self.config.max_beam_width,
87+
max_seq_len=self.config.max_seq_len
88+
)
89+
90+
# Initialize LLM
91+
self.llm = LLM(
92+
model=self.config.model,
93+
tokenizer=self.config.tokenizer,
94+
tensor_parallel_size=self.config.tp_size,
95+
pipeline_parallel_size=self.config.pp_size,
96+
trust_remote_code=self.config.trust_remote_code,
97+
build_config=build_config
98+
)
99+
100+
# Initialize tokenizer
101+
self.tokenizer = AutoTokenizer.from_pretrained(
102+
self.config.tokenizer or self.config.model,
103+
trust_remote_code=self.config.trust_remote_code
104+
)
105+
106+
# Initialize OpenAI compatible server
107+
self.server = OpenAIServer(
108+
llm=self.llm,
109+
model=self.config.model,
110+
hf_tokenizer=self.tokenizer
111+
)
28112

29-
results = []
30-
for output in outputs:
31-
results.append(output.outputs[0].text)
32-
return results
113+
asyncio.run(self.server(self.host, self.port))
33114

115+
# Initialize the server at module load time
116+
server = TensorRTLLMServer()
117+
118+
async def async_handler(job):
119+
"""Handle the requests asynchronously."""
120+
job_input = job["input"]
121+
print(f"JOB_INPUT: {job_input}")
34122

35-
async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]:
36-
sampling_params = SamplingParams(max_new_tokens=max_tokens)
37-
38-
async for output in self.llm.generate_async(prompts, sampling_params):
39-
for request_output in output.outputs:
40-
if request_output.text:
41-
yield request_output.text
42-
43-
# Initialize the worker outside the handler
44-
# This ensures the model is loaded only once when the serverless function starts
45-
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
46-
model_path = os.environ["MODEL_PATH"]
47-
worker = TRTLLMWorker(model_path)
48-
49-
50-
async def handler(job: Dict) -> AsyncGenerator[Dict[str, Union[str, List[str]]], None]:
51-
"""Handler function that will be used to process jobs."""
52-
job_input = job['input']
53-
prompts = job_input.get('prompts', ["Hello, how are you?"])
54-
max_tokens = job_input.get('max_tokens', 100)
55-
streaming = job_input.get('streaming', False)
123+
base_url = "http://0.0.0.0:8000"
56124

57-
try:
58-
if streaming:
59-
results = []
60-
async for chunk in worker.generate_async(prompts, max_tokens):
61-
results.append(chunk)
62-
yield {"status": "streaming", "chunk": chunk}
63-
yield {"status": "success", "output": results}
125+
if job_input.get("openai_route"):
126+
openai_route, openai_input = job_input.get("openai_route"), job_input.get("openai_input")
127+
128+
openai_url = f"{base_url}" + openai_route
129+
headers = {"Content-Type": "application/json"}
130+
131+
response = requests.post(openai_url, headers=headers, json=openai_input)
132+
# Process the streamed response
133+
if openai_input.get("stream", False):
134+
for formated_chunk in response:
135+
yield formated_chunk
136+
else:
137+
for chunk in response.iter_lines():
138+
if chunk:
139+
decoded_chunk = chunk.decode('utf-8')
140+
yield decoded_chunk
141+
else:
142+
generate_url = f"{base_url}/generate"
143+
headers = {"Content-Type": "application/json"}
144+
# Directly pass `job_input` to `json`. Can we tell users the possible fields of `job_input`?
145+
response = requests.post(generate_url, json=job_input, headers=headers)
146+
if response.status_code == 200:
147+
yield response.json()
64148
else:
65-
results = worker.generate(prompts, max_tokens)
66-
yield {"status": "success", "output": results}
67-
except Exception as e:
68-
yield {"status": "error", "message": str(e)}
149+
yield {"error": f"Generate request failed with status code {response.status_code}", "details": response.text}
69150

70-
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
151+
runpod.serverless.start({"handler": async_handler, "return_aggregate_stream": True})

0 commit comments

Comments
 (0)