Skip to content

Commit 0ac06b7

Browse files
authored
Use json-repair package to fix LLM generated json (#226)
* Use json-repair package to fix LLM generated json * Removed redundant repaired_json.strip() * Rename test to test_extractor_llm_unfixable_json * Readded llama-index to dependencies * Removed print * Renamed JSONRepairError to InvalidJSONError * Use cast for repaired_json instead of isinstance * Add InvalidJSONError hyperlink to API docs
1 parent c166afc commit 0ac06b7

File tree

9 files changed

+1233
-1243
lines changed

9 files changed

+1233
-1243
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
## Next
44

5+
## Added
6+
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
7+
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.
8+
9+
## Changed
10+
- Updated LLM prompts to include stricter instructions for generating valid JSON.
11+
512
### Fixed
613
- Added schema functions to the documentation.
714

docs/source/api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ Errors
384384

385385
* :class:`neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError`
386386

387+
* :class:`neo4j_graphrag.experimental.pipeline.exceptions.InvalidJSONError`
388+
387389

388390
Neo4jGraphRagError
389391
==================
@@ -509,3 +511,10 @@ PipelineStatusUpdateError
509511

510512
.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError
511513
:show-inheritance:
514+
515+
516+
InvalidJSONError
517+
================
518+
519+
.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.InvalidJSONError
520+
:show-inheritance:

poetry.lock

Lines changed: 1152 additions & 1188 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ llama-index = {version = "^0.10.55", optional = true }
4848
openai = {version = "^1.51.1", optional = true }
4949
anthropic = { version = "^0.36.0", optional = true}
5050
sentence-transformers = {version = "^3.0.0", optional = true }
51+
json-repair = "^0.30.2"
5152

5253
[tool.poetry.group.dev.dependencies]
5354
urllib3 = "<2"

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
import enum
2020
import json
2121
import logging
22-
import re
2322
from datetime import datetime
24-
from typing import Any, List, Optional, Union
23+
from typing import Any, List, Optional, Union, cast
24+
25+
import json_repair
2526

2627
from pydantic import ValidationError, validate_call
2728

@@ -36,6 +37,7 @@
3637
TextChunks,
3738
)
3839
from neo4j_graphrag.experimental.pipeline.component import Component
40+
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
3941
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
4042
from neo4j_graphrag.llm import LLMInterface
4143

@@ -100,28 +102,15 @@ def balance_curly_braces(json_string: str) -> str:
100102
return "".join(fixed_json)
101103

102104

103-
def fix_invalid_json(invalid_json_string: str) -> str:
104-
# Fix missing quotes around field names
105-
invalid_json_string = re.sub(
106-
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
107-
)
108-
109-
# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
110-
invalid_json_string = re.sub(
111-
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
112-
r'"\2"',
113-
invalid_json_string,
114-
)
115-
116-
# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
117-
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)
105+
def fix_invalid_json(raw_json: str) -> str:
106+
repaired_json = json_repair.repair_json(raw_json)
107+
repaired_json = cast(str, repaired_json).strip()
118108

119-
# Normalize excessive curly braces
120-
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
121-
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)
122-
123-
# Balance curly braces
124-
return balance_curly_braces(invalid_json_string)
109+
if repaired_json == '""':
110+
raise InvalidJSONError("JSON repair resulted in an empty or invalid JSON.")
111+
if not repaired_json:
112+
raise InvalidJSONError("JSON repair resulted in an empty string.")
113+
return repaired_json
125114

126115

127116
class EntityRelationExtractor(Component, abc.ABC):
@@ -223,24 +212,18 @@ async def extract_for_chunk(
223212
)
224213
llm_result = await self.llm.ainvoke(prompt)
225214
try:
226-
result = json.loads(llm_result.content)
227-
except json.JSONDecodeError:
228-
logger.info(
229-
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}. Trying to fix it."
230-
)
231-
fixed_content = fix_invalid_json(llm_result.content)
232-
try:
233-
result = json.loads(fixed_content)
234-
except json.JSONDecodeError as e:
235-
if self.on_error == OnError.RAISE:
236-
raise LLMGenerationError(
237-
f"LLM response is not valid JSON {fixed_content}: {e}"
238-
)
239-
else:
240-
logger.error(
241-
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
242-
)
243-
result = {"nodes": [], "relationships": []}
215+
llm_generated_json = fix_invalid_json(llm_result.content)
216+
result = json.loads(llm_generated_json)
217+
except (json.JSONDecodeError, InvalidJSONError) as e:
218+
if self.on_error == OnError.RAISE:
219+
raise LLMGenerationError(
220+
f"LLM response is not valid JSON {llm_result.content}: {e}"
221+
)
222+
else:
223+
logger.error(
224+
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
225+
)
226+
result = {"nodes": [], "relationships": []}
244227
try:
245228
chunk_graph = Neo4jGraph(**result)
246229
except ValidationError as e:

