Skip to content

Commit 4edefe3

Browse files
authored
Feat: Support Azure Workload Identity Credential (#9012)
* Start adding support for passing callable to Azure components * Add to chat version * Fix test * Add reno * Add support to azure doc and text embedder * Rename * update llm metadata extractor * Add tests for text embedder * Update tests * Remove unused fixture and import * Update reno
1 parent 4c1facd commit 4edefe3

16 files changed

+305
-22
lines changed

haystack/components/embedders/azure_document_embedder.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
from more_itertools import batched
99
from openai import APIError
10-
from openai.lib.azure import AzureOpenAI
10+
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
1111
from tqdm import tqdm
1212

1313
from haystack import Document, component, default_from_dict, default_to_dict, logging
14-
from haystack.utils import Secret, deserialize_secrets_inplace
14+
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -57,6 +57,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
5757
max_retries: Optional[int] = None,
5858
*,
5959
default_headers: Optional[Dict[str, str]] = None,
60+
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
6061
):
6162
"""
6263
Creates an AzureOpenAIDocumentEmbedder component.
@@ -102,6 +103,8 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
102103
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
103104
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
104105
:param default_headers: Default headers to send to the AzureOpenAI client.
106+
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
107+
every request.
105108
"""
106109
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
107110
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
@@ -127,11 +130,13 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
127130
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
128131
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
129132
self.default_headers = default_headers or {}
133+
self.azure_ad_token_provider = azure_ad_token_provider
130134

131135
self._client = AzureOpenAI(
132136
api_version=api_version,
133137
azure_endpoint=azure_endpoint,
134138
azure_deployment=azure_deployment,
139+
azure_ad_token_provider=azure_ad_token_provider,
135140
api_key=api_key.resolve_value() if api_key is not None else None,
136141
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
137142
organization=organization,
@@ -153,6 +158,9 @@ def to_dict(self) -> Dict[str, Any]:
153158
:returns:
154159
Dictionary with serialized data.
155160
"""
161+
azure_ad_token_provider_name = None
162+
if self.azure_ad_token_provider:
163+
azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
156164
return default_to_dict(
157165
self,
158166
azure_endpoint=self.azure_endpoint,
@@ -171,6 +179,7 @@ def to_dict(self) -> Dict[str, Any]:
171179
timeout=self.timeout,
172180
max_retries=self.max_retries,
173181
default_headers=self.default_headers,
182+
azure_ad_token_provider=azure_ad_token_provider_name,
174183
)
175184

176185
@classmethod
@@ -184,6 +193,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
184193
Deserialized component.
185194
"""
186195
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
196+
serialized_azure_ad_token_provider = data["init_parameters"].get("azure_ad_token_provider")
197+
if serialized_azure_ad_token_provider:
198+
data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
199+
serialized_azure_ad_token_provider
200+
)
187201
return default_from_dict(cls, data)
188202

189203
def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:

haystack/components/embedders/azure_text_embedder.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import os
66
from typing import Any, Dict, List, Optional
77

8-
from openai.lib.azure import AzureOpenAI
8+
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
99

1010
from haystack import Document, component, default_from_dict, default_to_dict
11-
from haystack.utils import Secret, deserialize_secrets_inplace
11+
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1212

1313

1414
@component
@@ -48,6 +48,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
4848
suffix: str = "",
4949
*,
5050
default_headers: Optional[Dict[str, str]] = None,
51+
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
5152
):
5253
"""
5354
Creates an AzureOpenAITextEmbedder component.
@@ -85,6 +86,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
8586
:param suffix:
8687
A string to add at the end of each text.
8788
:param default_headers: Default headers to send to the AzureOpenAI client.
89+
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
90+
every request.
8891
"""
8992
# Why is this here?
9093
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
@@ -109,11 +112,13 @@ def __init__( # pylint: disable=too-many-positional-arguments
109112
self.prefix = prefix
110113
self.suffix = suffix
111114
self.default_headers = default_headers or {}
115+
self.azure_ad_token_provider = azure_ad_token_provider
112116

113117
self._client = AzureOpenAI(
114118
api_version=api_version,
115119
azure_endpoint=azure_endpoint,
116120
azure_deployment=azure_deployment,
121+
azure_ad_token_provider=azure_ad_token_provider,
117122
api_key=api_key.resolve_value() if api_key is not None else None,
118123
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
119124
organization=organization,
@@ -135,6 +140,9 @@ def to_dict(self) -> Dict[str, Any]:
135140
:returns:
136141
Dictionary with serialized data.
137142
"""
143+
azure_ad_token_provider_name = None
144+
if self.azure_ad_token_provider:
145+
azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
138146
return default_to_dict(
139147
self,
140148
azure_endpoint=self.azure_endpoint,
@@ -149,6 +157,7 @@ def to_dict(self) -> Dict[str, Any]:
149157
timeout=self.timeout,
150158
max_retries=self.max_retries,
151159
default_headers=self.default_headers,
160+
azure_ad_token_provider=azure_ad_token_provider_name,
152161
)
153162

154163
@classmethod
@@ -162,6 +171,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAITextEmbedder":
162171
Deserialized component.
163172
"""
164173
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
174+
serialized_azure_ad_token_provider = data["init_parameters"].get("azure_ad_token_provider")
175+
if serialized_azure_ad_token_provider:
176+
data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
177+
serialized_azure_ad_token_provider
178+
)
165179
return default_from_dict(cls, data)
166180

167181
@component.output_types(embedding=List[float], meta=Dict[str, Any])

haystack/components/extractors/llm_metadata_extractor.py

+7
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
293293
init_parameters["generator_api_params"]["generation_config"]
294294
)
295295

