From 94b9fae0e9c110f4190174a963b8a246d7de5103 Mon Sep 17 00:00:00 2001 From: Gavan Kwan Date: Wed, 8 Jan 2025 15:49:53 -0800 Subject: [PATCH] DSE-41570 - Fix MAX_TOKENS error for AI Inference models. Should be self.MAX_TOKENS --- .../cloudera_ai_inference_provider.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py b/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py index a1bb8c1f..0eea92be 100644 --- a/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py +++ b/packages/cloudera-ai-inference-package/cloudera_ai_inference_package/cloudera_ai_inference_provider.py @@ -95,7 +95,7 @@ def _stream( # OpenAI Chat completions API request_messages = self.BuildChatCompletionMessage(messages) - request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": MAX_TOKENS, "stream": True} + request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": True} logging.info(f"request: {request}") try: r = requests.post( @@ -124,7 +124,7 @@ def _stream( prompt = self.BuildCompletionPrompt(messages) req_data = '{"prompt": "' + prompt.encode('unicode_escape').decode("utf-8") - my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(MAX_TOKENS) + ',"stream":true}' + my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(self.MAX_TOKENS) + ',"stream":true}' logging.info('req:') logging.info(my_req_data) @@ -172,7 +172,7 @@ def _call( if inference_endpoint.find("chat/completions") != -1: # OpenAI Chat completions API request_messages = self.BuildChatCompletionMessage(messages) - request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False} + request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False} logging.info(json.dumps(request)) try: r = requests.post(inference_endpoint, @@ -192,7 +192,7 @@ def _call( # OpenAI Completions API prompt = self.BuildCompletionPrompt(messages) logging.info(f"prompt: {prompt}") - request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False} + request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False} logging.info(json.dumps(request)) try: