1
- """ Example handler file. """
2
-
1
+ import os
3
2
import runpod
4
-
5
- # If your handler runs inference on a model, load the model here.
6
- # You will want models to be loaded into memory before starting serverless.
7
-
3
+ from typing import List
4
+ from tensorrt_llm import LLM , SamplingParams
5
+
6
+ # Enable build caching
7
+ os .environ ["TLLM_HLAPI_BUILD_CACHE" ] = "1"
8
+ # Optionally, set a custom cache directory
9
+ # os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
10
+
11
+ class TRTLLMWorker :
12
+ def __init__ (self , model_path : str ):
13
+ self .llm = LLM (model = model_path , enable_build_cache = True )
14
+
15
+ def generate (self , prompts : List [str ], max_tokens : int = 100 ) -> List [str ]:
16
+ sampling_params = SamplingParams (max_new_tokens = max_tokens )
17
+ outputs = self .llm .generate (prompts , sampling_params )
18
+
19
+ results = []
20
+ for output in outputs :
21
+ results .append (output .outputs [0 ].text )
22
+
23
+ return results
24
+
25
+ # Initialize the worker outside the handler
26
+ # This ensures the model is loaded only once when the serverless function starts
27
+ worker = TRTLLMWorker ("TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
8
28
9
29
def handler (job ):
10
- """ Handler function that will be used to process jobs. """
30
+ """Handler function that will be used to process jobs."""
11
31
job_input = job ['input' ]
12
-
13
- name = job_input .get ('name' , 'World' )
14
-
15
- return f"Hello, { name } !"
16
-
17
-
18
- runpod .serverless .start ({"handler" : handler })
32
+ prompts = job_input .get ('prompts' , ["Hello, how are you?" ])
33
+ max_tokens = job_input .get ('max_tokens' , 100 )
34
+
35
+ try :
36
+ results = worker .generate (prompts , max_tokens )
37
+ return {"status" : "success" , "output" : results }
38
+ except Exception as e :
39
+ return {"status" : "error" , "message" : str (e )}
40
+
41
+ if __name__ == "__main__" :
42
+ runpod .serverless .start ({"handler" : handler })
0 commit comments