296+
# For AzureOpenAI
297+
serialized_azure_ad_token_provider = init_parameters["generator_api_params"].get("azure_ad_token_provider")
298+
if serialized_azure_ad_token_provider:
299+
data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
300+
serialized_azure_ad_token_provider
301+
)
302+
296303
# For all
297304
serialized_callback_handler = init_parameters["generator_api_params"].get("streaming_callback")
298305
if serialized_callback_handler:

haystack/components/generators/azure.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import os
66
from typing import Any, Callable, Dict, Optional
77

8-
# pylint: disable=import-error
9-
from openai.lib.azure import AzureOpenAI
8+
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
109

1110
from haystack import component, default_from_dict, default_to_dict, logging
1211
from haystack.components.generators import OpenAIGenerator
@@ -21,7 +20,9 @@ class AzureOpenAIGenerator(OpenAIGenerator):
2120
"""
2221
Generates text using OpenAI's large language models (LLMs).
2322
24-
It works with the gpt-4 and gpt-3.5-turbo family of models.
23+
It works with the gpt-4 - type models and supports streaming responses
24+
from OpenAI API.
25+
2526
You can customize how the text is generated by passing parameters to the
2627
OpenAI API. Use the `**generation_kwargs` argument when you initialize
2728
the component or when you run it. Any parameter that works with
@@ -69,6 +70,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
6970
max_retries: Optional[int] = None,
7071
generation_kwargs: Optional[Dict[str, Any]] = None,
7172
default_headers: Optional[Dict[str, str]] = None,
73+
*,
74+
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
7275
):
7376
"""
7477
Initialize the Azure OpenAI Generator.
@@ -109,6 +112,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
109112
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
110113
values are the bias to add to that token.
111114
:param default_headers: Default headers to use for the AzureOpenAI client.
115+
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
116+
every request.
112117
"""
113118
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
114119
# with the API.
@@ -139,11 +144,13 @@ def __init__( # pylint: disable=too-many-positional-arguments
139144
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
140145
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
141146
self.default_headers = default_headers or {}
147+
self.azure_ad_token_provider = azure_ad_token_provider
142148

143149
self.client = AzureOpenAI(
144150
api_version=api_version,
145151
azure_endpoint=azure_endpoint,
146152
azure_deployment=azure_deployment,
153+
azure_ad_token_provider=azure_ad_token_provider,
147154
api_key=api_key.resolve_value() if api_key is not None else None,
148155
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
149156
organization=organization,
@@ -160,6 +167,9 @@ def to_dict(self) -> Dict[str, Any]:
160167
The serialized component as a dictionary.
161168
"""
162169
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
170+
azure_ad_token_provider_name = None
171+
if self.azure_ad_token_provider:
172+
azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
163173
return default_to_dict(
164174
self,
165175
azure_endpoint=self.azure_endpoint,
@@ -174,6 +184,7 @@ def to_dict(self) -> Dict[str, Any]:
174184
timeout=self.timeout,
175185
max_retries=self.max_retries,
176186
default_headers=self.default_headers,
187+
azure_ad_token_provider=azure_ad_token_provider_name,
177188
)
178189

179190
@classmethod
@@ -191,4 +202,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIGenerator":
191202
serialized_callback_handler = init_params.get("streaming_callback")
192203
if serialized_callback_handler:
193204
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
205+
serialized_azure_ad_token_provider = init_params.get("azure_ad_token_provider")
206+
if serialized_azure_ad_token_provider:
207+
data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
208+
serialized_azure_ad_token_provider
209+
)
194210
return default_from_dict(cls, data)

haystack/components/generators/chat/azure.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6-
from typing import Any, Callable, Dict, List, Optional
6+
from typing import Any, Callable, Dict, List, Optional, Union
77

8-
# pylint: disable=import-error
9-
from openai.lib.azure import AsyncAzureOpenAI, AzureOpenAI
8+
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
109

