From dcc0df3bc0a6851cc1c2cc940e2aa4b21366c1db Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Tue, 16 Jul 2024 17:58:04 +0800 Subject: [PATCH] add op optimize_instruction_mapper and support vllm --- configs/config_all.yaml | 14 ++-- data_juicer/ops/mapper/__init__.py | 4 +- data_juicer/ops/mapper/extract_qa_mapper.py | 31 ++++++-- .../ops/mapper/generate_instruction_mapper.py | 71 ++++++++++------- .../ops/mapper/optimize_instruction_mapper.py | 76 +++++++++++++++++++ data_juicer/utils/model_utils.py | 31 +++++++- docs/Operators.md | 3 +- docs/Operators_ZH.md | 5 +- environments/science_requires.txt | 1 + tests/ops/mapper/test_extract_qa_mapper.py | 12 ++- .../test_generate_instruction_mapper.py | 13 +++- .../test_optimize_instruction_mapper.py | 35 +++++++++ 12 files changed, 248 insertions(+), 48 deletions(-) create mode 100644 data_juicer/ops/mapper/optimize_instruction_mapper.py create mode 100644 tests/ops/mapper/test_optimize_instruction_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 3e87ffa25..4428bde8f 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -69,13 +69,13 @@ process: - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. - extract_qa_mapper: # mapper to extract question and answer pair from text. - hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' + hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # model name on huggingface to extract question and answer pair. - fix_unicode_mapper: # fix unicode errors in text. - generate_instruction_mapper: # generate new instruction text data. - hf_model: 'Qwen/Qwen-7B-Chat' - seed_file: 'demos/data/demo-dataset-chatml.jsonl' - instruct_num: 3 - similarity_threshold: 0.7 + hf_model: 'Qwen/Qwen-7B-Chat' # model name on huggingface to generate instruction. + seed_file: 'demos/data/demo-dataset-chatml.jsonl' # Seed file as instruction samples to generate new instructions, chatml format. + instruct_num: 3 # the number of generated samples. + similarity_threshold: 0.7 # the similarity score threshold between the generated samples and the seed samples.Range from 0 to 1. Samples with similarity score less than this threshold will be kept. - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] @@ -134,6 +134,10 @@ process: delete_random_char: false # whether to open the augmentation method of deleting random characters from the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据增强" swap_random_char: false # whether to open the augmentation method of swapping random contiguous characters in the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据强增方法" replace_equivalent_num: false # whether to open the augmentation method of replacing random numbers with their equivalent representations in the original texts. **Notice**: Only for numbers for now. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有伍种不同的数据增强方法" + - optimize_instruction_mapper: # optimize instruction. + hf_model: 'alibaba-pai/Qwen2-7B-Instruct-Refine' # model name on huggingface to optimize instruction + enable_vllm: false # whether to use vllm for inference acceleration. + tensor_parallel_size: 1 # it is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 0279df217..25aa7d5f6 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -5,7 +5,7 @@ extract_qa_mapper, fix_unicode_mapper, image_blur_mapper, image_captioning_from_gpt4v_mapper, image_captioning_mapper, image_diffusion_mapper, image_face_blur_mapper, - nlpaug_en_mapper, nlpcda_zh_mapper, + nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper, punctuation_normalization_mapper, remove_bibliography_mapper, remove_comments_mapper, remove_header_mapper, remove_long_words_mapper, remove_non_chinese_character_mapper, @@ -42,6 +42,7 @@ from .image_face_blur_mapper import ImageFaceBlurMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper +from .optimize_instruction_mapper import OptimizeInstructionMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -93,6 +94,7 @@ 'VideoFFmpegWrappedMapper', 'ChineseConvertMapper', 'NlpcdaZhMapper', + 'OptimizeInstructionMapper', 'ImageBlurMapper', 'CleanCopyrightMapper', 'RemoveNonChineseCharacterlMapper', diff --git a/data_juicer/ops/mapper/extract_qa_mapper.py b/data_juicer/ops/mapper/extract_qa_mapper.py index 373ad88ac..a0e93f943 100644 --- a/data_juicer/ops/mapper/extract_qa_mapper.py +++ b/data_juicer/ops/mapper/extract_qa_mapper.py @@ -26,6 +26,8 @@ def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', pattern: str = None, qa_format: str = 'chatml', + enable_vllm=False, + tensor_parallel_size=1, *args, **kwargs): """ @@ -33,6 +35,8 @@ def __init__(self, :param hf_model: Hugginface model id. :param pattern: regular expression pattern to search for within text. :param qa_format: Output format of question and answer pair. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. :param args: extra args :param kwargs: extra args @@ -59,8 +63,17 @@ def __init__(self, self.pattern = pattern self.qa_format = qa_format - self.model_key = prepare_model(model_type='huggingface', - pretrained_model_name_or_path=hf_model) + self.enable_vllm = enable_vllm + + if enable_vllm: + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + tensor_parallel_size=tensor_parallel_size) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model) def _extract_qa(self, output): """Extract qestion and answer pair from model output response.""" @@ -78,10 +91,16 @@ def _extract_qa(self, output): def process(self, sample, rank=None): model, processor = get_model(self.model_key, rank=rank) - inputs = processor(sample[self.text_key], - return_tensors='pt').to(model.device) - response = model.generate(**inputs) - output = processor.decode(response.cpu()[0], skip_special_tokens=True) + if self.enable_vllm: + response = model.generate([sample[self.text_key]]) + output = response[0].outputs[0].text + else: + inputs = processor(sample[self.text_key], + return_tensors='pt').to(model.device) + response = model.generate(**inputs) + output = processor.decode(response.cpu()[0], + skip_special_tokens=True) + qa_list = self._extract_qa(output) if not len(qa_list): diff --git a/data_juicer/ops/mapper/generate_instruction_mapper.py b/data_juicer/ops/mapper/generate_instruction_mapper.py index 5782ad118..2d8046b29 100644 --- a/data_juicer/ops/mapper/generate_instruction_mapper.py +++ b/data_juicer/ops/mapper/generate_instruction_mapper.py @@ -43,6 +43,8 @@ def __init__(self, instruct_num, similarity_threshold=0.7, prompt_template=None, + enable_vllm=False, + tensor_parallel_size=1, *args, **kwargs): """ @@ -61,6 +63,8 @@ def __init__(self, :param prompt_template: Prompt template for generate samples. Please make sure the template contains "{augmented_data}", which corresponds to the augmented samples. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. :param args: extra args :param kwargs: extra args """ @@ -74,9 +78,17 @@ def __init__(self, if prompt_template is None: prompt_template = DEFAULT_PROMPT_TEMPLATE self.prompt_template = prompt_template - - self.model_key = prepare_model(model_type='huggingface', - pretrained_model_name_or_path=hf_model) + self.enable_vllm = enable_vllm + + if enable_vllm: + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + tensor_parallel_size=tensor_parallel_size) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model) self.seed_qa_samples = self.load_seed_qa_samples(seed_file) @@ -168,30 +180,35 @@ def process(self, sample=None, rank=None): self.instruct_num) input_prompt = self.build_prompt(random_qa_samples, self.prompt_template) - inputs = processor(input_prompt, return_tensors='pt').to(model.device) - response = model.generate(**inputs) - output_response = processor.decode(response.cpu()[0], - skip_special_tokens=True) - if output_response: - out_qa_pairs, response_str = self.parse_response(output_response) - - if self.similarity_type == 'rouge_l': - sim_score = self.max_rouge_l_score(response_str, - self.reference_samples) - else: - raise ValueError( - f'Not support similarity type "{self.similarity_type}"!') - - message_list = [] - if sim_score <= self.similarity_threshold: - for question, answer in out_qa_pairs: - message_list.append({'role': 'user', 'content': question}) - message_list.append({ - 'role': 'assistant', - 'content': answer - }) - else: - logging.info('Filter one instance due to similarity.') + if self.enable_vllm: + response = model.generate([input_prompt]) + response_str = response[0].outputs[0].text + else: + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + response = model.generate(**inputs) + response_str = processor.decode(response.cpu()[0], + skip_special_tokens=True) + + message_list = [] + out_qa_pairs, response_str = self.parse_response(response_str) + + if not response_str: + return {self.text_key: json.dumps({'messages': message_list})} + + if self.similarity_type == 'rouge_l': + sim_score = self.max_rouge_l_score(response_str, + self.reference_samples) + else: + raise ValueError( + f'Not support similarity type "{self.similarity_type}"!') + + if sim_score <= self.similarity_threshold: + for question, answer in out_qa_pairs: + message_list.append({'role': 'user', 'content': question}) + message_list.append({'role': 'assistant', 'content': answer}) + else: + logging.info('Filter this generated sample due to similarity.') return { self.text_key: diff --git a/data_juicer/ops/mapper/optimize_instruction_mapper.py b/data_juicer/ops/mapper/optimize_instruction_mapper.py new file mode 100644 index 000000000..48039a040 --- /dev/null +++ b/data_juicer/ops/mapper/optimize_instruction_mapper.py @@ -0,0 +1,76 @@ +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +DEFAULT_SYSTEM_PROMPT = '请优化这个指令,将其修改为一个更详细具体的指令。' + + +@OPERATORS.register_module('optimize_instruction_mapper') +class OptimizeInstructionMapper(Mapper): + """Mapper to optimize instruction. + Recommended model list: [ + alibaba-pai/Qwen2-1.5B-Instruct-Refine + alibaba-pai/Qwen2-7B-Instruct-Refine + ] + """ + + def __init__(self, + hf_model='alibaba-pai/Qwen2-7B-Instruct-Refine', + system_prompt=None, + enable_vllm=False, + tensor_parallel_size=1, + *args, + **kwargs): + """ + Initialization method. + :param hf_model: Hugginface model id. + :param system_prompt: System prompt for optimize samples. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. + The number of GPUs to use for distributed execution with tensor + parallelism. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + self.system_prompt = system_prompt + self.enable_vllm = enable_vllm + + if enable_vllm: + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + tensor_parallel_size=tensor_parallel_size) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model) + + def process(self, sample=None, rank=None): + model, processor = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': sample[self.text_key] + }] + input_prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + + if self.enable_vllm: + response = model.generate([input_prompt]) + output = response[0].outputs[0].text + else: + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + response = model.generate(**inputs) + output = processor.decode(response.cpu()[0], + skip_special_tokens=True) + + sample[self.text_key] = output + + return sample diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index e8612db2d..7d567064a 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -400,6 +400,34 @@ def prepare_huggingface_model(pretrained_model_name_or_path, return (model, processor) if return_model else processor +def prepare_vllm_model(pretrained_model_name_or_path, + return_model=True, + trust_remote_code=False, + tensor_parallel_size=1): + """ + Prepare and load a HuggingFace model with the correspoding processor. + + :param pretrained_model_name_or_path: model name or path + :param return_model: return model or not + :param trust_remote_code: passed to transformers + :param tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + :return: a tuple (model, input processor) if `return_model` is True; + otherwise, only the processor is returned. + """ + from transformers import AutoProcessor + from vllm import LLM as vLLM + + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + + if return_model: + model = vLLM(model=pretrained_model_name_or_path, + tensor_parallel_size=tensor_parallel_size) + + return (model, processor) if return_model else processor + + def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'): """ Prepare spacy model for specific language. @@ -530,7 +558,8 @@ def prepare_recognizeAnything_model( 'spacy': prepare_spacy_model, 'diffusion': prepare_diffusion_model, 'video_blip': prepare_video_blip_model, - 'recognizeAnything': prepare_recognizeAnything_model + 'recognizeAnything': prepare_recognizeAnything_model, + 'vllm': prepare_vllm_model } diff --git a/docs/Operators.md b/docs/Operators.md index 75c667359..9e33cd790 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 45 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples | | [ Filter ]( #filter ) | 41 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -68,6 +68,7 @@ All the specific operators are listed below, each featured with several capabili | image_face_blur_mapper | Image | - | Blur faces detected in images | | nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | | nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | +| optimize_instruction_mapper | General | en, zh | Optimize instruction text samples.| | punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | | remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | | remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 0aa4965d3..a5da43e05 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 45 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 41 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | @@ -59,7 +59,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | | extract_qa_mapper | General | en, zh | 从文本中抽取问答对 | | fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | -| generate_instruction_mapper | General | en, zh | 数据增强,生成新样本。 | +| generate_instruction_mapper | General | en, zh | 指令扩充,根据种子数据,生成新的样本。 | | image_blur_mapper | Image | - | 对图像进行模糊处理 | | image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | | image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | @@ -67,6 +67,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | | nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | | nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | +| optimize_instruction_mapper | General | en, zh | 指令优化,优化prompt。| | punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | | remove_bibliography_mapper | LaTeX | en, zh | 删除 TeX 文档的参考文献 | | remove_comments_mapper | LaTeX | en, zh | 删除 TeX 文档中的注释 | diff --git a/environments/science_requires.txt b/environments/science_requires.txt index 0060ffeeb..1afb98427 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -25,3 +25,4 @@ simple-aesthetics-predictor scenedetect[opencv] ffmpeg-python opencv-python +vllm diff --git a/tests/ops/mapper/test_extract_qa_mapper.py b/tests/ops/mapper/test_extract_qa_mapper.py index 6d659b61f..c9c34103b 100644 --- a/tests/ops/mapper/test_extract_qa_mapper.py +++ b/tests/ops/mapper/test_extract_qa_mapper.py @@ -10,10 +10,11 @@ class ExtractQAMapperTest(DataJuicerTestCaseBase): text_key = 'text' - def _run_extract_qa(self, samples): + def _run_extract_qa(self, samples, enable_vllm=False): op = ExtractQAMapper( hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa', - qa_format='chatml' + qa_format='chatml', + enable_vllm=enable_vllm ) for sample in samples: result = op.process(sample) @@ -31,6 +32,13 @@ def test_extract_qa(self): }] self._run_extract_qa(samples) + def test_extract_qa_vllm(self): + samples = [ + { + self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n' + }] + self._run_extract_qa(samples, enable_vllm=True) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_generate_instruction_mapper.py b/tests/ops/mapper/test_generate_instruction_mapper.py index 135579ce6..41ae41dce 100644 --- a/tests/ops/mapper/test_generate_instruction_mapper.py +++ b/tests/ops/mapper/test_generate_instruction_mapper.py @@ -11,11 +11,12 @@ class GenerateInstructionMapperTest(DataJuicerTestCaseBase): text_key = 'text' - def test_generate_instruction(self): + def _run_generate_instruction(self, enable_vllm=False): op = GenerateInstructionMapper( - hf_model='Qwen/Qwen-7B-Chat', + hf_model='Qwen/Qwen-7B-Chat', seed_file='demos/data/demo-dataset-chatml.jsonl', - instruct_num=2 + instruct_num=2, + enable_vllm=enable_vllm ) from data_juicer.format.empty_formatter import EmptyFormatter @@ -28,6 +29,12 @@ def test_generate_instruction(self): # test one output qa sample self.assertIn('role', out_sample['messages'][0]) self.assertIn('content', out_sample['messages'][0]) + + def test_generate_instruction(self): + self._run_generate_instruction() + + def test_generate_instruction_vllm(self): + self._run_generate_instruction(enable_vllm=True) if __name__ == '__main__': diff --git a/tests/ops/mapper/test_optimize_instruction_mapper.py b/tests/ops/mapper/test_optimize_instruction_mapper.py new file mode 100644 index 000000000..9f33c00a2 --- /dev/null +++ b/tests/ops/mapper/test_optimize_instruction_mapper.py @@ -0,0 +1,35 @@ +import unittest +from data_juicer.ops.mapper.optimize_instruction_mapper import OptimizeInstructionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class OptimizeInstructionMapperTest(DataJuicerTestCaseBase): + + text_key = 'text' + + def _run_optimize_instruction(self, enable_vllm=False): + op = OptimizeInstructionMapper( + hf_model='alibaba-pai/Qwen2-7B-Instruct-Refine', + enable_vllm=enable_vllm + ) + + samples = [ + {self.text_key: '鱼香肉丝怎么做?'} + ] + + for sample in samples: + result = op.process(sample) + self.assertIn(self.text_key, result) + + def test_optimize_instruction(self): + self._run_optimize_instruction() + + def test_optimize_instruction_vllm(self): + self._run_optimize_instruction(enable_vllm=True) + + +if __name__ == '__main__': + unittest.main()