diff --git a/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py b/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py index a1bb8c1f..0eea92be 100644 --- a/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py +++ b/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py @@ -95,7 +95,7 @@ def _stream( # OpenAI Chat completions API request_messages = self.BuildChatCompletionMessage(messages) - request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": MAX_TOKENS, "stream": True} + request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": True} logging.info(f"request: {request}") try: r = requests.post( @@ -124,7 +124,7 @@ def _stream( prompt = self.BuildCompletionPrompt(messages) req_data = '{"prompt": "' + prompt.encode('unicode_escape').decode("utf-8") - my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(MAX_TOKENS) + ',"stream":true}' + my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(self.MAX_TOKENS) + ',"stream":true}' logging.info('req:') logging.info(my_req_data) @@ -172,7 +172,7 @@ def _call( if inference_endpoint.find("chat/completions") != -1: # OpenAI Chat completions API request_messages = self.BuildChatCompletionMessage(messages) - request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False} + request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False} logging.info(json.dumps(request)) try: r = requests.post(inference_endpoint, @@ -192,7 +192,7 @@ def _call( # OpenAI Completions API prompt = self.BuildCompletionPrompt(messages) logging.info(f"prompt: {prompt}") - request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False} + request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False} logging.info(json.dumps(request)) try: