Skip to content

Commit

Permalink
add op optimize_instruction_mapper and support vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Jul 16, 2024
1 parent b14d846 commit dcc0df3
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 48 deletions.
14 changes: 9 additions & 5 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ process:
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_qa_mapper: # mapper to extract question and answer pair from text.
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa'
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # model name on huggingface to extract question and answer pair.
- fix_unicode_mapper: # fix unicode errors in text.
- generate_instruction_mapper: # generate new instruction text data.
hf_model: 'Qwen/Qwen-7B-Chat'
seed_file: 'demos/data/demo-dataset-chatml.jsonl'
instruct_num: 3
similarity_threshold: 0.7
hf_model: 'Qwen/Qwen-7B-Chat' # model name on huggingface to generate instruction.
seed_file: 'demos/data/demo-dataset-chatml.jsonl' # Seed file as instruction samples to generate new instructions, chatml format.
instruct_num: 3 # the number of generated samples.
similarity_threshold: 0.7 # the similarity score threshold between the generated samples and the seed samples.Range from 0 to 1. Samples with similarity score less than this threshold will be kept.
- image_blur_mapper: # mapper to blur images.
p: 0.2 # probability of the image being blured
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
Expand Down Expand Up @@ -134,6 +134,10 @@ process:
delete_random_char: false # whether to open the augmentation method of deleting random characters from the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据增强"
swap_random_char: false # whether to open the augmentation method of swapping random contiguous characters in the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据强增方法"
replace_equivalent_num: false # whether to open the augmentation method of replacing random numbers with their equivalent representations in the original texts. **Notice**: Only for numbers for now. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有伍种不同的数据增强方法"
- optimize_instruction_mapper: # optimize instruction.
hf_model: 'alibaba-pai/Qwen2-7B-Instruct-Refine' # model name on huggingface to optimize instruction
enable_vllm: false # whether to use vllm for inference acceleration.
tensor_parallel_size: 1 # it is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
extract_qa_mapper, fix_unicode_mapper, image_blur_mapper,
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
image_diffusion_mapper, image_face_blur_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper,
punctuation_normalization_mapper, remove_bibliography_mapper,
remove_comments_mapper, remove_header_mapper,
remove_long_words_mapper, remove_non_chinese_character_mapper,
Expand Down Expand Up @@ -42,6 +42,7 @@
from .image_face_blur_mapper import ImageFaceBlurMapper
from .nlpaug_en_mapper import NlpaugEnMapper
from .nlpcda_zh_mapper import NlpcdaZhMapper
from .optimize_instruction_mapper import OptimizeInstructionMapper
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
Expand Down Expand Up @@ -93,6 +94,7 @@
'VideoFFmpegWrappedMapper',
'ChineseConvertMapper',
'NlpcdaZhMapper',
'OptimizeInstructionMapper',
'ImageBlurMapper',
'CleanCopyrightMapper',
'RemoveNonChineseCharacterlMapper',
Expand Down
31 changes: 25 additions & 6 deletions data_juicer/ops/mapper/extract_qa_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ def __init__(self,
hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa',
pattern: str = None,
qa_format: str = 'chatml',
enable_vllm=False,
tensor_parallel_size=1,
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param pattern: regular expression pattern to search for within text.
:param qa_format: Output format of question and answer pair.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param tensor_parallel_size: It is only valid when enable_vllm is True.
:param args: extra args
:param kwargs: extra args
Expand All @@ -59,8 +63,17 @@ def __init__(self,
self.pattern = pattern

self.qa_format = qa_format
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)
self.enable_vllm = enable_vllm

if enable_vllm:
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
tensor_parallel_size=tensor_parallel_size)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model)

def _extract_qa(self, output):
"""Extract qestion and answer pair from model output response."""
Expand All @@ -78,10 +91,16 @@ def _extract_qa(self, output):
def process(self, sample, rank=None):
model, processor = get_model(self.model_key, rank=rank)

inputs = processor(sample[self.text_key],
return_tensors='pt').to(model.device)
response = model.generate(**inputs)
output = processor.decode(response.cpu()[0], skip_special_tokens=True)
if self.enable_vllm:
response = model.generate([sample[self.text_key]])
output = response[0].outputs[0].text
else:
inputs = processor(sample[self.text_key],
return_tensors='pt').to(model.device)
response = model.generate(**inputs)
output = processor.decode(response.cpu()[0],
skip_special_tokens=True)

qa_list = self._extract_qa(output)

