Skip to content

Commit 63da1af

Browse files
fix: picklable ingest errors wrapper (#448)
1 parent 8e619d1 commit 63da1af

File tree

8 files changed

+127
-96
lines changed

8 files changed

+127
-96
lines changed

packages/ragbits-document-search/CHANGELOG.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
## Unreleased
44

5-
### Changed
6-
7-
- Add support for Git source to fetch files from Git repositories
5+
- Introduce picklable ingest error wrapper (#448)
6+
- Add support for Git source to fetch files from Git repositories (#439)
87

98
## 0.10.2 (2025-03-21)
109

packages/ragbits-document-search/src/ragbits/document_search/documents/exceptions.py

-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import inspect
2-
from typing import Any
3-
4-
51
class SourceError(Exception):
62
"""
73
Class for all exceptions raised by the document source.
@@ -11,19 +7,6 @@ def __init__(self, message: str) -> None:
117
super().__init__(message)
128
self.message = message
139

14-
def __reduce__(self) -> tuple[type["SourceError"], tuple[Any, ...]]:
15-
# This __reduce__ method is written in a way that it automatically handles any subclass of SourceError.
16-
# It requires the subclass to have an initializer that store the arguments in the instance's state,
17-
# under the same name.
18-
init_params = inspect.signature(self.__class__.__init__).parameters
19-
20-
args = [
21-
self.__getattribute__(param_name)
22-
for param_name in list(init_params.keys())[1:] # Skip 'self'
23-
]
24-
25-
return self.__class__, tuple(args)
26-
2710

2811
class SourceConnectionError(SourceError):
2912
"""

packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py

+22
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,25 @@ async def process(self, document_meta: DocumentMeta) -> Sequence[Element | Inter
6767
document_meta=document_meta,
6868
)
6969
]
70+
71+
72+
class DummyFailureProvider(BaseProvider):
73+
"""
74+
This is a simple provider that always raises an exception.
75+
"""
76+
77+
SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT, DocumentType.MD, DocumentType.JPG, DocumentType.PNG}
78+
79+
async def process(self, document_meta: DocumentMeta) -> Sequence[Element | IntermediateElement]:
80+
"""
81+
Process the text document.
82+
83+
Args:
84+
document_meta: The document to process.
85+
86+
Raises:
87+
RuntimeError: This is a dummy exception.
88+
"""
89+
self.validate_document_type(document_meta.document_type)
90+
91+
raise RuntimeError("This is a dummy exception")

packages/ragbits-document-search/src/ragbits/document_search/ingestion/strategies/base.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import random
3+
import traceback
34
from abc import ABC, abstractmethod
45
from collections import defaultdict
56
from collections.abc import Awaitable, Callable, Iterable, Sequence
@@ -20,6 +21,31 @@
2021
_CallReturnT = TypeVar("_CallReturnT")
2122

2223

24+
@dataclass
25+
class IngestError:
26+
"""
27+
Represents an error that occurred during the document ingest execution
28+
"""
29+
30+
type: type[Exception]
31+
message: str
32+
stacktrace: str
33+
34+
@classmethod
35+
def from_exception(cls, exc: Exception) -> "IngestError":
36+
"""
37+
Create an IngestError from an exception.
38+
39+
Args:
40+
exc: The exception to create the IngestError from.
41+
42+
Returns:
43+
The IngestError instance.
44+
"""
45+
stacktrace = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
46+
return cls(type=type(exc), message=str(exc), stacktrace=stacktrace)
47+
48+
2349
@dataclass
2450
class IngestDocumentResult:
2551
"""
@@ -28,7 +54,7 @@ class IngestDocumentResult:
2854

2955
document_uri: str
3056
num_elements: int = 0
31-
error: BaseException | None = None
57+
error: IngestError | None = None
3258

3359

3460
@dataclass

packages/ragbits-document-search/src/ragbits/document_search/ingestion/strategies/batched.py

+49-26
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ragbits.document_search.ingestion.intermediate_handlers.base import BaseIntermediateHandler
1212
from ragbits.document_search.ingestion.strategies.base import (
1313
IngestDocumentResult,
14+
IngestError,
1415
IngestExecutionResult,
1516
IngestStrategy,
1617
)
@@ -19,7 +20,7 @@
1920
@dataclass
2021
class IngestTaskResult:
2122
"""
22-
Represents the result of the document batch ingest tast.
23+
Represents the result of the document batch ingest task.
2324
"""
2425

2526
document_uri: str
@@ -145,18 +146,29 @@ async def _parse_batch(
145146
],
146147
return_exceptions=True,
147148
)
148-
return [
149-
IngestDocumentResult(
150-
document_uri=uri,
151-
error=response,
152-
)
153-
if isinstance(response, BaseException)
154-
else IngestTaskResult(
155-
document_uri=uri,
156-
elements=response,
157-
)
158-
for uri, response in zip(uris, responses, strict=True)
159-
]
149+
150+
results: list[IngestTaskResult | IngestDocumentResult] = []
151+
for uri, response in zip(uris, responses, strict=True):
152+
if isinstance(response, BaseException):
153+
if isinstance(response, Exception):
154+
results.append(
155+
IngestDocumentResult(
156+
document_uri=uri,
157+
error=IngestError.from_exception(response),
158+
)
159+
)
160+
# Handle only standard exceptions, not BaseExceptions like SystemExit, KeyboardInterrupt, etc.
161+
else:
162+
raise response
163+
else:
164+
results.append(
165+
IngestTaskResult(
166+
document_uri=uri,
167+
elements=response,
168+
)
169+
)
170+
171+
return results
160172

161173
async def _enrich_batch(
162174
self,
@@ -184,18 +196,29 @@ async def _enrich_batch(
184196
],
185197
return_exceptions=True,
186198
)
187-
return [
188-
IngestDocumentResult(
189-
document_uri=result.document_uri,
190-
error=response,
191-
)
192-
if isinstance(response, BaseException)
193-
else IngestTaskResult(
194-
document_uri=result.document_uri,
195-
elements=[element for element in result.elements if isinstance(element, Element)] + response,
196-
)
197-
for result, response in zip(batch, responses, strict=True)
198-
]
199+
200+
results: list[IngestTaskResult | IngestDocumentResult] = []
201+
for result, response in zip(batch, responses, strict=True):
202+
if isinstance(response, BaseException):
203+
if isinstance(response, Exception):
204+
results.append(
205+
IngestDocumentResult(
206+
document_uri=result.document_uri,
207+
error=IngestError.from_exception(response),
208+
)
209+
)
210+
# Handle only standard exceptions, not BaseExceptions like SystemExit, KeyboardInterrupt, etc.
211+
else:
212+
raise response
213+
else:
214+
results.append(
215+
IngestTaskResult(
216+
document_uri=result.document_uri,
217+
elements=[element for element in result.elements if isinstance(element, Element)] + response,
218+
)
219+
)
220+
221+
return results
199222

200223
async def _index_batch(
201224
self,
@@ -230,7 +253,7 @@ async def _index_batch(
230253
return [
231254
IngestDocumentResult(
232255
document_uri=result.document_uri,
233-
error=exc,
256+
error=IngestError.from_exception(exc),
234257
)
235258
for result in batch
236259
]

packages/ragbits-document-search/src/ragbits/document_search/ingestion/strategies/sequential.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ragbits.document_search.ingestion.intermediate_handlers.base import BaseIntermediateHandler
99
from ragbits.document_search.ingestion.strategies.base import (
1010
IngestDocumentResult,
11+
IngestError,
1112
IngestExecutionResult,
1213
IngestStrategy,
1314
)
@@ -70,7 +71,7 @@ async def __call__(
7071
results.failed.append(
7172
IngestDocumentResult(
7273
document_uri=document_uri,
73-
error=exc,
74+
error=IngestError.from_exception(exc),
7475
)
7576
)
7677
else:

packages/ragbits-document-search/tests/unit/test_ingest_strategies.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
55
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
66
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
7-
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
7+
from ragbits.document_search.ingestion.providers.dummy import DummyFailureProvider, DummyProvider
88
from ragbits.document_search.ingestion.strategies.base import IngestStrategy
99
from ragbits.document_search.ingestion.strategies.batched import BatchedIngestStrategy
1010
from ragbits.document_search.ingestion.strategies.ray import RayDistributedIngestStrategy
@@ -48,3 +48,27 @@ async def test_ingest_strategy_call(ingest_strategy: IngestStrategy, documents:
4848

4949
assert len(results.successful) == len(documents)
5050
assert len(results.failed) == 0
51+
52+
53+
async def test_ingest_errors_are_returned(ingest_strategy: IngestStrategy, documents: list[DocumentMeta]) -> None:
54+
vector_store = InMemoryVectorStore(embedder=NoopEmbedder())
55+
parser_router = DocumentProcessorRouter.from_config({DocumentType.TXT: DummyFailureProvider()})
56+
57+
results = await ingest_strategy(
58+
documents=documents,
59+
vector_store=vector_store,
60+
parser_router=parser_router,
61+
enricher_router={},
62+
)
63+
64+
assert len(results.successful) == 0
65+
assert len(results.failed) == len(documents)
66+
67+
expected_error_message = "This is a dummy exception"
68+
expected_error_type = RuntimeError
69+
for result in results.failed:
70+
assert result.error is not None
71+
assert result.error.type == expected_error_type
72+
assert result.error.message == expected_error_message
73+
assert result.error.stacktrace.startswith("Traceback")
74+
assert expected_error_message in result.error.stacktrace
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import pickle
2-
31
from ragbits.document_search.documents.exceptions import (
42
SourceConnectionError,
53
SourceError,
@@ -14,50 +12,19 @@ def test_source_error_init():
1412
assert str(error) == "Test error message"
1513

1614

17-
def test_source_error_pickle():
18-
original = SourceError("Test pickle message")
19-
pickled = pickle.dumps(original)
20-
unpickled = pickle.loads(pickled) # noqa: S301
21-
22-
assert isinstance(unpickled, SourceError)
23-
assert unpickled.message == "Test pickle message"
24-
assert str(unpickled) == "Test pickle message"
25-
26-
2715
def test_source_connection_error_init():
2816
error = SourceConnectionError()
2917
assert error.message == "Connection error."
3018
assert str(error) == "Connection error."
3119

3220

33-
def test_source_connection_error_pickle():
34-
original = SourceConnectionError()
35-
pickled = pickle.dumps(original)
36-
unpickled = pickle.loads(pickled) # noqa: S301
37-
38-
assert isinstance(unpickled, SourceConnectionError)
39-
assert unpickled.message == "Connection error."
40-
assert str(unpickled) == "Connection error."
41-
42-
4321
def test_source_not_found_error_init():
4422
error = SourceNotFoundError("test-source-id")
4523
assert error.source_id == "test-source-id"
4624
assert error.message == "Source with ID test-source-id not found."
4725
assert str(error) == "Source with ID test-source-id not found."
4826

4927

50-
def test_source_not_found_error_pickle():
51-
original = SourceNotFoundError("test-source-id")
52-
pickled = pickle.dumps(original)
53-
unpickled = pickle.loads(pickled) # noqa: S301
54-
55-
assert isinstance(unpickled, SourceNotFoundError)
56-
assert unpickled.source_id == "test-source-id"
57-
assert unpickled.message == "Source with ID test-source-id not found."
58-
assert str(unpickled) == "Source with ID test-source-id not found."
59-
60-
6128
def test_web_download_error_init():
6229
url = "https://example.com/file.pdf"
6330
code = 404
@@ -67,17 +34,3 @@ def test_web_download_error_init():
6734
assert error.code == code
6835
assert error.message == f"Download of {url} failed with code {code}."
6936
assert str(error) == f"Download of {url} failed with code {code}."
70-
71-
72-
def test_web_download_error_pickle():
73-
url = "https://example.com/file.pdf"
74-
code = 404
75-
original = WebDownloadError(url, code)
76-
pickled = pickle.dumps(original)
77-
unpickled = pickle.loads(pickled) # noqa: S301
78-
79-
assert isinstance(unpickled, WebDownloadError)
80-
assert unpickled.url == url
81-
assert unpickled.code == code
82-
assert unpickled.message == f"Download of {url} failed with code {code}."
83-
assert str(unpickled) == f"Download of {url} failed with code {code}."

0 commit comments

Comments
 (0)