From 0f529096d9671917adb6a2cac5a2d6d952174dbd Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Fri, 30 Aug 2024 15:58:40 +0800 Subject: [PATCH] support batch for part ops --- .../ops/mapper/chinese_convert_mapper.py | 10 ++-- .../ops/mapper/clean_copyright_mapper.py | 23 ++++++--- data_juicer/ops/mapper/clean_email_mapper.py | 20 ++++---- data_juicer/ops/mapper/clean_html_mapper.py | 10 ++-- data_juicer/ops/mapper/clean_ip_mapper.py | 21 ++++---- data_juicer/ops/mapper/clean_links_mapper.py | 20 ++++---- data_juicer/ops/mapper/expand_macro_mapper.py | 51 ++++++++++--------- data_juicer/ops/mapper/fix_unicode_mapper.py | 4 +- .../punctuation_normalization_mapper.py | 14 +++-- .../ops/mapper/remove_bibliography_mapper.py | 17 ++++--- .../ops/mapper/remove_comments_mapper.py | 34 ++++++++----- .../ops/mapper/remove_header_mapper.py | 24 +++++---- .../ops/mapper/test_chinese_convert_mapper.py | 9 ++-- .../ops/mapper/test_clean_copyright_mapper.py | 9 ++-- tests/ops/mapper/test_clean_email_mapper.py | 9 ++-- tests/ops/mapper/test_clean_html_mapper.py | 9 ++-- tests/ops/mapper/test_clean_ip_mapper.py | 9 ++-- tests/ops/mapper/test_clean_links_mapper.py | 9 ++-- tests/ops/mapper/test_exapnd_macro_mapper.py | 9 ++-- .../test_punctuation_normalization_mapper.py | 9 ++-- .../mapper/test_remove_bibliography_mapper.py | 9 ++-- .../ops/mapper/test_remove_comments_mapper.py | 9 ++-- tests/ops/mapper/test_remove_header_mapper.py | 9 ++-- 23 files changed, 212 insertions(+), 135 deletions(-) diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py index 818f1b1d4..8e6bb9dc1 100644 --- a/data_juicer/ops/mapper/chinese_convert_mapper.py +++ b/data_juicer/ops/mapper/chinese_convert_mapper.py @@ -27,6 +27,8 @@ class ChineseConvertMapper(Mapper): """Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji.""" + _batched_op = True + def __init__(self, mode: str = 's2t', *args, **kwargs): """ Initialization method. @@ -82,8 +84,10 @@ def __init__(self, mode: str = 's2t', *args, **kwargs): self.mode = mode prepare_converter(self.mode) - def process(self, sample): + def process(self, samples): prepare_converter(self.mode) - sample[self.text_key] = OPENCC_CONVERTER.convert(sample[self.text_key]) - return sample + samples[self.text_key] = [ + OPENCC_CONVERTER.convert(text) for text in samples[self.text_key] + ] + return samples diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py index dabb0cd40..8908d33e9 100644 --- a/data_juicer/ops/mapper/clean_copyright_mapper.py +++ b/data_juicer/ops/mapper/clean_copyright_mapper.py @@ -12,6 +12,8 @@ class CleanCopyrightMapper(Mapper): """Mapper to clean copyright comments at the beginning of the text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -23,21 +25,19 @@ def __init__(self, *args, **kwargs): self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/') self.cpat = re.compile('copyright', re.IGNORECASE) - def process(self, sample): - - r = self.pat.search(sample[self.text_key]) + def _process_single_sample(self, sample): + r = self.pat.search(sample) if r: # found one, now see if it contains "copyright", if so strip it span = r.span() - sub = sample[self.text_key][span[0]:span[1]] + sub = sample[span[0]:span[1]] if self.cpat.search(sub): # cut it - sample[self.text_key] = sample[ - self.text_key][:span[0]] + sample[self.text_key][span[1]:] + sample = sample[:span[0]] + sample[span[1]:] return sample - lines = sample[self.text_key].split('\n') + lines = sample.split('\n') skip = 0 # Greedy replace any file that begins with comment block, most @@ -51,5 +51,12 @@ def process(self, sample): if skip: # we skipped, consume it - sample[self.text_key] = '\n'.join(lines[skip:]) + sample = '\n'.join(lines[skip:]) return sample + + def process(self, samples): + samples[self.text_key] = [ + self._process_single_sample(text) + for text in samples[self.text_key] + ] + return samples diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py index 9708363e5..e5eb180ba 100644 --- a/data_juicer/ops/mapper/clean_email_mapper.py +++ b/data_juicer/ops/mapper/clean_email_mapper.py @@ -7,6 +7,8 @@ class CleanEmailMapper(Mapper): """Mapper to clean email in text samples.""" + _batched_op = True + def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): """ Initialization method. @@ -28,13 +30,13 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample + def process(self, samples): + for i, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue + samples[self.text_key][i] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + return samples diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py index 5c2c30c57..09e847dd0 100644 --- a/data_juicer/ops/mapper/clean_html_mapper.py +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -16,6 +16,8 @@ class CleanHtmlMapper(Mapper): """Mapper to clean html code in text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -25,7 +27,7 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, sample): + def process(self, samples): def _clean_html(raw_html): raw_html = raw_html.replace('
  • ', '\n*') @@ -35,5 +37,7 @@ def _clean_html(raw_html): parser = HTMLParser(raw_html) return parser.text() - sample[self.text_key] = _clean_html(sample[self.text_key]) - return sample + samples[self.text_key] = [ + _clean_html(text) for text in samples[self.text_key] + ] + return samples diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py index 607aeb585..53859521c 100644 --- a/data_juicer/ops/mapper/clean_ip_mapper.py +++ b/data_juicer/ops/mapper/clean_ip_mapper.py @@ -7,6 +7,8 @@ class CleanIpMapper(Mapper): """Mapper to clean ipv4 and ipv6 address in text samples.""" + _batched_op = True + def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): """ Initialization method. @@ -32,13 +34,12 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): self.pattern = pattern[2:-1] self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample - - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + def process(self, samples): + for i, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue + samples[self.text_key][i] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) + return samples diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py index bcd90d524..ee7f663e8 100644 --- a/data_juicer/ops/mapper/clean_links_mapper.py +++ b/data_juicer/ops/mapper/clean_links_mapper.py @@ -10,6 +10,8 @@ class CleanLinksMapper(Mapper): """Mapper to clean links like http/https/ftp in text samples.""" + _batched_op = True + def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): """ Initialization method. @@ -38,13 +40,13 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): self.pattern = pattern[2:-1] self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample + def process(self, samples): + for i, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + samples[self.text_key][i] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) + return samples diff --git a/data_juicer/ops/mapper/expand_macro_mapper.py b/data_juicer/ops/mapper/expand_macro_mapper.py index 2f5d7fe83..b41f8ba74 100644 --- a/data_juicer/ops/mapper/expand_macro_mapper.py +++ b/data_juicer/ops/mapper/expand_macro_mapper.py @@ -12,6 +12,8 @@ class ExpandMacroMapper(Mapper): """Mapper to expand macro definitions in the document body of Latex samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -55,26 +57,29 @@ def _build_non_arg_macros_dict(self, file_content): macros[macro_name] = macro_val return macros - def process(self, sample): - non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key]) - - # TODO: macros that take arguments are not supported yet - arg_macros = {} - - # inline-expand all non-arg macros - for macro_name, macro_value in non_arg_macros.items(): - sample[self.text_key] = re.sub( - # make pattern grouped to make sure that the macro is not part - # of a longer alphanumeric word - pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', - # replace the macro with its value and add back the character - # that was matched after the macro - repl=macro_value + r'\2', - string=sample[self.text_key]) - - # inline-expand all macros that use args - # TODO: inline-expand macros with args - for macro_name, macro_value in arg_macros.items(): - pass - - return sample + def process(self, samples): + for i, text in enumerate(samples[self.text_key]): + non_arg_macros = self._build_non_arg_macros_dict(text) + + # TODO: macros that take arguments are not supported yet + arg_macros = {} + + # inline-expand all non-arg macros + for macro_name, macro_value in non_arg_macros.items(): + text = re.sub( + # make pattern grouped to make sure that the macro + # is not part of a longer alphanumeric word + pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', + # replace the macro with its value and add back the + # character that was matched after the macro + repl=macro_value + r'\2', + string=text) + + # inline-expand all macros that use args + # TODO: inline-expand macros with args + for macro_name, macro_value in arg_macros.items(): + pass + + samples[self.text_key][i] = text + + return samples diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py index f1f219e47..b44005076 100644 --- a/data_juicer/ops/mapper/fix_unicode_mapper.py +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -37,7 +37,7 @@ def __init__(self, normalization: str = None, *args, **kwargs): def process(self, samples): samples[self.text_key] = [ - ftfy.fix_text(i, normalization=self.normalization) - for i in samples[self.text_key] + ftfy.fix_text(text, normalization=self.normalization) + for text in samples[self.text_key] ] return samples diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py index b6640e9eb..9217dda07 100644 --- a/data_juicer/ops/mapper/punctuation_normalization_mapper.py +++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py @@ -10,6 +10,8 @@ class PunctuationNormalizationMapper(Mapper): """Mapper to normalize unicode punctuations to English punctuations in text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -55,8 +57,10 @@ def __init__(self, *args, **kwargs): '►': '-', } - def process(self, sample): - sample[self.text_key] = ''.join([ - self.punctuation_unicode.get(c, c) for c in sample[self.text_key] - ]) - return sample + def process(self, samples): + samples[self.text_key] = [ + ''.join([self.punctuation_unicode.get(c, c) for c in text]) + for text in samples[self.text_key] + ] + + return samples diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py index 2ce852d66..1eecd66d2 100644 --- a/data_juicer/ops/mapper/remove_bibliography_mapper.py +++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py @@ -12,6 +12,8 @@ class RemoveBibliographyMapper(Mapper): """Mapper to remove bibliography at the end of documents in Latex samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -27,9 +29,12 @@ def __init__(self, *args, **kwargs): self.pattern += r'\\bibliography\{.*\}' self.pattern += r').*$' - def process(self, sample): - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + def process(self, samples): + samples[self.text_key] = [ + re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL) for text in samples[self.text_key] + ] + + return samples diff --git a/data_juicer/ops/mapper/remove_comments_mapper.py b/data_juicer/ops/mapper/remove_comments_mapper.py index c5f083c14..5e63ee27b 100644 --- a/data_juicer/ops/mapper/remove_comments_mapper.py +++ b/data_juicer/ops/mapper/remove_comments_mapper.py @@ -17,6 +17,8 @@ class RemoveCommentsMapper(Mapper): Only support 'tex' for now. """ + _batched_op = True + def __init__(self, doc_type: Union[str, List[str]] = 'tex', inline: bool = True, @@ -37,19 +39,23 @@ def __init__(self, self.inline = inline self.multiline = multiline - def process(self, sample): + def process(self, samples): # TODO: remove different comments by sample type - if self.inline: - # remove all in comments within a line - sample[self.text_key] = re.sub(pattern=r'[^\\]%.+$', - repl=r'', - string=sample[self.text_key], - flags=re.MULTILINE) - - if self.multiline: - sample[self.text_key] = re.sub(pattern=r'(?m)^%.*\n?', - repl=r'', - string=sample[self.text_key], - flags=re.MULTILINE) - return sample + for i, text in enumerate(samples[self.text_key]): + if self.inline: + # remove all in comments within a line + text = re.sub(pattern=r'[^\\]%.+$', + repl=r'', + string=text, + flags=re.MULTILINE) + + if self.multiline: + text = re.sub(pattern=r'(?m)^%.*\n?', + repl=r'', + string=text, + flags=re.MULTILINE) + + samples[self.text_key][i] = text + + return samples diff --git a/data_juicer/ops/mapper/remove_header_mapper.py b/data_juicer/ops/mapper/remove_header_mapper.py index 8371d2f99..85e510d0c 100644 --- a/data_juicer/ops/mapper/remove_header_mapper.py +++ b/data_juicer/ops/mapper/remove_header_mapper.py @@ -12,6 +12,8 @@ class RemoveHeaderMapper(Mapper): """Mapper to remove headers at the beginning of documents in Latex samples.""" + _batched_op = True + def __init__(self, drop_no_head: bool = True, *args, **kwargs): """ Initialization method. @@ -34,15 +36,17 @@ def __init__(self, drop_no_head: bool = True, *args, **kwargs): self.drop_no_head = drop_no_head - def process(self, sample): + def process(self, samples): + for i, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + if self.drop_no_head: + text = '' + continue + text = re.sub(pattern=self.pattern, + repl=r'\2', + string=text, + flags=re.DOTALL) - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - if self.drop_no_head: - sample[self.text_key] = '' - return sample + samples[self.text_key][i] = text - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'\2', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + return samples diff --git a/tests/ops/mapper/test_chinese_convert_mapper.py b/tests/ops/mapper/test_chinese_convert_mapper.py index 9bbe8e8df..bc21f40fe 100644 --- a/tests/ops/mapper/test_chinese_convert_mapper.py +++ b/tests/ops/mapper/test_chinese_convert_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.chinese_convert_mapper import ChineseConvertMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self, mode='s2t'): self.op = ChineseConvertMapper(mode) def _run_chinese_convert(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_s2t(self): diff --git a/tests/ops/mapper/test_clean_copyright_mapper.py b/tests/ops/mapper/test_clean_copyright_mapper.py index 726d829f7..a236988f7 100644 --- a/tests/ops/mapper/test_clean_copyright_mapper.py +++ b/tests/ops/mapper/test_clean_copyright_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_copyright_mapper import CleanCopyrightMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanCopyrightMapper() def _run_clean_copyright(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_clean_copyright(self): diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py index b3f0e5e9a..1ff7e389e 100644 --- a/tests/ops/mapper/test_clean_email_mapper.py +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_email_mapper import CleanEmailMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class CleanEmailMapperTest(DataJuicerTestCaseBase): def _run_clean_email(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_clean_email(self): diff --git a/tests/ops/mapper/test_clean_html_mapper.py b/tests/ops/mapper/test_clean_html_mapper.py index 69249b60a..71d4e11ee 100644 --- a/tests/ops/mapper/test_clean_html_mapper.py +++ b/tests/ops/mapper/test_clean_html_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_html_mapper import CleanHtmlMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanHtmlMapper() def _run_helper(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_complete_html_text(self): diff --git a/tests/ops/mapper/test_clean_ip_mapper.py b/tests/ops/mapper/test_clean_ip_mapper.py index ccbaf52b7..479228263 100644 --- a/tests/ops/mapper/test_clean_ip_mapper.py +++ b/tests/ops/mapper/test_clean_ip_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_ip_mapper import CleanIpMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class CleanIpMapperTest(DataJuicerTestCaseBase): def _run_clean_ip(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_ipv4(self): diff --git a/tests/ops/mapper/test_clean_links_mapper.py b/tests/ops/mapper/test_clean_links_mapper.py index 28e14b2d9..5efcd4acd 100644 --- a/tests/ops/mapper/test_clean_links_mapper.py +++ b/tests/ops/mapper/test_clean_links_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_links_mapper import CleanLinksMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanLinksMapper() def _run_clean_links(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_lower_ftp_links_text(self): diff --git a/tests/ops/mapper/test_exapnd_macro_mapper.py b/tests/ops/mapper/test_exapnd_macro_mapper.py index 68dbf047b..bdc758193 100644 --- a/tests/ops/mapper/test_exapnd_macro_mapper.py +++ b/tests/ops/mapper/test_exapnd_macro_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.expand_macro_mapper import ExpandMacroMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = ExpandMacroMapper() def _run_expand_macro(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self): diff --git a/tests/ops/mapper/test_punctuation_normalization_mapper.py b/tests/ops/mapper/test_punctuation_normalization_mapper.py index a69d4040e..080666ce8 100644 --- a/tests/ops/mapper/test_punctuation_normalization_mapper.py +++ b/tests/ops/mapper/test_punctuation_normalization_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.punctuation_normalization_mapper import \ PunctuationNormalizationMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = PunctuationNormalizationMapper() def _run_punctuation_normalization(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self): diff --git a/tests/ops/mapper/test_remove_bibliography_mapper.py b/tests/ops/mapper/test_remove_bibliography_mapper.py index 76096fe93..9d08c2a4d 100644 --- a/tests/ops/mapper/test_remove_bibliography_mapper.py +++ b/tests/ops/mapper/test_remove_bibliography_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_bibliography_mapper import \ RemoveBibliographyMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = RemoveBibliographyMapper() def _run_remove_bibliography(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_bibliography_case(self): diff --git a/tests/ops/mapper/test_remove_comments_mapper.py b/tests/ops/mapper/test_remove_comments_mapper.py index 81a0df5de..93a287460 100644 --- a/tests/ops/mapper/test_remove_comments_mapper.py +++ b/tests/ops/mapper/test_remove_comments_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_comments_mapper import RemoveCommentsMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class RemoveCommentsMapperTest(DataJuicerTestCaseBase): def _run_remove_comments(self, samples, op): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_tex_case(self): diff --git a/tests/ops/mapper/test_remove_header_mapper.py b/tests/ops/mapper/test_remove_header_mapper.py index c91bfe790..0196b0317 100644 --- a/tests/ops/mapper/test_remove_header_mapper.py +++ b/tests/ops/mapper/test_remove_header_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_header_mapper import RemoveHeaderMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = RemoveHeaderMapper() def _run_remove_header(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self):