Skip to content

Commit

Permalink
Two addition filter for image captions (resubmit) (#122)
Browse files Browse the repository at this point in the history
* text action filter

* text action filter

* text entity dependency filter

* complete config_all.yaml and Operators.md

* move spacy-pkuseg to science_requires.txt

* fix typo in comments
  • Loading branch information
BeachWang authored Dec 8, 2023
1 parent 30c781b commit 66eda06
Show file tree
Hide file tree
Showing 11 changed files with 414 additions and 7 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
args: ['--style', '{column_limit: 79}']
exclude: data_juicer/ops/common/special_characters.py
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
Expand Down
7 changes: 7 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ process:
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- text_action_filter: # filter text according the number of action verb
lang: en # consider the words in what language
min_action_num: 1 # text will be filtered whose verbs less the min action number
- text_entity_dependency_filter: # filter text without non independent entity nouns
lang: en # consider the words in what language
min_dependency_num: 1 # the min number of adjacent edges of a non independent noun in dependency tree
any_or_all: any # keep this sample when any/all entity nouns are non independent
- text_length_filter: # filter text with length out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
language_id_score_filter, maximum_line_length_filter,
perplexity_filter, special_characters_filter,
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_length_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, word_num_filter, word_repetition_filter)
66 changes: 66 additions & 0 deletions data_juicer/ops/filter/text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_action_filter'


@OPERATORS.register_module(OP_NAME)
class TextActionFilter(Filter):
"""
Filter to keep texts those contain actions in the text.
"""

def __init__(self,
lang: str = 'en',
min_action_num: int = 1,
*args,
**kwargs):
"""
Initialization method.
:param lang: language of the text in the samples. 'en' for detection of
actions in English and 'zh' for detection of actions in Chinese.
:param mini_action_num: The min action number in the filtering. samples
will be filtered if their action number in the text is below this
parameter.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in action detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.action_poss = ['VERB']
self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN']
self.min_action_num = min_action_num

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_action in sample[Fields.stats]:
return sample

text = remove_special_tokens(sample[self.text_key])

# process text via spacy and count the actions in text
model = get_model(self.model_key)
doc = model(text)
num_action = 0
for token in doc:
if token.pos_ in self.action_poss \
and token.tag_ in self.action_tags:
num_action += 1
sample[Fields.stats][StatsKeys.num_action] = num_action

return sample

def process(self, sample):
num_action = sample[Fields.stats][StatsKeys.num_action]
if self.min_action_num <= num_action:
return True
else:
return False
103 changes: 103 additions & 0 deletions data_juicer/ops/filter/text_entity_dependency_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_entity_dependency_filter'


@OPERATORS.register_module(OP_NAME)
class TextEntityDependencyFilter(Filter):
"""
Identify the entities in the text which are independent with other token,
and filter them. The text containing no entities will be omitted.
"""

def __init__(self,
lang: str = 'en',
min_dependency_num: int = 1,
any_or_all: str = 'all',
*args,
**kwargs):
"""
Initialization method.
:param lang: language of the text in the samples. 'en' for detection of
entities in English and 'zh' for detection of entities in Chinese.
:param mini_dependency_num: The min token number in the filtering.
Objects is independent if their number of edges in the dependency
tree is below this parameter.
:param any_or_all: keep this sample with 'any' or 'all' strategy.
'any': keep this sample if any objet is dependent. 'all': keep this
sample only if all images are dependent.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in entities detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.entity_poss = ['NOUN', 'PROPN', 'PRON']
self.entity_tags = ['NN', 'NR', 'PN', 'NNS', 'NNP', 'NNPS', 'PRP']
self.min_dependency_num = min_dependency_num
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_dependency_edges in sample[Fields.stats]:
return sample

text = remove_special_tokens(sample[self.text_key])

# identify entities
model = get_model(self.model_key)
doc = model(text)
entity_to_dependency_nums = {}
for token in doc:
if token.pos_ in self.entity_poss \
and token.tag_ in self.entity_tags:
entity_to_dependency_nums[token] = 0

# count the edges of each entity in dependency tree
for obj in entity_to_dependency_nums:
if obj.dep_ != 'ROOT':
entity_to_dependency_nums[obj] += 1
for token in doc:
# the punctation mark such as ',', '.'
if token.pos_ == 'PUNCT':
continue

if token.head in entity_to_dependency_nums.keys(
) and token.dep_ != 'ROOT':
entity_to_dependency_nums[token.head] += 1

sample[Fields.stats][StatsKeys.num_dependency_edges] = [
n for _, n in entity_to_dependency_nums.items()
]

return sample

def process(self, sample):
num_dependency_edges = sample[Fields.stats][
StatsKeys.num_dependency_edges]
keep_bools = np.array([
self.min_dependency_num <= num_edge
for num_edge in num_dependency_edges
])
# omit the samples without entity
if len(keep_bools) <= 0:
return False

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class StatsKeys(object):
special_char_ratio = 'special_char_ratio'
stopwords_ratio = 'stopwords_ratio'
text_len = 'text_len'
num_action = 'num_action'
num_dependency_edges = 'num_dependency_edges'
num_token = 'num_token'
num_words = 'num_words'
word_rep_ratio = 'word_rep_ratio'
Expand Down
12 changes: 7 additions & 5 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
| [ Filter ]( #filter ) | 22 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 24 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |

Expand Down Expand Up @@ -77,11 +77,11 @@ All the specific operators are listed below, each featured with several capabili
| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range |
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| face_area_filter | Image | - | Keeps samples contains images with face area ratios within the specified range |
| face_area_filter | Image | - | Keeps samples containing images with face area ratios within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within the specified range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within the specified range |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within the specified range |
| image_aspect_ratio_filter | Image | - | Keeps samples containing images with aspect ratios within the specified range |
| image_shape_filter | Image | - | Keeps samples containing images with widths and heights within the specified range |
| image_size_filter | Image | - | Keeps samples containing images whose size in bytes are within the specified range |
| image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model |
| image_text_similarity_filter | Multimodal | - | Keeps samples with image-text feature cosine similarity within the specified range based on a CLIP model |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
Expand All @@ -92,6 +92,8 @@ All the specific operators are listed below, each featured with several capabili
| specified_numeric_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified range (for numeric types) |
| stopwords_filter | General | en, zh | Keeps samples with stopword ratio above the specified threshold |
| suffix_filter | General | en, zh | Keeps samples with specified suffixes |
| text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts |
| text_entity_dependency_filter | General | en, zh | Keeps samples containing entity nouns related to other tokens in the dependency tree of the texts |
| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range |
| token_num_filter | General | en, zh | Keeps samples with token count within the specified range |
| word_num_filter | General | en, zh | Keeps samples with word count within the specified range |
Expand Down
4 changes: 3 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 22 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 24 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -89,6 +89,8 @@ Data-Juicer 中的算子分为以下 5 种类型。
| specified_numeric_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) |
| stopwords_filter | General | en, zh | 保留停用词比率高于指定阈值的样本 |
| suffix_filter | General | en, zh | 保留包含特定后缀的样本 |
| text_action_filter | General | en, zh | 保留文本部分包含动作的样本 |
| text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 |
| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 |
| token_num_filter | General | en, zh | 保留token数在指定范围内的样本 |
| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
Expand Down
1 change: 1 addition & 0 deletions environments/science_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ opencc==1.1.6
imagededup
torch
dlib
spacy-pkuseg==0.0.32
114 changes: 114 additions & 0 deletions tests/ops/filter/test_text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
import os

from datasets import Dataset

from data_juicer.ops.filter.text_action_filter import TextActionFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens


class TextActionFilterTest(unittest.TestCase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')

cat_path = os.path.join(data_path, 'cat.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')

def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=column_names)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_en_text_case(self):

ds_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}, {
'text': 'Tom 在打篮球'
}, {
'text': 'a v s e c s f e f g a a a '
}, {
'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
}, {
'text': 'that is a green tree'
}]
tgt_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='en')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_zh_text_case(self):

ds_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom is playing 篮球'
}, {
'text': '上上下下左左右右'
}, {
'text': 'Tom在打篮球'
}, {
'text': '我有一只猫,它是一只猫'
}]
tgt_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom在打篮球'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_image_text_case(self):
ds_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}小猫咪',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}',
'images': [self.img3_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]
tgt_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]

dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images'])

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 66eda06

Please sign in to comment.