Skip to content

Commit 3c101cd

Browse files
authored
feat: add split_by_row feature to CSVDocumentSplitter (#9031)
* Add split by row feature
1 parent ed931b4 commit 3c101cd

File tree

3 files changed

+100
-16
lines changed

3 files changed

+100
-16
lines changed

haystack/components/preprocessors/csv_document_splitter.py

+53-15
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from io import StringIO
6-
from typing import Any, Dict, List, Literal, Optional, Tuple
6+
from typing import Any, Dict, List, Literal, Optional, Tuple, get_args
77

88
from haystack import Document, component, logging
99
from haystack.lazy_imports import LazyImport
@@ -13,21 +13,27 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
SplitMode = Literal["threshold", "row-wise"]
17+
1618

1719
@component
1820
class CSVDocumentSplitter:
1921
"""
20-
A component for splitting CSV documents into sub-tables based on empty rows and columns.
22+
A component for splitting CSV documents into sub-tables based on split arguments.
2123
22-
The splitter identifies consecutive empty rows or columns that exceed a given threshold
24+
The splitter supports two modes of operation:
25+
- identify consecutive empty rows or columns that exceed a given threshold
2326
and uses them as delimiters to segment the document into smaller tables.
27+
- split each row into a separate sub-table, represented as a Document.
28+
2429
"""
2530

2631
def __init__(
2732
self,
2833
row_split_threshold: Optional[int] = 2,
2934
column_split_threshold: Optional[int] = 2,
3035
read_csv_kwargs: Optional[Dict[str, Any]] = None,
36+
split_mode: SplitMode = "threshold",
3137
) -> None:
3238
"""
3339
Initializes the CSVDocumentSplitter component.
@@ -40,8 +46,16 @@ def __init__(
4046
- `skip_blank_lines=False` to preserve blank lines
4147
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
4248
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
49+
:param split_mode:
50+
If `threshold`, the component will split the document based on the number of
51+
consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
52+
If `row-wise`, the component will split each row into a separate sub-table.
4353
"""
4454
pandas_import.check()
55+
if split_mode not in get_args(SplitMode):
56+
raise ValueError(
57+
f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}."
58+
)
4559
if row_split_threshold is not None and row_split_threshold < 1:
4660
raise ValueError("row_split_threshold must be greater than 0")
4761

