Skip to content

feat: Allow OpenAI client config in OpenAIChatGenerator and AzureOpenAIChatGenerator #9215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 16, 2025
Merged
20 changes: 7 additions & 13 deletions haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import os
from typing import Any, Dict, List, Optional

import httpx
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,18 +152,12 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
"default_headers": self.default_headers,
}

self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_args)
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_args)

def _init_http_client(self, async_client: bool = False):
"""Internal method to initialize the httpx.Client."""
if not self.http_client_kwargs:
return None
if not isinstance(self.http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**self.http_client_kwargs)
return httpx.Client(**self.http_client_kwargs)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
19 changes: 7 additions & 12 deletions haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import os
from typing import Any, Dict, Optional

import httpx
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI

from haystack import component, default_from_dict, default_to_dict
from haystack.components.embedders import OpenAITextEmbedder
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client


@component
Expand Down Expand Up @@ -138,17 +138,12 @@ def __init__( # pylint: disable=too-many-positional-arguments
"default_headers": self.default_headers,
}

self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_kwargs)
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_kwargs)

def _init_http_client(self, async_client: bool = False):
if not self.http_client_kwargs:
return None
if not isinstance(self.http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**self.http_client_kwargs)
return httpx.Client(**self.http_client_kwargs)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
16 changes: 13 additions & 3 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client


@component
Expand Down Expand Up @@ -66,6 +67,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""

# pylint: disable=super-init-not-called
# ruff: noqa: PLR0913
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
Expand All @@ -83,6 +85,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
tools_strict: bool = False,
*,
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Azure OpenAI Chat Generator component.
Expand Down Expand Up @@ -128,6 +131,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
every request.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom httpx.Client.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Amnah199!
I want to add similar support to other components, and I think they all should have the same description. I suggest adding documentation similar to the following for :param http_client::


A dictionary of keyword arguments to configure a custom httpx.Client for your use case: proxies, authentication and other advanced functionalities of HTTPX.
You can set a proxy with basic authorization using the environment variables: HTTP_PROXY and HTTPS_PROXY, ALL_PROXY and NO_PROXY, for example HTTP_PROXY=http://user:password@your-proxy.net:8080.


