Skip to content

Commit 405c79f

Browse files
committed
update handler
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
1 parent 85815b5 commit 405c79f

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

src/handler.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import runpod
3-
from typing import List
3+
from typing import List, AsyncGenerator
44
from tensorrt_llm import LLM, SamplingParams
55
from huggingface_hub import login
6+
from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig
67

78
# Enable build caching
89
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1"
@@ -19,7 +20,7 @@
1920

2021
class TRTLLMWorker:
2122
def __init__(self, model_path: str):
22-
self.llm = LLM(model=model_path, enable_build_cache=True)
23+
self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig())
2324

2425
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
2526
sampling_params = SamplingParams(max_new_tokens=max_tokens)
@@ -30,25 +31,55 @@ def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
3031
results.append(output.outputs[0].text)
3132
return results
3233

34+
35+
async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]:
36+
sampling_params = SamplingParams(max_new_tokens=max_tokens)
37+
38+
async for output in self.llm.generate_async(prompts, sampling_params):
39+
for request_output in output.outputs:
40+
if request_output.text:
41+
yield request_output.text
42+
3343
# Initialize the worker outside the handler
3444
# This ensures the model is loaded only once when the serverless function starts
35-
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
45+
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
3646
model_path = os.environ["MODEL_PATH"]
3747
worker = TRTLLMWorker(model_path)
3848

3949

4050

41-
def handler(job):
51+
# def handler(job):
52+
# """Handler function that will be used to process jobs."""
53+
# job_input = job['input']
54+
# prompts = job_input.get('prompts', ["Hello, how are you?"])
55+
# max_tokens = job_input.get('max_tokens', 100)
56+
# streaming = job_input.get('streaming', False)
57+
58+
# try:
59+
# results = worker.generate(prompts, max_tokens)
60+
# return {"status": "success", "output": results}
61+
# except Exception as e:
62+
# return {"status": "error", "message": str(e)}
63+
64+
65+
async def handler(job):
4266
"""Handler function that will be used to process jobs."""
4367
job_input = job['input']
4468
prompts = job_input.get('prompts', ["Hello, how are you?"])
4569
max_tokens = job_input.get('max_tokens', 100)
70+
streaming = job_input.get('streaming', False)
4671

4772
try:
48-
results = worker.generate(prompts, max_tokens)
49-
return {"status": "success", "output": results}
73+
if streaming:
74+
results = []
75+
async for chunk in worker.generate_async(prompts, max_tokens):
76+
results.append(chunk)
77+
yield {"status": "streaming", "chunk": chunk}
78+
yield {"status": "success", "output": results}
79+
else:
80+
results = worker.generate(prompts, max_tokens)
81+
return {"status": "success", "output": results}
5082
except Exception as e:
5183
return {"status": "error", "message": str(e)}
5284

53-
54-
runpod.serverless.start({"handler": handler})
85+
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

0 commit comments

Comments
 (0)