Skip to content

Commit cb96815

Browse files
authored
Fix bug where AzureOpenAIEmbeddings inherits from OpenAIEmbeddings (#203)
* Fix bug where AzureOpenAIEmbeddings inherits from OpenAIEmbeddings * Cast embedding to list[float] * Refactored embed_query method to be in base class
1 parent bc8540e commit cb96815

File tree

9 files changed

+86
-32
lines changed

9 files changed

+86
-32
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components.
1414
- Updated examples and unit tests to reflect the removal of async driver support.
1515

16+
### Fixed
17+
- Resolved issue with `AzureOpenAIEmbeddings` incorrectly inheriting from `OpenAIEmbeddings`, now inherits from `BaseOpenAIEmbeddings`.
1618

1719
## 1.1.0
1820

examples/customize/embeddings/azure_openai_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
from neo4j_graphrag.embeddings import AzureOpenAIEmbeddings
66

7-
embeder = AzureOpenAIEmbeddings(
7+
embedder = AzureOpenAIEmbeddings(
88
model="text-embedding-ada-002",
99
azure_endpoint="https://my-endpoint.openai.azure.com/",
1010
api_key="<my key>",
1111
api_version="<update version>",
1212
)
13-
res = embeder.embed_query("my question")
13+
res = embedder.embed_query("my question")
1414
print(res[:10])

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,22 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Any
18+
import abc
19+
from typing import TYPE_CHECKING, Any
20+
1921
from neo4j_graphrag.embeddings.base import Embedder
2022

23+
if TYPE_CHECKING:
24+
import openai
2125

22-
class OpenAIEmbeddings(Embedder):
23-
"""
24-
OpenAI embeddings class.
25-
This class uses the OpenAI python client to generate embeddings for text data.
2626

27-
Args:
28-
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
29-
kwargs: All other parameters will be passed to the openai.OpenAI init.
27+
class BaseOpenAIEmbeddings(Embedder, abc.ABC):
28+
"""
29+
Abstract base class for OpenAI embeddings.
3030
"""
3131

32+
client: openai.OpenAI
33+
3234
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
3335
try:
3436
import openai
@@ -39,23 +41,52 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None
3941
)
4042
self.openai = openai
4143
self.model = model
42-
self.openai_client = self.openai.OpenAI(**kwargs)
44+
self.client = self._initialize_client(**kwargs)
45+
46+
@abc.abstractmethod
47+
def _initialize_client(self, **kwargs: Any) -> Any:
48+
"""
49+
Initialize the OpenAI client.
50+
Must be implemented by subclasses.
51+
"""
52+
pass
4353

4454
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
4555
"""
46-
Generate embeddings for a given query using a OpenAI text embedding model.
56+
Generate embeddings for a given query using an OpenAI text embedding model.
4757
4858
Args:
4959
text (str): The text to generate an embedding for.
5060
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
5161
"""
52-
response = self.openai_client.embeddings.create(
53-
input=text, model=self.model, **kwargs
54-
)
55-
return response.data[0].embedding
62+
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
63+
embedding: list[float] = response.data[0].embedding
64+
return embedding
5665

5766

58-
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
59-
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
60-
super().__init__(model, **kwargs)
61-
self.openai_client = self.openai.AzureOpenAI(**kwargs)
67+
class OpenAIEmbeddings(BaseOpenAIEmbeddings):
68+
"""
69+
OpenAI embeddings class.
70+
This class uses the OpenAI python client to generate embeddings for text data.
71+
72+
Args:
73+
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
74+
kwargs: All other parameters will be passed to the openai.OpenAI init.
75+
"""
76+
77+
def _initialize_client(self, **kwargs: Any) -> Any:
78+
return self.openai.OpenAI(**kwargs)
79+
80+
81+
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
82+
"""
83+
Azure OpenAI embeddings class.
84+
This class uses the Azure OpenAI python client to generate embeddings for text data.
85+
86+
Args:
87+
model (str): The name of the Azure OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
88+
kwargs: All other parameters will be passed to the openai.AzureOpenAI init.
89+
"""
90+
91+
def _initialize_client(self, **kwargs: Any) -> Any:
92+
return self.openai.AzureOpenAI(**kwargs)

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
from typing import Any, Optional, TYPE_CHECKING, Iterable
18+
from typing import TYPE_CHECKING, Any, Iterable, Optional
1919

2020
from ..exceptions import LLMGenerationError
2121
from .base import LLMInterface

tests/unit/embeddings/test_openai_embedder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# limitations under the License.
1515
from unittest.mock import MagicMock, Mock, patch
1616

17+
import openai
1718
import pytest
1819
from neo4j_graphrag.embeddings.openai import (
1920
AzureOpenAIEmbeddings,
2021
OpenAIEmbeddings,
2122
)
22-
import openai
2323

2424

2525
def get_mock_openai() -> MagicMock:
@@ -71,3 +71,24 @@ def test_azure_openai_embedder_happy_path(mock_import: Mock) -> None:
7171
res = embedder.embed_query("my text")
7272
assert isinstance(res, list)
7373
assert res == [1.0, 2.0]
74+
75+
76+
def test_azure_openai_embedder_does_not_call_openai_client() -> None:
77+
from unittest.mock import patch
78+
79+
mock_openai = get_mock_openai()
80+
81+
with patch.dict("sys.modules", {"openai": mock_openai}):
82+
AzureOpenAIEmbeddings(
83+
model="text-embedding-ada-002",
84+
azure_endpoint="https://test.openai.azure.com/",
85+
api_key="my_key",
86+
api_version="2023-05-15",
87+
)
88+
89+
mock_openai.OpenAI.assert_not_called()
90+
mock_openai.AzureOpenAI.assert_called_once_with(
91+
azure_endpoint="https://test.openai.azure.com/",
92+
api_key="my_key",
93+
api_version="2023-05-15",
94+
)

tests/unit/embeddings/test_sentence_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from unittest.mock import MagicMock, patch, Mock
1+
from unittest.mock import MagicMock, Mock, patch
22

33
import numpy as np
44
import pytest
5+
import torch
56
from neo4j_graphrag.embeddings.base import Embedder
67
from neo4j_graphrag.embeddings.sentence_transformers import (
78
SentenceTransformerEmbeddings,
89
)
9-
import torch
1010

1111

1212
def get_mock_sentence_transformers() -> MagicMock:

tests/unit/llm/test_anthropic_llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from unittest.mock import AsyncMock, MagicMock, patch, Mock
16+
import sys
17+
from typing import Generator
18+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1719

20+
import anthropic
1821
import pytest
1922
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
20-
import sys
21-
import anthropic
22-
from typing import Generator
2323

2424

2525
@pytest.fixture

tests/unit/llm/test_cohere_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import sys
16+
from typing import Generator
1517
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1618

1719
import cohere.core
1820
import pytest
1921
from neo4j_graphrag.exceptions import LLMGenerationError
2022
from neo4j_graphrag.llm import LLMResponse
2123
from neo4j_graphrag.llm.cohere_llm import CohereLLM
22-
import sys
23-
from typing import Generator
2424

2525

2626
@pytest.fixture

tests/unit/llm/test_openai_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from unittest.mock import MagicMock, patch, Mock
15+
from unittest.mock import MagicMock, Mock, patch
1616

17+
import openai
1718
import pytest
1819
from neo4j_graphrag.llm import LLMResponse
1920
from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM
20-
import openai
2121

2222

2323
def get_mock_openai() -> MagicMock:

0 commit comments

Comments
 (0)