A dictionary of keyword arguments to configure a custom `httpx.Client` for your use case: [proxies](https://www.python-httpx.org/advanced/proxies), [authentication](https://www.python-httpx.org/advanced/authentication) and other [advanced functionalities](https://www.python-httpx.org/advanced/clients) of HTTPX.
You can set a proxy with basic authorization using [the environment variables](https://www.python-httpx.org/environment_variables): `HTTP_PROXY` and `HTTPS_PROXY`, `ALL_PROXY` and `NO_PROXY`, for example `HTTP_PROXY=http://user:password@your-proxy.net:8080`.

"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
Expand Down Expand Up @@ -158,7 +163,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

self.http_client_kwargs = http_client_kwargs
_check_duplicate_tool_names(list(tools or []))
self.tools = tools
self.tools_strict = tools_strict
Expand All @@ -176,8 +181,12 @@ def __init__( # pylint: disable=too-many-positional-arguments
"azure_ad_token_provider": azure_ad_token_provider,
}

self.client = AzureOpenAI(**client_args)
self.async_client = AsyncAzureOpenAI(**client_args)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -206,6 +215,7 @@ def to_dict(self) -> Dict[str, Any]:
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
azure_ad_token_provider=azure_ad_token_provider_name,
http_client_kwargs=self.http_client_kwargs,
)

@classmethod
Expand Down
15 changes: 11 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
max_retries: Optional[int] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: bool = False,
http_client_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini
Expand Down Expand Up @@ -138,6 +140,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom httpx.Client.
"""
self.api_key = api_key
self.model = model
Expand All @@ -149,7 +153,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.max_retries = max_retries
self.tools = tools # Store tools as-is, whether it's a list or a Toolset
self.tools_strict = tools_strict

self.http_client_kwargs = http_client_kwargs
# Check for duplicate tool names
_check_duplicate_tool_names(list(self.tools or []))

Expand All @@ -158,16 +162,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))

client_args: Dict[str, Any] = {
client_kwargs: Dict[str, Any] = {
"api_key": api_key.resolve_value(),
"organization": organization,
"base_url": api_base_url,
"timeout": timeout,
"max_retries": max_retries,
}

self.client = OpenAI(**client_args)
self.async_client = AsyncOpenAI(**client_args)
self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
self.async_client = AsyncOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -195,6 +201,7 @@ def to_dict(self) -> Dict[str, Any]:
max_retries=self.max_retries,
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
http_client_kwargs=self.http_client_kwargs,
)

@classmethod
Expand Down
28 changes: 28 additions & 0 deletions haystack/utils/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import httpx


def init_http_client(http_client_kwargs: Optional[Dict[str, Any]] = None, async_client: bool = False):
"""
Initialize an httpx client based on the http_client_kwargs.
:param http_client_kwargs:
The kwargs to pass to the httpx client.
:param async_client:
Whether to initialize an async client.
:returns:
An httpx client.
"""
if not http_client_kwargs:
return None
if not isinstance(http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**http_client_kwargs)
return httpx.Client(**http_client_kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using openai.DefaultHttpxClient because it provides specific configurations from OpenAI (see here).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents.

In case the user wants to provide their own client, I would leave them free to choose the configurations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anakin87 As per my understanding, DefaultAsyncHttpxClient does not limit the configuration options, just the default parameters are pre-configured to be optimal. The user you retains full configurability as with a regular httpx.AsyncClient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I understand this point, as a user I would find it unexpected to pass some client kwargs and then get a client that automatically sets others as well.

If we decide to go this path, I would clearly indicate this in all docstrings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexengrig I see the merits in both approaches. However, we've decided to stick with httpx.Client for now to avoid introducing potential confusion from using a preconfigured client. This way, users have full transparency and control when creating and customizing their own client.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
`OpenAIChatGenerator` and `AzureOpenAIChatGenerator` now support custom HTTP client config via `http_client_kwargs`, enabling proxy and SSL setup.
2 changes: 2 additions & 0 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch):
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"tools": [
Expand Down Expand Up @@ -205,6 +206,7 @@ def test_from_dict(self, weather_tool, component_tool, monkeypatch):
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"tools": [
Expand Down
7 changes: 4 additions & 3 deletions test/components/embedders/test_azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
from haystack.utils.azure import default_azure_ad_token_provider
from unittest.mock import Mock, patch
from haystack.utils.http_client import init_http_client


class TestAzureOpenAIDocumentEmbedder:
Expand Down Expand Up @@ -220,14 +221,14 @@ def test_init_http_client(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")

embedder = AzureOpenAIDocumentEmbedder()
client = embedder._init_http_client()
client = init_http_client(embedder.http_client_kwargs, async_client=False)
assert client is None

embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
client = embedder._init_http_client(async_client=False)
client = init_http_client(embedder.http_client_kwargs, async_client=False)
assert isinstance(client, httpx.Client)

client = embedder._init_http_client(async_client=True)
client = init_http_client(embedder.http_client_kwargs, async_client=True)
assert isinstance(client, httpx.AsyncClient)

def test_http_client_kwargs_type_validation(self, monkeypatch):
Expand Down
7 changes: 4 additions & 3 deletions test/components/embedders/test_azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from haystack.components.embedders import AzureOpenAITextEmbedder
from haystack.utils.azure import default_azure_ad_token_provider
from haystack.utils.http_client import init_http_client


class TestAzureOpenAITextEmbedder:
Expand Down Expand Up @@ -174,14 +175,14 @@ def test_init_http_client(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")

embedder = AzureOpenAITextEmbedder()
client = embedder._init_http_client()
client = init_http_client(embedder.http_client_kwargs, async_client=False)
assert client is None

embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
client = embedder._init_http_client(async_client=False)
client = init_http_client(embedder.http_client_kwargs, async_client=False)
assert isinstance(client, httpx.Client)

client = embedder._init_http_client(async_client=True)
client = init_http_client(embedder.http_client_kwargs, async_client=True)
assert isinstance(client, httpx.AsyncClient)

def test_http_client_kwargs_type_validation(self, monkeypatch):
Expand Down
1 change: 1 addition & 0 deletions test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_to_dict_openai(self, monkeypatch):
"timeout": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
},
Expand Down
6 changes: 6 additions & 0 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_to_dict_default(self, monkeypatch):
"tools": None,
"tools_strict": False,
"azure_ad_token_provider": None,
"http_client_kwargs": None,
},
}

Expand All @@ -124,6 +125,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
max_retries=10,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
azure_ad_token_provider=default_azure_ad_token_provider,
http_client_kwargs={"proxy": "http://localhost:8080"},
)
data = component.to_dict()
assert data == {
Expand All @@ -143,6 +145,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"tools_strict": False,
"default_headers": {},
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
"http_client_kwargs": {"proxy": "http://localhost:8080"},
},
}

Expand Down Expand Up @@ -175,6 +178,7 @@ def test_from_dict(self, monkeypatch):
}
],
"tools_strict": False,
"http_client_kwargs": None,
},
}

Expand All @@ -196,6 +200,7 @@ def test_from_dict(self, monkeypatch):
Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
]
assert generator.tools_strict == False
assert generator.http_client_kwargs is None

def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
Expand Down Expand Up @@ -225,6 +230,7 @@ def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
"tools": None,
"tools_strict": False,
"azure_ad_token_provider": None,
"http_client_kwargs": None,
},
}
},
Expand Down
Loading
Loading