@@ -54,6 +68,7 @@ def __init__(
5468
self.row_split_threshold = row_split_threshold
5569
self.column_split_threshold = column_split_threshold
5670
self.read_csv_kwargs = read_csv_kwargs or {}
71+
self.split_mode = split_mode
5772

5873
@component.output_types(documents=List[Document])
5974
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
@@ -89,6 +104,7 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
89104
resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
90105

91106
split_documents = []
107+
split_dfs = []
92108
for document in documents:
93109
try:
94110
df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore
@@ -97,19 +113,32 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
97113
split_documents.append(document)
98114
continue
99115

100-
if self.row_split_threshold is not None and self.column_split_threshold is None:
101-
# split by rows
102-
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
103-
elif self.column_split_threshold is not None and self.row_split_threshold is None:
104-
# split by columns
105-
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
106-
else:
107-
# recursive split
108-
split_dfs = self._recursive_split(
109-
df=df,
110-
row_split_threshold=self.row_split_threshold, # type: ignore
111-
column_split_threshold=self.column_split_threshold, # type: ignore
116+
if self.split_mode == "row-wise":
117+
# each row is a separate sub-table
118+
split_dfs = self._split_by_row(df=df)
119+
120+
elif self.split_mode == "threshold":
121+
if self.row_split_threshold is not None and self.column_split_threshold is None:
122+
# split by rows
123+
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
124+
elif self.column_split_threshold is not None and self.row_split_threshold is None:
125+
# split by columns
126+
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
127+
else:
128+
# recursive split
129+
split_dfs = self._recursive_split(
130+
df=df,
131+
row_split_threshold=self.row_split_threshold, # type: ignore
132+
column_split_threshold=self.column_split_threshold, # type: ignore
133+
)
134+
135+
# check if no sub-tables were found
136+
if len(split_dfs) == 0:
137+
logger.warning(
138+
"No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.",
139+
doc_id=document.id,
112140
)
141+
continue
113142

114143
# Sort split_dfs first by row index, then by column index
115144
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
@@ -242,3 +271,12 @@ def _recursive_split(
242271
result.append(table)
243272

244273
return result
274+
275+
def _split_by_row(self, df: "pd.DataFrame") -> List["pd.DataFrame"]:
276+
"""Split each CSV row into a separate subtable"""
277+
split_dfs = []
278+
for idx, row in enumerate(df.itertuples(index=False)):
279+
split_df = pd.DataFrame(row).T
280+
split_df.index = [idx] # Set the index of the new DataFrame to idx
281+
split_dfs.append(split_df)
282+
return split_dfs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Added a new parameter `split_mode` to the `CSVDocumentSplitter` component to control the splitting mode.
5+
The new parameter can be set to `row-wise` to split the CSV file by rows.
6+
The default value is `threshold`, which is the previous behavior.

test/components/preprocessors/test_csv_document_splitter.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import pytest
6+
import logging
67
from pandas import read_csv
78
from io import StringIO
89
from haystack import Document, Pipeline
@@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter:
1516
return CSVDocumentSplitter()
1617

1718

19+
@pytest.fixture
20+
def csv_with_four_rows() -> str:
21+
return """A,B,C
22+
1,2,3
23+
X,Y,Z
24+
7,8,9
25+
"""
26+
27+
1828
@pytest.fixture
1929
def two_tables_sep_by_two_empty_rows() -> str:
2030
return """A,B,C
@@ -255,7 +265,12 @@ def test_to_dict_with_defaults(self) -> None:
255265
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
256266
config = {
257267
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
258-
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
268+
"init_parameters": {
269+
"row_split_threshold": 2,
270+
"column_split_threshold": 2,
271+
"read_csv_kwargs": {},
272+
"split_mode": "threshold",
273+
},
259274
}
260275
assert config_serialized == config
261276

@@ -268,6 +283,7 @@ def test_to_dict_non_defaults(self) -> None:
268283
"row_split_threshold": 1,
269284
"column_split_threshold": None,
270285
"read_csv_kwargs": {"sep": ";"},
286+
"split_mode": "threshold",
271287
},
272288
}
273289
assert config_serialized == config
@@ -284,6 +300,7 @@ def test_from_dict_defaults(self) -> None:
284300
assert splitter.row_split_threshold == 2
285301
assert splitter.column_split_threshold == 2
286302
assert splitter.read_csv_kwargs == {}
303+
assert splitter.split_mode == "threshold"
287304

288305
def test_from_dict_non_defaults(self) -> None:
289306
splitter = component_from_dict(
@@ -294,10 +311,33 @@ def test_from_dict_non_defaults(self) -> None:
294311
"row_split_threshold": 1,
295312
"column_split_threshold": None,
296313
"read_csv_kwargs": {"sep": ";"},
314+
"split_mode": "row-wise",
297315
},
298316
},
299317
name="CSVDocumentSplitter",
300318
)
301319
assert splitter.row_split_threshold == 1
302320
assert splitter.column_split_threshold is None
303321
assert splitter.read_csv_kwargs == {"sep": ";"}
322+
assert splitter.split_mode == "row-wise"
323+
324+
def test_split_by_row(self, csv_with_four_rows: str) -> None:
325+
splitter = CSVDocumentSplitter(split_mode="row-wise")
326+
doc = Document(content=csv_with_four_rows)
327+
result = splitter.run([doc])["documents"]
328+
assert len(result) == 4
329+
assert result[0].content == "A,B,C\n"
330+
assert result[1].content == "1,2,3\n"
331+
assert result[2].content == "X,Y,Z\n"
332+
333+
def test_split_by_row_with_empty_rows(self, caplog) -> None:
334+
splitter = CSVDocumentSplitter(split_mode="row-wise")
335+
doc = Document(content="")
336+
with caplog.at_level(logging.ERROR):
337+
result = splitter.run([doc])["documents"]
338+
assert len(result) == 1
339+
assert result[0].content == ""
340+
341+
def test_incorrect_split_mode(self) -> None:
342+
with pytest.raises(ValueError, match="not recognized"):
343+
CSVDocumentSplitter(split_mode="incorrect_mode")

0 commit comments

Comments
 (0)