4
4
import logging as python_logging
5
5
import os
6
6
from datetime import datetime
7
- from typing import Any , ClassVar , Dict , List , Optional
7
+ from typing import Any , Dict , List , Optional , Union
8
8
9
9
from azure .core .credentials import AzureKeyCredential
10
10
from azure .core .exceptions import ClientAuthenticationError , HttpResponseError , ResourceNotFoundError
85
85
86
86
87
87
class AzureAISearchDocumentStore :
88
- TYPE_MAP : ClassVar [Dict [str , type ]] = {"str" : str , "int" : int , "float" : float , "bool" : bool , "datetime" : datetime }
89
88
90
89
def __init__ (
91
90
self ,
@@ -94,8 +93,8 @@ def __init__(
94
93
azure_endpoint : Secret = Secret .from_env_var ("AZURE_AI_SEARCH_ENDPOINT" , strict = True ), # noqa: B008
95
94
index_name : str = "default" ,
96
95
embedding_dimension : int = 768 ,
97
- metadata_fields : Optional [Dict [str , type ]] = None ,
98
- vector_search_configuration : VectorSearch = None ,
96
+ metadata_fields : Optional [Dict [str , Union [ SearchField , type ] ]] = None ,
97
+ vector_search_configuration : Optional [ VectorSearch ] = None ,
99
98
** index_creation_kwargs ,
100
99
):
101
100
"""
@@ -106,10 +105,22 @@ def __init__(
106
105
:param api_key: The API key to use for authentication.
107
106
:param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created.
108
107
:param embedding_dimension: Dimension of the embeddings.
109
- :param metadata_fields: A dictionary of metadata keys and their types to create
110
- additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic,
111
- it is necessary to specify the metadata fields in advance.
112
- (e.g. metadata_fields = {"author": str, "date": datetime})
108
+ :param metadata_fields: A dictionary mapping metadata field names to their corresponding field definitions.
109
+ Each field can be defined either as:
110
+ - A SearchField object to specify detailed field configuration like type, searchability, and filterability
111
+ - A Python type (`str`, `bool`, `int`, `float`, or `datetime`) to create a simple filterable field
112
+
113
+ These fields are automatically added when creating the search index.
114
+ Example:
115
+ metadata_fields={
116
+ "Title": SearchField(
117
+ name="Title",
118
+ type="Edm.String",
119
+ searchable=True,
120
+ filterable=True
121
+ ),
122
+ "Pages": int
123
+ }
113
124
:param vector_search_configuration: Configuration option related to vector search.
114
125
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.
115
126
@@ -139,13 +150,12 @@ def __init__(
139
150
self ._index_name = index_name
140
151
self ._embedding_dimension = embedding_dimension
141
152
self ._dummy_vector = [- 10.0 ] * self ._embedding_dimension
142
- self ._metadata_fields = metadata_fields
153
+ self ._metadata_fields = self . _normalize_metadata_index_fields ( metadata_fields )
143
154
self ._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
144
155
self ._index_creation_kwargs = index_creation_kwargs
145
156
146
157
@property
147
158
def client (self ) -> SearchClient :
148
-
149
159
# resolve secrets for authentication
150
160
resolved_endpoint = (
151
161
self ._azure_endpoint .resolve_value () if isinstance (self ._azure_endpoint , Secret ) else self ._azure_endpoint
@@ -185,6 +195,45 @@ def client(self) -> SearchClient:
185
195
186
196
return self ._client
187
197
198
+ def _normalize_metadata_index_fields (
199
+ self , metadata_fields : Optional [Dict [str , Union [SearchField , type ]]]
200
+ ) -> Dict [str , SearchField ]:
201
+ """Create a list of index fields for storing metadata values."""
202
+
203
+ if not metadata_fields :
204
+ return {}
205
+
206
+ normalized_fields = {}
207
+
208
+ for key , value in metadata_fields .items ():
209
+ if isinstance (value , SearchField ):
210
+ if value .name == key :
211
+ normalized_fields [key ] = value
212
+ else :
213
+ msg = f"Name of SearchField ('{ value .name } ') must match metadata field name ('{ key } ')"
214
+ raise ValueError (msg )
215
+ else :
216
+ if not key [0 ].isalpha ():
217
+ msg = (
218
+ f"Azure Search index only allows field names starting with letters. "
219
+ f"Invalid key: { key } will be dropped."
220
+ )
221
+ logger .warning (msg )
222
+ continue
223
+
224
+ field_type = type_mapping .get (value )
225
+ if not field_type :
226
+ error_message = f"Unsupported field type for key '{ key } ': { value } "
227
+ raise ValueError (error_message )
228
+
229
+ normalized_fields [key ] = SimpleField (
230
+ name = key ,
231
+ type = field_type ,
232
+ filterable = True ,
233
+ )
234
+
235
+ return normalized_fields
236
+
188
237
def _create_index (self ) -> None :
189
238
"""
190
239
Internally creates a new search index.
@@ -205,29 +254,18 @@ def _create_index(self) -> None:
205
254
]
206
255
207
256
if self ._metadata_fields :
208
- default_fields .extend (self ._create_metadata_index_fields (self ._metadata_fields ))
257
+ default_fields .extend (self ._metadata_fields .values ())
258
+
209
259
index = SearchIndex (
210
260
name = self ._index_name ,
211
261
fields = default_fields ,
212
262
vector_search = self ._vector_search_configuration ,
213
263
** self ._index_creation_kwargs ,
214
264
)
265
+
215
266
if self ._index_client :
216
267
self ._index_client .create_index (index )
217
268
218
- @classmethod
219
- def _deserialize_metadata_fields (cls , fields : Optional [Dict [str , str ]]) -> Optional [Dict [str , type ]]:
220
- """Convert string representations back to type objects."""
221
- if not fields :
222
- return None
223
- try :
224
- # Use the class-level TYPE_MAP for conversion.
225
- ans = {key : cls .TYPE_MAP [value ] for key , value in fields .items ()}
226
- return ans
227
- except KeyError as e :
228
- msg = f"Unsupported type encountered in metadata_fields: { e } "
229
- raise ValueError (msg ) from e
230
-
231
269
@staticmethod
232
270
def _serialize_index_creation_kwargs (index_creation_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
233
271
"""
@@ -265,28 +303,19 @@ def _deserialize_index_creation_kwargs(cls, data: Dict[str, Any]) -> Any:
265
303
return result [key ]
266
304
267
305
def to_dict (self ) -> Dict [str , Any ]:
268
- # This is not the best solution to serialise this class but is the fastest to implement.
269
- # Not all kwargs types can be serialised to text so this can fail. We must serialise each
270
- # type explicitly to handle this properly.
271
306
"""
272
307
Serializes the component to a dictionary.
273
308
274
309
:returns:
275
310
Dictionary with serialized data.
276
311
"""
277
-
278
- if self ._metadata_fields :
279
- serialized_metadata = {key : value .__name__ for key , value in self ._metadata_fields .items ()}
280
- else :
281
- serialized_metadata = None
282
-
283
312
return default_to_dict (
284
313
self ,
285
314
azure_endpoint = self ._azure_endpoint .to_dict () if self ._azure_endpoint else None ,
286
315
api_key = self ._api_key .to_dict () if self ._api_key else None ,
287
316
index_name = self ._index_name ,
288
317
embedding_dimension = self ._embedding_dimension ,
289
- metadata_fields = serialized_metadata ,
318
+ metadata_fields = { key : value . as_dict () for key , value in self . _metadata_fields . items ()} ,
290
319
vector_search_configuration = self ._vector_search_configuration .as_dict (),
291
320
** self ._serialize_index_creation_kwargs (self ._index_creation_kwargs ),
292
321
)
@@ -303,7 +332,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
303
332
Deserialized component.
304
333
"""
305
334
if (fields := data ["init_parameters" ]["metadata_fields" ]) is not None :
306
- data ["init_parameters" ]["metadata_fields" ] = cls ._deserialize_metadata_fields (fields )
335
+ data ["init_parameters" ]["metadata_fields" ] = {
336
+ key : SearchField .from_dict (field ) for key , field in fields .items ()
337
+ }
338
+ else :
339
+ data ["init_parameters" ]["metadata_fields" ] = {}
307
340
308
341
for key , _value in AZURE_CLASS_MAPPING .items ():
309
342
if key in data ["init_parameters" ]:
@@ -461,46 +494,12 @@ def _convert_haystack_document_to_azure(self, document: Document) -> Dict[str, A
461
494
462
495
return index_document
463
496
464
- def _create_metadata_index_fields (self , metadata : Dict [str , Any ]) -> List [SimpleField ]:
465
- """Create a list of index fields for storing metadata values."""
466
-
467
- index_fields = []
468
- metadata_field_mapping = self ._map_metadata_field_types (metadata )
469
-
470
- for key , field_type in metadata_field_mapping .items ():
471
- index_fields .append (SimpleField (name = key , type = field_type , filterable = True ))
472
-
473
- return index_fields
474
-
475
- def _map_metadata_field_types (self , metadata : Dict [str , type ]) -> Dict [str , str ]:
476
- """Map metadata field types to Azure Search field types."""
477
-
478
- metadata_field_mapping = {}
479
-
480
- for key , value_type in metadata .items ():
481
-
482
- if not key [0 ].isalpha ():
483
- msg = (
484
- f"Azure Search index only allows field names starting with letters. "
485
- f"Invalid key: { key } will be dropped."
486
- )
487
- logger .warning (msg )
488
- continue
489
-
490
- field_type = type_mapping .get (value_type )
491
- if not field_type :
492
- error_message = f"Unsupported field type for key '{ key } ': { value_type } "
493
- raise ValueError (error_message )
494
- metadata_field_mapping [key ] = field_type
495
-
496
- return metadata_field_mapping
497
-
498
497
def _embedding_retrieval (
499
498
self ,
500
499
query_embedding : List [float ],
501
500
* ,
502
501
top_k : int = 10 ,
503
- filters : Optional [Dict [ str , Any ] ] = None ,
502
+ filters : Optional [str ] = None ,
504
503
** kwargs ,
505
504
) -> List [Document ]:
506
505
"""
@@ -534,7 +533,7 @@ def _bm25_retrieval(
534
533
self ,
535
534
query : str ,
536
535
top_k : int = 10 ,
537
- filters : Optional [Dict [ str , Any ] ] = None ,
536
+ filters : Optional [str ] = None ,
538
537
** kwargs ,
539
538
) -> List [Document ]:
540
539
"""
@@ -567,7 +566,7 @@ def _hybrid_retrieval(
567
566
query : str ,
568
567
query_embedding : List [float ],
569
568
top_k : int = 10 ,
570
- filters : Optional [Dict [ str , Any ] ] = None ,
569
+ filters : Optional [str ] = None ,
571
570
** kwargs ,
572
571
) -> List [Document ]:
573
572
"""
0 commit comments