1
- """ Example handler file. """
2
-
1
+ import os
3
2
import runpod
3
+ from typing import List , AsyncGenerator , Dict , Union
4
+ from tensorrt_llm import LLM , SamplingParams
5
+ from huggingface_hub import login
6
+ from tensorrt_llm .hlapi import BuildConfig , KvCacheConfig
4
7
5
- # If your handler runs inference on a model, load the model here.
6
- # You will want models to be loaded into memory before starting serverless.
8
+ # Enable build caching
9
+ os .environ ["TLLM_HLAPI_BUILD_CACHE" ] = "1"
10
+ # Optionally, set a custom cache directory
11
+ # os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
12
+ #HF_TOKEN for downloading models
7
13
8
14
9
- def handler (job ):
10
- """ Handler function that will be used to process jobs. """
11
- job_input = job ['input' ]
12
15
13
- name = job_input .get ('name' , 'World' )
16
+ hf_token = os .environ ["HF_TOKEN" ]
17
+ login (token = hf_token )
18
+
19
+
14
20
15
- return f"Hello, { name } !"
21
+ class TRTLLMWorker :
22
+ def __init__ (self , model_path : str ):
23
+ self .llm = LLM (model = model_path , enable_build_cache = True , kv_cache_config = KvCacheConfig (), build_config = BuildConfig ())
24
+
25
+ def generate (self , prompts : List [str ], max_tokens : int = 100 ) -> List [str ]:
26
+ sampling_params = SamplingParams (max_new_tokens = max_tokens )
27
+ outputs = self .llm .generate (prompts , sampling_params )
28
+
29
+ results = []
30
+ for output in outputs :
31
+ results .append (output .outputs [0 ].text )
32
+ return results
16
33
34
+
35
+ async def generate_async (self , prompts : List [str ], max_tokens : int = 100 ) -> AsyncGenerator [str , None ]:
36
+ sampling_params = SamplingParams (max_new_tokens = max_tokens )
37
+
38
+ async for output in self .llm .generate_async (prompts , sampling_params ):
39
+ for request_output in output .outputs :
40
+ if request_output .text :
41
+ yield request_output .text
42
+
43
+ # Initialize the worker outside the handler
44
+ # This ensures the model is loaded only once when the serverless function starts
45
+ # this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
46
+ model_path = os .environ ["MODEL_PATH" ]
47
+ worker = TRTLLMWorker (model_path )
48
+
49
+
50
+ async def handler (job : Dict ) -> AsyncGenerator [Dict [str , Union [str , List [str ]]], None ]:
51
+ """Handler function that will be used to process jobs."""
52
+ job_input = job ['input' ]
53
+ prompts = job_input .get ('prompts' , ["Hello, how are you?" ])
54
+ max_tokens = job_input .get ('max_tokens' , 100 )
55
+ streaming = job_input .get ('streaming' , False )
56
+
57
+ try :
58
+ if streaming :
59
+ results = []
60
+ async for chunk in worker .generate_async (prompts , max_tokens ):
61
+ results .append (chunk )
62
+ yield {"status" : "streaming" , "chunk" : chunk }
63
+ yield {"status" : "success" , "output" : results }
64
+ else :
65
+ results = worker .generate (prompts , max_tokens )
66
+ yield {"status" : "success" , "output" : results }
67
+ except Exception as e :
68
+ yield {"status" : "error" , "message" : str (e )}
17
69
18
- runpod .serverless .start ({"handler" : handler })
70
+ runpod .serverless .start ({"handler" : handler , "return_aggregate_stream" : True })
0 commit comments