Skip to content

Commit ff4ed5a

Browse files
authored
Fix import and init issue with using OpenAIEmbeddings (#79)
* Fix init class of OpenAIEmbeddings * Removed sentence_transformers from init
1 parent e059839 commit ff4ed5a

File tree

6 files changed

+57
-13
lines changed

6 files changed

+57
-13
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Next
44

5+
### Fixed
6+
- Corrected initialization to allow specifying the embedding model name.
7+
- Removed sentence_transformers from embeddings/__init__.py to avoid ImportError when the package is not installed.
8+
59
## 0.3.0
610

711
### Added

docs/source/user_guide.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ The `OpenAIEmbedder` was illustrated previously. Here is how to use the `Sentenc
266266

267267
.. code:: python
268268
269-
from neo4j_genai.embeddings import SentenceTransformerEmbeddings
269+
from neo4j_genai.embeddings.sentence_transformers import SentenceTransformerEmbeddings
270270
271271
embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model
272272
Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
from .sentence_transformers import SentenceTransformerEmbeddings
2-
3-
__all__ = [
4-
"SentenceTransformerEmbeddings",
5-
]
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

src/neo4j_genai/embeddings/openai.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
from __future__ import annotations
217

318
from typing import Any
@@ -6,7 +21,7 @@
621

722

823
class OpenAIEmbeddings(Embedder):
9-
def __init__(self, *args: Any, **kwargs: Any) -> None:
24+
def __init__(self, model: str = "text-embedding-ada-002") -> None:
1025
try:
1126
import openai
1227
except ImportError:
@@ -15,10 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
1530
"Please install it with `pip install openai`."
1631
)
1732

18-
self.model = openai.OpenAI(*args, **kwargs)
33+
self.openai_model = openai.OpenAI()
34+
self.model = model
1935

20-
def embed_query(
21-
self, text: str, model: str = "text-embedding-ada-002", **kwargs: Any
22-
) -> list[float]:
23-
response = self.model.embeddings.create(input=text, model=model, **kwargs)
36+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
37+
response = self.openai_model.embeddings.create(
38+
input=text, model=self.model, **kwargs
39+
)
2440
return response.data[0].embedding

src/neo4j_genai/embeddings/sentence_transformers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
from typing import Any
217

318
import numpy as np

tests/unit/embeddings/test_sentence_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
from neo4j_genai.embedder import Embedder
6-
from neo4j_genai.embeddings import SentenceTransformerEmbeddings
6+
from neo4j_genai.embeddings.sentence_transformers import SentenceTransformerEmbeddings
77

88

99
@patch("sentence_transformers.SentenceTransformer")

0 commit comments

Comments
 (0)