Skip to content

Commit

Permalink
remove monkey patch, openshift ai now works with max_completion_tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Shawn Hurley <shawn@hurley.page>
  • Loading branch information
shawn-hurley committed Feb 24, 2025
1 parent f5814b5 commit 76956f6
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,6 @@ def __init__(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

# NOTE(JonahSussman): This is a hack to prevent `max_tokens`
# from getting converted to `max_completion_tokens` for every
# model, except for the o1 and o3 family of models.

@property # type: ignore[misc]
def _default_params(self: ChatOpenAI) -> dict[str, Any]:
return super(ChatOpenAI, self)._default_params

def _get_request_payload(
self: ChatOpenAI,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict: # type: ignore[type-arg]
return super(ChatOpenAI, self)._get_request_payload(
input_, stop=stop, **kwargs
)

if not (model_id.startswith("o1") or model_id.startswith("o3")):
ChatOpenAI._default_params = _default_params # type: ignore[method-assign]
ChatOpenAI._get_request_payload = _get_request_payload # type: ignore[method-assign]
else:
if "streaming" in model_args:
del model_args["streaming"]
if "temperature" in model_args:
del model_args["temperature"]

case "ChatBedrock":
model_class = ChatBedrock

Expand Down Expand Up @@ -177,6 +149,8 @@ def _get_request_payload(
raise Exception(f"Unrecognized provider '{config.provider}'")

self.provider_id: str = config.provider
# In the future we should consider https://github.com/langchain-ai/langchain/blob/b7a1705052f763b8a11e164041efef1465c07595/libs/langchain/langchain/chat_models/base.py#L85
# as this seems to be the more correct way of doing things.
self.llm: BaseChatModel = model_class(**model_args)
self.model_id: str = model_id

Expand All @@ -193,6 +167,20 @@ def validate_environment(
current model provider.
"""

# The only time, that model must be invoked is when the pydantic model has not already validated
# for the model class. To see a quick list of models that use model_kwargs
# https://github.com/search?q=repo%3Alangchain-ai%2Flangchain%20_build_model_kwargs&type=code
# if there are extras, not defined in the model but saved/used in some other place then we need to
# validate them.
has_model_kwargs = (
hasattr(self.llm, "model_kwargs")
and isinstance(self.llm.model_kwargs, dict)
and self.llm.model_kwargs
)
if not self.llm.model_extra and not has_model_kwargs:
LOG.debug("validation done by pydantic")
return

cpr = SimplePathResolver("validate_environment.json")

def challenge(k: str) -> BaseMessage:
Expand Down

0 comments on commit 76956f6

Please sign in to comment.