if not len(qa_list):
Expand Down
71 changes: 44 additions & 27 deletions data_juicer/ops/mapper/generate_instruction_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self,
instruct_num,
similarity_threshold=0.7,
prompt_template=None,
enable_vllm=False,
tensor_parallel_size=1,
*args,
**kwargs):
"""
Expand All @@ -61,6 +63,8 @@ def __init__(self,
:param prompt_template: Prompt template for generate samples.
Please make sure the template contains "{augmented_data}",
which corresponds to the augmented samples.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param tensor_parallel_size: It is only valid when enable_vllm is True.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -74,9 +78,17 @@ def __init__(self,
if prompt_template is None:
prompt_template = DEFAULT_PROMPT_TEMPLATE
self.prompt_template = prompt_template

self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)
self.enable_vllm = enable_vllm

if enable_vllm:
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
tensor_parallel_size=tensor_parallel_size)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model)

self.seed_qa_samples = self.load_seed_qa_samples(seed_file)

Expand Down Expand Up @@ -168,30 +180,35 @@ def process(self, sample=None, rank=None):
self.instruct_num)
input_prompt = self.build_prompt(random_qa_samples,
self.prompt_template)
inputs = processor(input_prompt, return_tensors='pt').to(model.device)
response = model.generate(**inputs)
output_response = processor.decode(response.cpu()[0],
skip_special_tokens=True)
if output_response:
out_qa_pairs, response_str = self.parse_response(output_response)

if self.similarity_type == 'rouge_l':
sim_score = self.max_rouge_l_score(response_str,
self.reference_samples)
else:
raise ValueError(
f'Not support similarity type "{self.similarity_type}"!')

message_list = []
if sim_score <= self.similarity_threshold:
for question, answer in out_qa_pairs:
message_list.append({'role': 'user', 'content': question})
message_list.append({
'role': 'assistant',
'content': answer
})
else:
logging.info('Filter one instance due to similarity.')
if self.enable_vllm:
response = model.generate([input_prompt])
response_str = response[0].outputs[0].text
else:
inputs = processor(input_prompt,
return_tensors='pt').to(model.device)
response = model.generate(**inputs)
response_str = processor.decode(response.cpu()[0],
skip_special_tokens=True)

message_list = []
out_qa_pairs, response_str = self.parse_response(response_str)

if not response_str:
return {self.text_key: json.dumps({'messages': message_list})}

if self.similarity_type == 'rouge_l':
sim_score = self.max_rouge_l_score(response_str,
self.reference_samples)
else:
raise ValueError(
f'Not support similarity type "{self.similarity_type}"!')

if sim_score <= self.similarity_threshold:
for question, answer in out_qa_pairs:
message_list.append({'role': 'user', 'content': question})
message_list.append({'role': 'assistant', 'content': answer})
else:
logging.info('Filter this generated sample due to similarity.')

return {
self.text_key:
Expand Down
76 changes: 76 additions & 0 deletions data_juicer/ops/mapper/optimize_instruction_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

DEFAULT_SYSTEM_PROMPT = '请优化这个指令,将其修改为一个更详细具体的指令。'


@OPERATORS.register_module('optimize_instruction_mapper')
class OptimizeInstructionMapper(Mapper):
"""Mapper to optimize instruction.
Recommended model list: [
alibaba-pai/Qwen2-1.5B-Instruct-Refine
alibaba-pai/Qwen2-7B-Instruct-Refine
]
"""

def __init__(self,
hf_model='alibaba-pai/Qwen2-7B-Instruct-Refine',
system_prompt=None,
enable_vllm=False,
tensor_parallel_size=1,
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param system_prompt: System prompt for optimize samples.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param tensor_parallel_size: It is only valid when enable_vllm is True.
The number of GPUs to use for distributed execution with tensor
parallelism.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)

if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
self.system_prompt = system_prompt
self.enable_vllm = enable_vllm

if enable_vllm:
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
tensor_parallel_size=tensor_parallel_size)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model)

def process(self, sample=None, rank=None):
model, processor = get_model(self.model_key, rank=rank)

messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': sample[self.text_key]
}]
input_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

if self.enable_vllm:
response = model.generate([input_prompt])
output = response[0].outputs[0].text
else:
inputs = processor(input_prompt,
return_tensors='pt').to(model.device)
response = model.generate(**inputs)
output = processor.decode(response.cpu()[0],
skip_special_tokens=True)

sample[self.text_key] = output

return sample
31 changes: 30 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,34 @@ def prepare_huggingface_model(pretrained_model_name_or_path,
return (model, processor) if return_model else processor


def prepare_vllm_model(pretrained_model_name_or_path,
return_model=True,
trust_remote_code=False,
tensor_parallel_size=1):
"""
Prepare and load a HuggingFace model with the correspoding processor.
:param pretrained_model_name_or_path: model name or path
:param return_model: return model or not
:param trust_remote_code: passed to transformers
:param tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
:return: a tuple (model, input processor) if `return_model` is True;
otherwise, only the processor is returned.
"""
from transformers import AutoProcessor
from vllm import LLM as vLLM

processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code)

if return_model:
model = vLLM(model=pretrained_model_name_or_path,
tensor_parallel_size=tensor_parallel_size)

return (model, processor) if return_model else processor


def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'):
"""
Prepare spacy model for specific language.
Expand Down Expand Up @@ -530,7 +558,8 @@ def prepare_recognizeAnything_model(
'spacy': prepare_spacy_model,
'diffusion': prepare_diffusion_model,
'video_blip': prepare_video_blip_model,
'recognizeAnything': prepare_recognizeAnything_model
'recognizeAnything': prepare_recognizeAnything_model,
'vllm': prepare_vllm_model
}


Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 45 | Edits and transforms samples |
| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |
Expand Down Expand Up @@ -68,6 +68,7 @@ All the specific operators are listed below, each featured with several capabili
| image_face_blur_mapper | Image | - | Blur faces detected in images |
| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library |
| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library |
| optimize_instruction_mapper | General | en, zh | Optimize instruction text samples.|
| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents |
| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents |
| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents |
Expand Down
Loading

0 comments on commit dcc0df3

Please sign in to comment.