1
1
import os
2
+ import asyncio
3
+ from typing import Optional , Dict , Any
4
+ from dataclasses import dataclass
2
5
import runpod
3
- from typing import List , AsyncGenerator , Dict , Union
4
- from tensorrt_llm import LLM , SamplingParams
6
+ from transformers import AutoTokenizer
7
+ from tensorrt_llm import LLM , BuildConfig
8
+ from tensorrt_llm import LlmArgs
9
+ from serve import OpenAIServer
10
+ from dotenv import load_dotenv
5
11
from huggingface_hub import login
6
- from tensorrt_llm . hlapi import BuildConfig , KvCacheConfig
12
+ import requests
7
13
8
- # Enable build caching
9
- os .environ ["TLLM_HLAPI_BUILD_CACHE" ] = "1"
10
- # Optionally, set a custom cache directory
11
- # os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
12
- #HF_TOKEN for downloading models
14
+ @dataclass
15
+ class ServerConfig :
16
+ model : str
17
+ tokenizer : Optional [str ] = None
18
+ max_beam_width : Optional [int ] = BuildConfig .max_beam_width
19
+ max_batch_size : Optional [int ] = BuildConfig .max_batch_size
20
+ max_num_tokens : Optional [int ] = BuildConfig .max_num_tokens
21
+ max_seq_len : Optional [int ] = BuildConfig .max_seq_len
22
+ tp_size : Optional [int ] = 1
23
+ pp_size : Optional [int ] = 1
24
+ kv_cache_free_gpu_memory_fraction : Optional [float ] = 0.9
25
+ trust_remote_code : bool = False
13
26
27
+ @classmethod
28
+ def from_env (cls ) -> 'ServerConfig' :
29
+ model = os .getenv ('TRTLLM_MODEL' )
30
+ if not model :
31
+ raise ValueError ("TRTLLM_MODEL environment variable must be set" )
14
32
33
+ return cls (
34
+ model = model ,
35
+ tokenizer = os .getenv ('TRTLLM_TOKENIZER' ),
36
+ max_beam_width = int (os .getenv ('TRTLLM_MAX_BEAM_WIDTH' , str (BuildConfig .max_beam_width ))) if os .getenv ('TRTLLM_MAX_BEAM_WIDTH' ) else None ,
37
+ max_batch_size = int (os .getenv ('TRTLLM_MAX_BATCH_SIZE' , str (BuildConfig .max_batch_size ))) if os .getenv ('TRTLLM_MAX_BATCH_SIZE' ) else None ,
38
+ max_num_tokens = int (os .getenv ('TRTLLM_MAX_NUM_TOKENS' , str (BuildConfig .max_num_tokens ))) if os .getenv ('TRTLLM_MAX_NUM_TOKENS' ) else None ,
39
+ max_seq_len = int (os .getenv ('TRTLLM_MAX_SEQ_LEN' , str (BuildConfig .max_seq_len ))) if os .getenv ('TRTLLM_MAX_SEQ_LEN' ) else None ,
40
+ tp_size = int (os .getenv ('TRTLLM_TP_SIZE' , '1' )) if os .getenv ('TRTLLM_TP_SIZE' ) else None ,
41
+ pp_size = int (os .getenv ('TRTLLM_PP_SIZE' , '1' )) if os .getenv ('TRTLLM_PP_SIZE' ) else None ,
42
+ kv_cache_free_gpu_memory_fraction = float (os .getenv ('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION' , '0.9' )) if os .getenv ('TRTLLM_KV_CACHE_FREE_GPU_MEMORY_FRACTION' ) else None ,
43
+ trust_remote_code = os .getenv ('TRTLLM_TRUST_REMOTE_CODE' , '' ).lower () in ('true' , '1' , 'yes' )
44
+ )
15
45
16
- hf_token = os .environ ["HF_TOKEN" ]
17
- login (token = hf_token )
46
+ def validate (self ) -> None :
47
+ if not self .model :
48
+ raise ValueError ("Model path or name must be provided" )
18
49
50
+ class TensorRTLLMServer :
51
+ """
52
+ Singleton class to manage TensorRT-LLM server instance and handle requests
53
+ """
54
+ # _instance = None
55
+ # _initialized = False
19
56
57
+ # def __new__(cls):
58
+ # if cls._instance is None:
59
+ # cls._instance = super(TensorRTLLMServer, cls).__new__(cls)
60
+ # return cls._instance
20
61
21
- class TRTLLMWorker :
22
- def __init__ (self , model_path : str ):
23
- self .llm = LLM (model = model_path , enable_build_cache = True , kv_cache_config = KvCacheConfig (), build_config = BuildConfig ())
62
+ def __init__ (self ):
63
+ self ._initialize_server ()
64
+ self .host = '0.0.0.0'
65
+ self .port = 8000
66
+
67
+ def _initialize_server (self ):
68
+ """Initialize the TensorRT-LLM server and load model"""
69
+ # Load environment variables
70
+ load_dotenv ()
24
71
25
- def generate (self , prompts : List [str ], max_tokens : int = 100 ) -> List [str ]:
26
- sampling_params = SamplingParams (max_new_tokens = max_tokens )
27
- outputs = self .llm .generate (prompts , sampling_params )
72
+ # Handle HuggingFace login
73
+ huggingface_token = os .getenv ("HF_TOKEN" )
74
+ if huggingface_token :
75
+ print ("Logging in to Hugging Face..." )
76
+ login (huggingface_token )
77
+
78
+ # Initialize configuration
79
+ self .config = ServerConfig .from_env ()
80
+ self .config .validate ()
81
+
82
+ # Create build configuration
83
+ build_config = BuildConfig (
84
+ max_batch_size = self .config .max_batch_size ,
85
+ max_num_tokens = self .config .max_num_tokens ,
86
+ max_beam_width = self .config .max_beam_width ,
87
+ max_seq_len = self .config .max_seq_len
88
+ )
89
+
90
+ # Initialize LLM
91
+ self .llm = LLM (
92
+ model = self .config .model ,
93
+ tokenizer = self .config .tokenizer ,
94
+ tensor_parallel_size = self .config .tp_size ,
95
+ pipeline_parallel_size = self .config .pp_size ,
96
+ trust_remote_code = self .config .trust_remote_code ,
97
+ build_config = build_config
98
+ )
99
+
100
+ # Initialize tokenizer
101
+ self .tokenizer = AutoTokenizer .from_pretrained (
102
+ self .config .tokenizer or self .config .model ,
103
+ trust_remote_code = self .config .trust_remote_code
104
+ )
105
+
106
+ # Initialize OpenAI compatible server
107
+ self .server = OpenAIServer (
108
+ llm = self .llm ,
109
+ model = self .config .model ,
110
+ hf_tokenizer = self .tokenizer
111
+ )
28
112
29
- results = []
30
- for output in outputs :
31
- results .append (output .outputs [0 ].text )
32
- return results
113
+ asyncio .run (self .server (self .host , self .port ))
33
114
115
+ # Initialize the server at module load time
116
+ server = TensorRTLLMServer ()
117
+
118
+ async def async_handler (job ):
119
+ """Handle the requests asynchronously."""
120
+ job_input = job ["input" ]
121
+ print (f"JOB_INPUT: { job_input } " )
34
122
35
- async def generate_async (self , prompts : List [str ], max_tokens : int = 100 ) -> AsyncGenerator [str , None ]:
36
- sampling_params = SamplingParams (max_new_tokens = max_tokens )
37
-
38
- async for output in self .llm .generate_async (prompts , sampling_params ):
39
- for request_output in output .outputs :
40
- if request_output .text :
41
- yield request_output .text
42
-
43
- # Initialize the worker outside the handler
44
- # This ensures the model is loaded only once when the serverless function starts
45
- # this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
46
- model_path = os .environ ["MODEL_PATH" ]
47
- worker = TRTLLMWorker (model_path )
48
-
49
-
50
- async def handler (job : Dict ) -> AsyncGenerator [Dict [str , Union [str , List [str ]]], None ]:
51
- """Handler function that will be used to process jobs."""
52
- job_input = job ['input' ]
53
- prompts = job_input .get ('prompts' , ["Hello, how are you?" ])
54
- max_tokens = job_input .get ('max_tokens' , 100 )
55
- streaming = job_input .get ('streaming' , False )
123
+ base_url = "http://0.0.0.0:8000"
56
124
57
- try :
58
- if streaming :
59
- results = []
60
- async for chunk in worker .generate_async (prompts , max_tokens ):
61
- results .append (chunk )
62
- yield {"status" : "streaming" , "chunk" : chunk }
63
- yield {"status" : "success" , "output" : results }
125
+ if job_input .get ("openai_route" ):
126
+ openai_route , openai_input = job_input .get ("openai_route" ), job_input .get ("openai_input" )
127
+
128
+ openai_url = f"{ base_url } " + openai_route
129
+ headers = {"Content-Type" : "application/json" }
130
+
131
+ response = requests .post (openai_url , headers = headers , json = openai_input )
132
+ # Process the streamed response
133
+ if openai_input .get ("stream" , False ):
134
+ for formated_chunk in response :
135
+ yield formated_chunk
136
+ else :
137
+ for chunk in response .iter_lines ():
138
+ if chunk :
139
+ decoded_chunk = chunk .decode ('utf-8' )
140
+ yield decoded_chunk
141
+ else :
142
+ generate_url = f"{ base_url } /generate"
143
+ headers = {"Content-Type" : "application/json" }
144
+ # Directly pass `job_input` to `json`. Can we tell users the possible fields of `job_input`?
145
+ response = requests .post (generate_url , json = job_input , headers = headers )
146
+ if response .status_code == 200 :
147
+ yield response .json ()
64
148
else :
65
- results = worker .generate (prompts , max_tokens )
66
- yield {"status" : "success" , "output" : results }
67
- except Exception as e :
68
- yield {"status" : "error" , "message" : str (e )}
149
+ yield {"error" : f"Generate request failed with status code { response .status_code } " , "details" : response .text }
69
150
70
- runpod .serverless .start ({"handler" : handler , "return_aggregate_stream" : True })
151
+ runpod .serverless .start ({"handler" : async_handler , "return_aggregate_stream" : True })
0 commit comments