Skip to content

Commit

Permalink
support batch for WordsNumFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Sep 3, 2024
1 parent a865b3d commit e4d00d9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
61 changes: 39 additions & 22 deletions data_juicer/ops/filter/words_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class WordsNumFilter(Filter):
"""Filter to keep samples with total words number within a specific
range."""

_batched_op = True

def __init__(self,
lang: str = 'en',
tokenization: bool = False,
Expand Down Expand Up @@ -54,28 +56,43 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_words in sample[Fields.stats]:
return sample

def compute_stats(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.model_key}'
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
if context:
sample[Fields.context][words_key] = words
words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
sample[Fields.stats][StatsKeys.num_words] = len(words)
return sample

def process(self, sample):
if self.min_num <= sample[Fields.stats][
StatsKeys.num_words] <= self.max_num:
return True
for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.num_words in stat:
continue
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][idx][words_key] = words
words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
samples_stats[idx][StatsKeys.num_words] = len(words)

return samples

def process(self, samples):
if isinstance(samples[Fields.stats], list):
bool_results = []
for stat in samples[Fields.stats]:
if self.min_num <= stat[StatsKeys.num_words] <= self.max_num:
bool_results.append(True)
else:
bool_results.append(False)
return bool_results
else:
return False
# single sample for ray filter
if self.min_num <= samples[Fields.stats][
StatsKeys.num_words] <= self.max_num:
return True
else:
return False
9 changes: 5 additions & 4 deletions tests/ops/filter/test_word_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def _run_words_num_filter(self, dataset: Dataset, target_list, op):
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.map(op.compute_stats, batch_size=op.batch_size)
dataset = dataset.filter(op.process, batch_size=op.batch_size)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
Expand All @@ -41,7 +41,7 @@ def test_case(self):
'text': 'a v s e c s f e f g a a a '
}]
dataset = Dataset.from_list(ds_list)
op = WordsNumFilter(min_num=5, max_num=15)
op = WordsNumFilter(min_num=5, max_num=15, batch_size=2)
self._run_words_num_filter(dataset, tgt_list, op)

def test_zh_case(self):
Expand All @@ -68,7 +68,8 @@ def test_zh_case(self):
op = WordsNumFilter(lang='zh',
tokenization=True,
min_num=10,
max_num=25)
max_num=25,
batch_size=1)
self._run_words_num_filter(dataset, tgt_list, op)


Expand Down

0 comments on commit e4d00d9

Please sign in to comment.