-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtext_splitters.py
214 lines (181 loc) · 7.71 KB
/
text_splitters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import logging
import re
from typing import Callable, Dict, List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
logger = logging.getLogger(__name__)
class BaseTablePreservingTextSplitter:
"""
Base class for table-preserving text splitters with common table extraction logic.
"""
@staticmethod
def extract_tables(text: str) -> List[Dict[str, str]]:
"""
Extract tables from the text.
Supports different table formats (markdown, HTML, pipe-separated).
"""
# Markdown table detection
markdown_table_pattern = r"(\|[^\n]+\|\n)+((?:\|[-:| ]+\|\n)(\|[^\n]+\|\n)*)"
# HTML table detection
html_table_pattern = r"<table>.*?</table>"
# Pipe-separated table detection
pipe_table_pattern = r"(^.*\|.*\n)(^[-:| ]+\|\n)(^.*\|.*\n)*"
tables = []
# Find markdown tables
for match in re.finditer(markdown_table_pattern, text, re.MULTILINE):
tables.append(
{
"type": "markdown",
"content": match.group(0),
"start": match.start(),
"end": match.end(),
}
)
# Find HTML tables
for match in re.finditer(html_table_pattern, text, re.DOTALL):
tables.append(
{
"type": "html",
"content": match.group(0),
"start": match.start(),
"end": match.end(),
}
)
# Find pipe-separated tables
for match in re.finditer(pipe_table_pattern, text, re.MULTILINE):
tables.append(
{
"type": "pipe",
"content": match.group(0),
"start": match.start(),
"end": match.end(),
}
)
# Sort and deduplicate tables
tables = sorted(tables, key=lambda x: x["start"])
deduplicated_tables = []
for table in tables:
if not any(
table["start"] == t["start"] and table["end"] == t["end"]
for t in deduplicated_tables
):
deduplicated_tables.append(table)
return deduplicated_tables
@classmethod
def split_text(
cls,
text: str,
base_splitter: Callable[[str], List[str]],
chunk_size: Optional[int] = None,
length_function: Callable[[str], int] = len,
table_augmenter: Optional[Callable[[str], str]] = None,
) -> List[str]:
"""
Split text while preserving tables.
:param text: Input text to split
:param base_splitter: The base text splitter function to use
:param chunk_size: Optional chunk size to limit chunk length
:param length_function: Function to calculate length of text segments
:return: List of text chunks
"""
# First, identify tables in the text
tables = cls.extract_tables(text)
if len(tables) > 0:
logger.info(f"{len(tables)} tables to augment with additional context")
# Create a list of text segments that alternate between non-table text and tables
text_segments = []
last_end = 0
for table in tables:
# Add text before the table
if table["start"] > last_end:
text_segments.append(
{"type": "text", "content": text[last_end : table["start"]]}
)
# Add the table
text_segments.append({"type": "table", "content": table["content"]})
last_end = table["end"]
# Add remaining text after the last table
if last_end < len(text):
text_segments.append({"type": "text", "content": text[last_end:]})
# Split text segments
final_chunks = []
current_chunk = ""
for segment in text_segments:
if segment["type"] == "text":
# Split the text segment
text_chunks = base_splitter(segment["content"])
for text_chunk in text_chunks:
# Determine if chunk can be added based on chunk_size
can_add_chunk = (
chunk_size is None
or length_function(current_chunk)
+ length_function(text_chunk)
+ 1
<= chunk_size
)
# Try to add the text chunk to the current chunk
if can_add_chunk:
current_chunk += (" " if current_chunk else "") + text_chunk
else:
# If adding would exceed chunk size, finalize current chunk and start a new one
if current_chunk:
final_chunks.append(current_chunk)
current_chunk = text_chunk
elif segment["type"] == "table":
# Handle table integration
# If a table augmenter function is provided, use it to augment the table content
# This can be used to add additional context to the table to make retrieval more accurate
if table_augmenter:
segment["content"] = table_augmenter(segment["content"])
logger.debug(f"""Augmenting table:\n{segment["content"]}""")
# Determine if table can be added based on chunk_size
can_add_table = (
chunk_size is None
or length_function(current_chunk)
+ length_function(segment["content"])
+ 1
<= chunk_size
)
# If table fits in current chunk, add it
if can_add_table:
current_chunk += ("\n\n" if current_chunk else "") + segment[
"content"
]
else:
# If current chunk exists, finalize it
if current_chunk:
final_chunks.append(current_chunk)
# Start a new chunk with the table
current_chunk = segment["content"]
# Add the last chunk if it's not empty
if current_chunk:
final_chunks.append(current_chunk)
return final_chunks
class TablePreservingSemanticChunker(SemanticChunker):
def __init__(self, chunk_size, length_function=len, table_augmenter=None, **kwargs):
self._chunk_size = chunk_size
self._length_function = length_function
self._table_augmenter = table_augmenter
super().__init__(**kwargs)
def split_text(self, text: str) -> List[str]:
return BaseTablePreservingTextSplitter.split_text(
text,
base_splitter=super().split_text,
chunk_size=self._chunk_size,
length_function=self._length_function,
table_augmenter=self._table_augmenter,
)
class TablePreservingTextSplitter(RecursiveCharacterTextSplitter):
def __init__(self, chunk_size, length_function=len, table_augmenter=None, **kwargs):
self._chunk_size = chunk_size
self._length_function = length_function
self._table_augmenter = table_augmenter
super().__init__(**kwargs)
def split_text(self, text: str) -> List[str]:
return BaseTablePreservingTextSplitter.split_text(
text,
base_splitter=super().split_text,
chunk_size=self._chunk_size,
length_function=self._length_function,
table_augmenter=self._table_augmenter,
)