diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a9208e6a..1a8c84da0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Fixed + +- Fix a bug where the `OllamaEmbedder` would return a `list[list[float]]` instead of the expected `list[float]`. + ## 1.4.1 ### Fixed diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index d0c9fec9b..78775ba60 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -55,10 +55,12 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: **kwargs, ) - if embeddings_response is None or embeddings_response.embeddings is None: + if embeddings_response is None or not embeddings_response.embeddings: raise EmbeddingsGenerationError("Failed to retrieve embeddings.") - embedding = embeddings_response.embeddings + embeddings = embeddings_response.embeddings + # client always returns a sequence of sequences + embedding = embeddings[0] if not isinstance(embedding, list): raise EmbeddingsGenerationError("Embedding is not a list of floats.") diff --git a/tests/unit/embeddings/test_ollama_embedder.py b/tests/unit/embeddings/test_ollama_embedder.py index dded2b87f..1cddd1103 100644 --- a/tests/unit/embeddings/test_ollama_embedder.py +++ b/tests/unit/embeddings/test_ollama_embedder.py @@ -16,6 +16,7 @@ import pytest from neo4j_graphrag.embeddings.ollama import OllamaEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("builtins.__import__", side_effect=ImportError) @@ -27,9 +28,19 @@ def test_ollama_embedder_missing_dependency(mock_import: Mock) -> None: @patch("builtins.__import__") def test_ollama_embedder_happy_path(mock_import: Mock) -> None: mock_import.return_value.Client.return_value.embed.return_value = MagicMock( - embeddings=[1.0, 2.0], + embeddings=[[1.0, 2.0]], ) embedder = OllamaEmbeddings(model="test") res = embedder.embed_query("my text") assert isinstance(res, list) assert res == [1.0, 2.0] + + +@patch("builtins.__import__") +def test_ollama_embedder_empty_list(mock_import: Mock) -> None: + mock_import.return_value.Client.return_value.embed.return_value = MagicMock( + embeddings=[], + ) + embedder = OllamaEmbeddings(model="test") + with pytest.raises(EmbeddingsGenerationError): + embedder.embed_query("my text")