3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
5
from io import StringIO
6
- from typing import Any , Dict , List , Literal , Optional , Tuple
6
+ from typing import Any , Dict , List , Literal , Optional , Tuple , get_args
7
7
8
8
from haystack import Document , component , logging
9
9
from haystack .lazy_imports import LazyImport
13
13
14
14
logger = logging .getLogger (__name__ )
15
15
16
+ SplitMode = Literal ["threshold" , "row-wise" ]
17
+
16
18
17
19
@component
18
20
class CSVDocumentSplitter :
19
21
"""
20
- A component for splitting CSV documents into sub-tables based on empty rows and columns .
22
+ A component for splitting CSV documents into sub-tables based on split arguments .
21
23
22
- The splitter identifies consecutive empty rows or columns that exceed a given threshold
24
+ The splitter supports two modes of operation:
25
+ - identify consecutive empty rows or columns that exceed a given threshold
23
26
and uses them as delimiters to segment the document into smaller tables.
27
+ - split each row into a separate sub-table, represented as a Document.
28
+
24
29
"""
25
30
26
31
def __init__ (
27
32
self ,
28
33
row_split_threshold : Optional [int ] = 2 ,
29
34
column_split_threshold : Optional [int ] = 2 ,
30
35
read_csv_kwargs : Optional [Dict [str , Any ]] = None ,
36
+ split_mode : SplitMode = "threshold" ,
31
37
) -> None :
32
38
"""
33
39
Initializes the CSVDocumentSplitter component.
@@ -40,8 +46,16 @@ def __init__(
40
46
- `skip_blank_lines=False` to preserve blank lines
41
47
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
42
48
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
49
+ :param split_mode:
50
+ If `threshold`, the component will split the document based on the number of
51
+ consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
52
+ If `row-wise`, the component will split each row into a separate sub-table.
43
53
"""
44
54
pandas_import .check ()
55
+ if split_mode not in get_args (SplitMode ):
56
+ raise ValueError (
57
+ f"Split mode '{ split_mode } ' not recognized. Choose one among: { ', ' .join (get_args (SplitMode ))} ."
58
+ )
45
59
if row_split_threshold is not None and row_split_threshold < 1 :
46
60
raise ValueError ("row_split_threshold must be greater than 0" )
47
61
@@ -54,6 +68,7 @@ def __init__(
54
68
self .row_split_threshold = row_split_threshold
55
69
self .column_split_threshold = column_split_threshold
56
70
self .read_csv_kwargs = read_csv_kwargs or {}
71
+ self .split_mode = split_mode
57
72
58
73
@component .output_types (documents = List [Document ])
59
74
def run (self , documents : List [Document ]) -> Dict [str , List [Document ]]:
@@ -89,6 +104,7 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
89
104
resolved_read_csv_kwargs = {"header" : None , "skip_blank_lines" : False , "dtype" : object , ** self .read_csv_kwargs }
90
105
91
106
split_documents = []
107
+ split_dfs = []
92
108
for document in documents :
93
109
try :
94
110
df = pd .read_csv (StringIO (document .content ), ** resolved_read_csv_kwargs ) # type: ignore
@@ -97,19 +113,32 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
97
113
split_documents .append (document )
98
114
continue
99
115
100
- if self .row_split_threshold is not None and self .column_split_threshold is None :
101
- # split by rows
102
- split_dfs = self ._split_dataframe (df = df , split_threshold = self .row_split_threshold , axis = "row" )
103
- elif self .column_split_threshold is not None and self .row_split_threshold is None :
104
- # split by columns
105
- split_dfs = self ._split_dataframe (df = df , split_threshold = self .column_split_threshold , axis = "column" )
106
- else :
107
- # recursive split
108
- split_dfs = self ._recursive_split (
109
- df = df ,
110
- row_split_threshold = self .row_split_threshold , # type: ignore
111
- column_split_threshold = self .column_split_threshold , # type: ignore
116
+ if self .split_mode == "row-wise" :
117
+ # each row is a separate sub-table
118
+ split_dfs = self ._split_by_row (df = df )
119
+
120
+ elif self .split_mode == "threshold" :
121
+ if self .row_split_threshold is not None and self .column_split_threshold is None :
122
+ # split by rows
123
+ split_dfs = self ._split_dataframe (df = df , split_threshold = self .row_split_threshold , axis = "row" )
124
+ elif self .column_split_threshold is not None and self .row_split_threshold is None :
125
+ # split by columns
126
+ split_dfs = self ._split_dataframe (df = df , split_threshold = self .column_split_threshold , axis = "column" )
127
+ else :
128
+ # recursive split
129
+ split_dfs = self ._recursive_split (
130
+ df = df ,
131
+ row_split_threshold = self .row_split_threshold , # type: ignore
132
+ column_split_threshold = self .column_split_threshold , # type: ignore
133
+ )
134
+
135
+ # check if no sub-tables were found
136
+ if len (split_dfs ) == 0 :
137
+ logger .warning (
138
+ "No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document." ,
139
+ doc_id = document .id ,
112
140
)
141
+ continue
113
142
114
143
# Sort split_dfs first by row index, then by column index
115
144
split_dfs .sort (key = lambda dataframe : (dataframe .index [0 ], dataframe .columns [0 ]))
@@ -242,3 +271,12 @@ def _recursive_split(
242
271
result .append (table )
243
272
244
273
return result
274
+
275
+ def _split_by_row (self , df : "pd.DataFrame" ) -> List ["pd.DataFrame" ]:
276
+ """Split each CSV row into a separate subtable"""
277
+ split_dfs = []
278
+ for idx , row in enumerate (df .itertuples (index = False )):
279
+ split_df = pd .DataFrame (row ).T
280
+ split_df .index = [idx ] # Set the index of the new DataFrame to idx
281
+ split_dfs .append (split_df )
282
+ return split_dfs
0 commit comments