1110
from haystack import component, default_from_dict, default_to_dict, logging
1211
from haystack.components.generators.chat import OpenAIChatGenerator
@@ -22,7 +21,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
2221
"""
2322
Generates text using OpenAI's models on Azure.
2423
25-
It works with the gpt-4 and gpt-3.5-turbo - type models and supports streaming responses
24+
It works with the gpt-4 - type models and supports streaming responses
2625
from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
2726
format in input and output.
2827
@@ -78,6 +77,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
7877
default_headers: Optional[Dict[str, str]] = None,
7978
tools: Optional[List[Tool]] = None,
8079
tools_strict: bool = False,
80+
*,
81+
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
8182
):
8283
"""
8384
Initialize the Azure OpenAI Chat Generator component.
@@ -120,6 +121,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
120121
:param tools_strict:
121122
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
122123
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
124+
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
125+
every request.
123126
"""
124127
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
125128
# with the API.
@@ -149,6 +152,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
149152
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
150153
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
151154
self.default_headers = default_headers or {}
155+
self.azure_ad_token_provider = azure_ad_token_provider
152156

153157
_check_duplicate_tool_names(tools)
154158
self.tools = tools
@@ -164,6 +168,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
164168
"timeout": self.timeout,
165169
"max_retries": self.max_retries,
166170
"default_headers": self.default_headers,
171+
"azure_ad_token_provider": azure_ad_token_provider,
167172
}
168173

169174
self.client = AzureOpenAI(**client_args)
@@ -177,6 +182,9 @@ def to_dict(self) -> Dict[str, Any]:
177182
The serialized component as a dictionary.
178183
"""
179184
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
185+
azure_ad_token_provider_name = None
186+
if self.azure_ad_token_provider:
187+
azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
180188
return default_to_dict(
181189
self,
182190
azure_endpoint=self.azure_endpoint,
@@ -192,6 +200,7 @@ def to_dict(self) -> Dict[str, Any]:
192200
default_headers=self.default_headers,
193201
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
194202
tools_strict=self.tools_strict,
203+
azure_ad_token_provider=azure_ad_token_provider_name,
195204
)
196205

197206
@classmethod
@@ -209,4 +218,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator":
209218
serialized_callback_handler = init_params.get("streaming_callback")
210219
if serialized_callback_handler:
211220
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
221+
serialized_azure_ad_token_provider = init_params.get("azure_ad_token_provider")
222+
if serialized_azure_ad_token_provider:
223+
data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
224+
serialized_azure_ad_token_provider
225+
)
212226
return default_from_dict(cls, data)

haystack/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
_import_structure = {
1111
"auth": ["Secret", "deserialize_secrets_inplace"],
12+
"azure": ["default_azure_ad_token_provider"],
1213
"callable_serialization": ["deserialize_callable", "serialize_callable"],
1314
"device": ["ComponentDevice", "Device", "DeviceMap", "DeviceType"],
1415
"docstore_deserialization": ["deserialize_document_store_in_init_params_inplace"],
@@ -22,6 +23,7 @@
2223

2324
if TYPE_CHECKING:
2425
from .auth import Secret, deserialize_secrets_inplace
26+
from .azure import default_azure_ad_token_provider
2527
from .callable_serialization import deserialize_callable, serialize_callable
2628
from .device import ComponentDevice, Device, DeviceMap, DeviceType
2729
from .docstore_deserialization import deserialize_document_store_in_init_params_inplace

haystack/utils/azure.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from haystack.lazy_imports import LazyImport
6+
7+
with LazyImport(message="Run 'pip install azure-identity") as azure_import:
8+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
9+
10+
11+
def default_azure_ad_token_provider() -> str:
12+
"""
13+
Get a Azure AD token using the DefaultAzureCredential and the "https://cognitiveservices.azure.com/.default" scope.
14+
"""
15+
azure_import.check()
16+
return get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")()

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ extra-dependencies = [
132132
# ComponentTool
133133
"docstring-parser",
134134

135+
# Azure Utils
136+
"azure-identity",
137+
135138
# Test
136139
"pytest",
137140
"pytest-bdd",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
- Added a new parameter `azure_ad_token_provider` to all Azure OpenAI components: `AzureOpenAIGenerator`, `AzureOpenAIChatGenerator`, `AzureOpenAITextEmbedder` and `AzureOpenAIDocumentEmbedder`. This parameter optionally accepts a callable that returns a bearer token, enabling authentication via Azure AD.
5+
- Introduced the utility function `default_azure_token_provider` in `haystack/utils/azure.py`. This function provides a default token provider that is serializable by Haystack. Users can now pass `default_azure_token_provider` as the `azure_ad_token_provider` or implement a custom token provider.

0 commit comments

Comments
 (0)