src/neo4j_graphrag/experimental/pipeline/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ class PipelineStatusUpdateError(Neo4jGraphRagError):
3131
"""Raises when trying an invalid change of state (e.g. DONE => DOING)"""
3232

3333
pass
34+
35+
36+
class InvalidJSONError(Neo4jGraphRagError):
37+
"""Raised when JSON repair fails to produce valid JSON."""
38+
39+
pass

src/neo4j_graphrag/generation/prompts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ class ERExtractionTemplate(PromptTemplate):
174174
Do respect the source and target node types for relationship and
175175
the relationship direction.
176176
177-
Do not return any additional information other than the JSON in it.
177+
Make sure you adhere to the following rules to produce valid JSON objects:
178+
- Do not return any additional information other than the JSON in it.
179+
- Omit any backticks around the JSON - simply output the JSON on its own.
180+
- The JSON object must not wrapped into a list - it is its own JSON object.
181+
- Property names must be enclosed in double quotes
178182
179183
Examples:
180184
{examples}

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ def invoke(self, input: str) -> LLMResponse:
8484
messages=self.get_messages(input),
8585
**self.model_params,
8686
)
87-
if response is None or response.choices is None or not response.choices:
88-
content = ""
89-
else:
90-
content = response.choices[0].message.content or ""
87+
content: str = ""
88+
if response and response.choices:
89+
possible_content = response.choices[0].message.content
90+
if isinstance(possible_content, str):
91+
content = possible_content
9192
return LLMResponse(content=content)
9293
except SDKError as e:
9394
raise LLMGenerationError(e)
@@ -111,10 +112,11 @@ async def ainvoke(self, input: str) -> LLMResponse:
111112
messages=self.get_messages(input),
112113
**self.model_params,
113114
)
114-
if response is None or response.choices is None or not response.choices:
115-
content = ""
116-
else:
117-
content = response.choices[0].message.content or ""
115+
content: str = ""
116+
if response and response.choices:
117+
possible_content = response.choices[0].message.content
118+
if isinstance(possible_content, str):
119+
content = possible_content
118120
return LLMResponse(content=content)
119121
except SDKError as e:
120122
raise LLMGenerationError(e)

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import json
18-
from unittest.mock import MagicMock
18+
from unittest.mock import MagicMock, patch
1919

2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
@@ -31,6 +31,7 @@
3131
TextChunk,
3232
TextChunks,
3333
)
34+
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
3435
from neo4j_graphrag.llm import LLMInterface, LLMResponse
3536

3637

@@ -144,16 +145,17 @@ async def test_extractor_llm_ainvoke_failed() -> None:
144145

145146

146147
@pytest.mark.asyncio
147-
async def test_extractor_llm_badly_formatted_json() -> None:
148+
async def test_extractor_llm_unfixable_json() -> None:
148149
llm = MagicMock(spec=LLMInterface)
149150
llm.ainvoke.return_value = LLMResponse(
150-
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": [}'
151+
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": }'
151152
)
152153

153154
extractor = LLMEntityRelationExtractor(
154155
llm=llm,
155156
)
156157
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
158+
157159
with pytest.raises(LLMGenerationError):
158160
await extractor.run(chunks=chunks)
159161

@@ -177,7 +179,7 @@ async def test_extractor_llm_invalid_json() -> None:
177179

178180

179181
@pytest.mark.asyncio
180-
async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
182+
async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None:
181183
llm = MagicMock(spec=LLMInterface)
182184
llm.ainvoke.return_value = LLMResponse(
183185
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": [}'
@@ -190,7 +192,11 @@ async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
190192
)
191193
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
192194
res = await extractor.run(chunks=chunks)
193-
assert res.nodes == []
195+
196+
assert len(res.nodes) == 1
197+
assert res.nodes[0].label == "Person"
198+
assert res.nodes[0].properties == {"chunk_index": 0}
199+
assert res.nodes[0].embedding_properties is None
194200
assert res.relationships == []
195201

196202

@@ -205,6 +211,14 @@ async def test_extractor_custom_prompt() -> None:
205211
llm.ainvoke.assert_called_once_with("this is my prompt")
206212

207213

214+
def test_fix_invalid_json_empty_result() -> None:
215+
json_string = "invalid json"
216+
217+
with patch("json_repair.repair_json", return_value=""):
218+
with pytest.raises(InvalidJSONError):
219+
fix_invalid_json(json_string)
220+
221+
208222
def test_fix_unquoted_keys() -> None:
209223
json_string = '{name: "John", age: "30"}'
210224
expected_result = '{"name": "John", "age": "30"}'

0 commit comments

Comments
 (0)