1
1
import os
2
2
import runpod
3
- from typing import List
3
+ from typing import List , AsyncGenerator
4
4
from tensorrt_llm import LLM , SamplingParams
5
5
from huggingface_hub import login
6
+ from tensorrt_llm .hlapi import BuildConfig , KvCacheConfig
6
7
7
8
# Enable build caching
8
9
os .environ ["TLLM_HLAPI_BUILD_CACHE" ] = "1"
19
20
20
21
class TRTLLMWorker :
21
22
def __init__ (self , model_path : str ):
22
- self .llm = LLM (model = model_path , enable_build_cache = True )
23
+ self .llm = LLM (model = model_path , enable_build_cache = True , kv_cache_config = KvCacheConfig (), build_config = BuildConfig () )
23
24
24
25
def generate (self , prompts : List [str ], max_tokens : int = 100 ) -> List [str ]:
25
26
sampling_params = SamplingParams (max_new_tokens = max_tokens )
@@ -30,25 +31,55 @@ def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
30
31
results .append (output .outputs [0 ].text )
31
32
return results
32
33
34
+
35
+ async def generate_async (self , prompts : List [str ], max_tokens : int = 100 ) -> AsyncGenerator [str , None ]:
36
+ sampling_params = SamplingParams (max_new_tokens = max_tokens )
37
+
38
+ async for output in self .llm .generate_async (prompts , sampling_params ):
39
+ for request_output in output .outputs :
40
+ if request_output .text :
41
+ yield request_output .text
42
+
33
43
# Initialize the worker outside the handler
34
44
# This ensures the model is loaded only once when the serverless function starts
35
- # this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
45
+ # this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
36
46
model_path = os .environ ["MODEL_PATH" ]
37
47
worker = TRTLLMWorker (model_path )
38
48
39
49
40
50
41
- def handler (job ):
51
+ # def handler(job):
52
+ # """Handler function that will be used to process jobs."""
53
+ # job_input = job['input']
54
+ # prompts = job_input.get('prompts', ["Hello, how are you?"])
55
+ # max_tokens = job_input.get('max_tokens', 100)
56
+ # streaming = job_input.get('streaming', False)
57
+
58
+ # try:
59
+ # results = worker.generate(prompts, max_tokens)
60
+ # return {"status": "success", "output": results}
61
+ # except Exception as e:
62
+ # return {"status": "error", "message": str(e)}
63
+
64
+
65
+ async def handler (job ):
42
66
"""Handler function that will be used to process jobs."""
43
67
job_input = job ['input' ]
44
68
prompts = job_input .get ('prompts' , ["Hello, how are you?" ])
45
69
max_tokens = job_input .get ('max_tokens' , 100 )
70
+ streaming = job_input .get ('streaming' , False )
46
71
47
72
try :
48
- results = worker .generate (prompts , max_tokens )
49
- return {"status" : "success" , "output" : results }
73
+ if streaming :
74
+ results = []
75
+ async for chunk in worker .generate_async (prompts , max_tokens ):
76
+ results .append (chunk )
77
+ yield {"status" : "streaming" , "chunk" : chunk }
78
+ yield {"status" : "success" , "output" : results }
79
+ else :
80
+ results = worker .generate (prompts , max_tokens )
81
+ return {"status" : "success" , "output" : results }
50
82
except Exception as e :
51
83
return {"status" : "error" , "message" : str (e )}
52
84
53
-
54
- runpod .serverless .start ({"handler" : handler })
85
+ runpod .serverless .start ({"handler" : handler , "return_aggregate_stream" : True })
0 commit comments