-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Two addition filter for image captions (resubmit) (#122)
* text action filter * text action filter * text entity dependency filter * complete config_all.yaml and Operators.md * move spacy-pkuseg to science_requires.txt * fix typo in comments
- Loading branch information
Showing
11 changed files
with
414 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import remove_special_tokens | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Filter | ||
|
||
OP_NAME = 'text_action_filter' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class TextActionFilter(Filter): | ||
""" | ||
Filter to keep texts those contain actions in the text. | ||
""" | ||
|
||
def __init__(self, | ||
lang: str = 'en', | ||
min_action_num: int = 1, | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param lang: language of the text in the samples. 'en' for detection of | ||
actions in English and 'zh' for detection of actions in Chinese. | ||
:param mini_action_num: The min action number in the filtering. samples | ||
will be filtered if their action number in the text is below this | ||
parameter. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
|
||
if lang not in ['en', 'zh']: | ||
raise ValueError( | ||
f'Language [{lang}] is not supported in action detection.' | ||
f'Can only be one of ["en", "zh"].') | ||
self.lang = lang | ||
self.model_key = prepare_model(model_type='spacy', lang=lang) | ||
self.action_poss = ['VERB'] | ||
self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN'] | ||
self.min_action_num = min_action_num | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.num_action in sample[Fields.stats]: | ||
return sample | ||
|
||
text = remove_special_tokens(sample[self.text_key]) | ||
|
||
# process text via spacy and count the actions in text | ||
model = get_model(self.model_key) | ||
doc = model(text) | ||
num_action = 0 | ||
for token in doc: | ||
if token.pos_ in self.action_poss \ | ||
and token.tag_ in self.action_tags: | ||
num_action += 1 | ||
sample[Fields.stats][StatsKeys.num_action] = num_action | ||
|
||
return sample | ||
|
||
def process(self, sample): | ||
num_action = sample[Fields.stats][StatsKeys.num_action] | ||
if self.min_action_num <= num_action: | ||
return True | ||
else: | ||
return False |
103 changes: 103 additions & 0 deletions
103
data_juicer/ops/filter/text_entity_dependency_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import numpy as np | ||
|
||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import remove_special_tokens | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Filter | ||
|
||
OP_NAME = 'text_entity_dependency_filter' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class TextEntityDependencyFilter(Filter): | ||
""" | ||
Identify the entities in the text which are independent with other token, | ||
and filter them. The text containing no entities will be omitted. | ||
""" | ||
|
||
def __init__(self, | ||
lang: str = 'en', | ||
min_dependency_num: int = 1, | ||
any_or_all: str = 'all', | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param lang: language of the text in the samples. 'en' for detection of | ||
entities in English and 'zh' for detection of entities in Chinese. | ||
:param mini_dependency_num: The min token number in the filtering. | ||
Objects is independent if their number of edges in the dependency | ||
tree is below this parameter. | ||
:param any_or_all: keep this sample with 'any' or 'all' strategy. | ||
'any': keep this sample if any objet is dependent. 'all': keep this | ||
sample only if all images are dependent. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
|
||
if lang not in ['en', 'zh']: | ||
raise ValueError( | ||
f'Language [{lang}] is not supported in entities detection.' | ||
f'Can only be one of ["en", "zh"].') | ||
self.lang = lang | ||
self.model_key = prepare_model(model_type='spacy', lang=lang) | ||
self.entity_poss = ['NOUN', 'PROPN', 'PRON'] | ||
self.entity_tags = ['NN', 'NR', 'PN', 'NNS', 'NNP', 'NNPS', 'PRP'] | ||
self.min_dependency_num = min_dependency_num | ||
if any_or_all not in ['any', 'all']: | ||
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' | ||
f'Can only be one of ["any", "all"].') | ||
self.any = (any_or_all == 'any') | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.num_dependency_edges in sample[Fields.stats]: | ||
return sample | ||
|
||
text = remove_special_tokens(sample[self.text_key]) | ||
|
||
# identify entities | ||
model = get_model(self.model_key) | ||
doc = model(text) | ||
entity_to_dependency_nums = {} | ||
for token in doc: | ||
if token.pos_ in self.entity_poss \ | ||
and token.tag_ in self.entity_tags: | ||
entity_to_dependency_nums[token] = 0 | ||
|
||
# count the edges of each entity in dependency tree | ||
for obj in entity_to_dependency_nums: | ||
if obj.dep_ != 'ROOT': | ||
entity_to_dependency_nums[obj] += 1 | ||
for token in doc: | ||
# the punctation mark such as ',', '.' | ||
if token.pos_ == 'PUNCT': | ||
continue | ||
|
||
if token.head in entity_to_dependency_nums.keys( | ||
) and token.dep_ != 'ROOT': | ||
entity_to_dependency_nums[token.head] += 1 | ||
|
||
sample[Fields.stats][StatsKeys.num_dependency_edges] = [ | ||
n for _, n in entity_to_dependency_nums.items() | ||
] | ||
|
||
return sample | ||
|
||
def process(self, sample): | ||
num_dependency_edges = sample[Fields.stats][ | ||
StatsKeys.num_dependency_edges] | ||
keep_bools = np.array([ | ||
self.min_dependency_num <= num_edge | ||
for num_edge in num_dependency_edges | ||
]) | ||
# omit the samples without entity | ||
if len(keep_bools) <= 0: | ||
return False | ||
|
||
# different strategies | ||
if self.any: | ||
return keep_bools.any() | ||
else: | ||
return keep_bools.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ opencc==1.1.6 | |
imagededup | ||
torch | ||
dlib | ||
spacy-pkuseg==0.0.32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import unittest | ||
import os | ||
|
||
from datasets import Dataset | ||
|
||
from data_juicer.ops.filter.text_action_filter import TextActionFilter | ||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.mm_utils import SpecialTokens | ||
|
||
|
||
class TextActionFilterTest(unittest.TestCase): | ||
|
||
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', | ||
'data') | ||
|
||
cat_path = os.path.join(data_path, 'cat.jpg') | ||
img3_path = os.path.join(data_path, 'img3.jpg') | ||
|
||
def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names): | ||
if Fields.stats not in dataset.features: | ||
dataset = dataset.add_column(name=Fields.stats, | ||
column=[{}] * dataset.num_rows) | ||
dataset = dataset.map(op.compute_stats) | ||
dataset = dataset.filter(op.process) | ||
dataset = dataset.select_columns(column_names=column_names) | ||
res_list = dataset.to_list() | ||
self.assertEqual(res_list, target_list) | ||
|
||
def test_en_text_case(self): | ||
|
||
ds_list = [{ | ||
'text': 'Tom is playing piano.' | ||
}, { | ||
'text': 'Tom plays piano.' | ||
}, { | ||
'text': 'Tom played piano.' | ||
},{ | ||
'text': 'I play piano.' | ||
}, { | ||
'text': 'to play piano.' | ||
}, { | ||
'text': 'Tom 在打篮球' | ||
}, { | ||
'text': 'a v s e c s f e f g a a a ' | ||
}, { | ||
'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' | ||
}, { | ||
'text': 'that is a green tree' | ||
}] | ||
tgt_list = [{ | ||
'text': 'Tom is playing piano.' | ||
}, { | ||
'text': 'Tom plays piano.' | ||
}, { | ||
'text': 'Tom played piano.' | ||
},{ | ||
'text': 'I play piano.' | ||
}, { | ||
'text': 'to play piano.' | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='en') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text']) | ||
|
||
def test_zh_text_case(self): | ||
|
||
ds_list = [{ | ||
'text': '小明在 弹奏钢琴' | ||
}, { | ||
'text': 'Tom is playing 篮球' | ||
}, { | ||
'text': '上上下下左左右右' | ||
}, { | ||
'text': 'Tom在打篮球' | ||
}, { | ||
'text': '我有一只猫,它是一只猫' | ||
}] | ||
tgt_list = [{ | ||
'text': '小明在 弹奏钢琴' | ||
}, { | ||
'text': 'Tom在打篮球' | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='zh') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text']) | ||
|
||
def test_image_text_case(self): | ||
ds_list = [{ | ||
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'{SpecialTokens.image}小猫咪', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}', | ||
'images': [self.img3_path] | ||
}, { | ||
'text': f'雨中行走的女人背影', | ||
'images': [self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'雨中行走的女人背影', | ||
'images': [self.img3_path] | ||
}] | ||
|
||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='zh') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images']) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.