Skip to content

Commit

Permalink
support batch for part ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Aug 30, 2024
1 parent b5bf71b commit 0f52909
Show file tree
Hide file tree
Showing 23 changed files with 212 additions and 135 deletions.
10 changes: 7 additions & 3 deletions data_juicer/ops/mapper/chinese_convert_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class ChineseConvertMapper(Mapper):
"""Mapper to convert Chinese between Traditional Chinese, Simplified Chinese
and Japanese Kanji."""

_batched_op = True

def __init__(self, mode: str = 's2t', *args, **kwargs):
"""
Initialization method.
Expand Down Expand Up @@ -82,8 +84,10 @@ def __init__(self, mode: str = 's2t', *args, **kwargs):
self.mode = mode
prepare_converter(self.mode)

def process(self, sample):
def process(self, samples):
prepare_converter(self.mode)

sample[self.text_key] = OPENCC_CONVERTER.convert(sample[self.text_key])
return sample
samples[self.text_key] = [
OPENCC_CONVERTER.convert(text) for text in samples[self.text_key]
]
return samples
23 changes: 15 additions & 8 deletions data_juicer/ops/mapper/clean_copyright_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class CleanCopyrightMapper(Mapper):
"""Mapper to clean copyright comments at the beginning of the text
samples."""

_batched_op = True

def __init__(self, *args, **kwargs):
"""
Initialization method.
Expand All @@ -23,21 +25,19 @@ def __init__(self, *args, **kwargs):
self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/')
self.cpat = re.compile('copyright', re.IGNORECASE)

def process(self, sample):

r = self.pat.search(sample[self.text_key])
def _process_single_sample(self, sample):
r = self.pat.search(sample)
if r:
# found one, now see if it contains "copyright", if so strip it
span = r.span()
sub = sample[self.text_key][span[0]:span[1]]
sub = sample[span[0]:span[1]]
if self.cpat.search(sub):
# cut it
sample[self.text_key] = sample[
self.text_key][:span[0]] + sample[self.text_key][span[1]:]
sample = sample[:span[0]] + sample[span[1]:]

return sample

lines = sample[self.text_key].split('\n')
lines = sample.split('\n')
skip = 0

# Greedy replace any file that begins with comment block, most
Expand All @@ -51,5 +51,12 @@ def process(self, sample):

if skip:
# we skipped, consume it
sample[self.text_key] = '\n'.join(lines[skip:])
sample = '\n'.join(lines[skip:])
return sample

def process(self, samples):
samples[self.text_key] = [
self._process_single_sample(text)
for text in samples[self.text_key]
]
return samples
20 changes: 11 additions & 9 deletions data_juicer/ops/mapper/clean_email_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class CleanEmailMapper(Mapper):
"""Mapper to clean email in text samples."""

_batched_op = True

def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
"""
Initialization method.
Expand All @@ -28,13 +30,13 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):

self.repl = repl

def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample
def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue
samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)

sample[self.text_key] = re.sub(pattern=self.pattern,
repl=self.repl,
string=sample[self.text_key],
flags=re.DOTALL)
return sample
return samples
10 changes: 7 additions & 3 deletions data_juicer/ops/mapper/clean_html_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
class CleanHtmlMapper(Mapper):
"""Mapper to clean html code in text samples."""

_batched_op = True

def __init__(self, *args, **kwargs):
"""
Initialization method.
Expand All @@ -25,7 +27,7 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

def process(self, sample):
def process(self, samples):

def _clean_html(raw_html):
raw_html = raw_html.replace('<li>', '\n*')
Expand All @@ -35,5 +37,7 @@ def _clean_html(raw_html):
parser = HTMLParser(raw_html)
return parser.text()

sample[self.text_key] = _clean_html(sample[self.text_key])
return sample
samples[self.text_key] = [
_clean_html(text) for text in samples[self.text_key]
]
return samples
21 changes: 11 additions & 10 deletions data_juicer/ops/mapper/clean_ip_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class CleanIpMapper(Mapper):
"""Mapper to clean ipv4 and ipv6 address in text samples."""

_batched_op = True

def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
"""
Initialization method.
Expand All @@ -32,13 +34,12 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.pattern = pattern[2:-1]
self.repl = repl

def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample

sample[self.text_key] = re.sub(pattern=self.pattern,
repl=self.repl,
string=sample[self.text_key],
flags=re.DOTALL)
return sample
def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue
samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
return samples
20 changes: 11 additions & 9 deletions data_juicer/ops/mapper/clean_links_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class CleanLinksMapper(Mapper):
"""Mapper to clean links like http/https/ftp in text samples."""

_batched_op = True

def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
"""
Initialization method.
Expand Down Expand Up @@ -38,13 +40,13 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.pattern = pattern[2:-1]
self.repl = repl

def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample
def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue

sample[self.text_key] = re.sub(pattern=self.pattern,
repl=self.repl,
string=sample[self.text_key],
flags=re.DOTALL)
return sample
samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
return samples
51 changes: 28 additions & 23 deletions data_juicer/ops/mapper/expand_macro_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class ExpandMacroMapper(Mapper):
"""Mapper to expand macro definitions in the document body of Latex
samples."""

_batched_op = True

def __init__(self, *args, **kwargs):
"""
Initialization method.
Expand Down Expand Up @@ -55,26 +57,29 @@ def _build_non_arg_macros_dict(self, file_content):
macros[macro_name] = macro_val
return macros

def process(self, sample):
non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key])

# TODO: macros that take arguments are not supported yet
arg_macros = {}

# inline-expand all non-arg macros
for macro_name, macro_value in non_arg_macros.items():
sample[self.text_key] = re.sub(
# make pattern grouped to make sure that the macro is not part
# of a longer alphanumeric word
pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])',
# replace the macro with its value and add back the character
# that was matched after the macro
repl=macro_value + r'\2',
string=sample[self.text_key])

# inline-expand all macros that use args
# TODO: inline-expand macros with args
for macro_name, macro_value in arg_macros.items():
pass

return sample
def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
non_arg_macros = self._build_non_arg_macros_dict(text)

# TODO: macros that take arguments are not supported yet
arg_macros = {}

# inline-expand all non-arg macros
for macro_name, macro_value in non_arg_macros.items():
text = re.sub(
# make pattern grouped to make sure that the macro
# is not part of a longer alphanumeric word
pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])',
# replace the macro with its value and add back the
# character that was matched after the macro
repl=macro_value + r'\2',
string=text)

# inline-expand all macros that use args
# TODO: inline-expand macros with args
for macro_name, macro_value in arg_macros.items():
pass

samples[self.text_key][i] = text

return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/fix_unicode_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, normalization: str = None, *args, **kwargs):

def process(self, samples):
samples[self.text_key] = [
ftfy.fix_text(i, normalization=self.normalization)
for i in samples[self.text_key]
ftfy.fix_text(text, normalization=self.normalization)
for text in samples[self.text_key]
]
return samples
14 changes: 9 additions & 5 deletions data_juicer/ops/mapper/punctuation_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class PunctuationNormalizationMapper(Mapper):
"""Mapper to normalize unicode punctuations to English punctuations in text
samples."""

_batched_op = True

def __init__(self, *args, **kwargs):
"""
Initialization method.
Expand Down Expand Up @@ -55,8 +57,10 @@ def __init__(self, *args, **kwargs):
'►': '-',
}

def process(self, sample):
sample[self.text_key] = ''.join([
self.punctuation_unicode.get(c, c) for c in sample[self.text_key]
])
return sample
def process(self, samples):
samples[self.text_key] = [
''.join([self.punctuation_unicode.get(c, c) for c in text])
for text in samples[self.text_key]
]

return samples
17 changes: 11 additions & 6 deletions data_juicer/ops/mapper/remove_bibliography_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class RemoveBibliographyMapper(Mapper):
"""Mapper to remove bibliography at the end of documents in Latex
samples."""

_batched_op = True

def __init__(self, *args, **kwargs):
"""
Initialization method.
Expand All @@ -27,9 +29,12 @@ def __init__(self, *args, **kwargs):
self.pattern += r'\\bibliography\{.*\}'
self.pattern += r').*$'

def process(self, sample):
sample[self.text_key] = re.sub(pattern=self.pattern,
repl=r'',
string=sample[self.text_key],
flags=re.DOTALL)
return sample
def process(self, samples):
samples[self.text_key] = [
re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL) for text in samples[self.text_key]
]

return samples
34 changes: 20 additions & 14 deletions data_juicer/ops/mapper/remove_comments_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class RemoveCommentsMapper(Mapper):
Only support 'tex' for now.
"""

_batched_op = True

def __init__(self,
doc_type: Union[str, List[str]] = 'tex',
inline: bool = True,
Expand All @@ -37,19 +39,23 @@ def __init__(self,
self.inline = inline
self.multiline = multiline

def process(self, sample):
def process(self, samples):
# TODO: remove different comments by sample type

if self.inline:
# remove all in comments within a line
sample[self.text_key] = re.sub(pattern=r'[^\\]%.+$',
repl=r'',
string=sample[self.text_key],
flags=re.MULTILINE)

if self.multiline:
sample[self.text_key] = re.sub(pattern=r'(?m)^%.*\n?',
repl=r'',
string=sample[self.text_key],
flags=re.MULTILINE)
return sample
for i, text in enumerate(samples[self.text_key]):
if self.inline:
# remove all in comments within a line
text = re.sub(pattern=r'[^\\]%.+$',
repl=r'',
string=text,
flags=re.MULTILINE)

if self.multiline:
text = re.sub(pattern=r'(?m)^%.*\n?',
repl=r'',
string=text,
flags=re.MULTILINE)

samples[self.text_key][i] = text

return samples
Loading

0 comments on commit 0f52909

Please sign in to comment.