From 0815c2976bb09e591aefff564e1a452bc67c4675 Mon Sep 17 00:00:00 2001 From: BeachWang <1400012807@pku.edu.cn> Date: Fri, 17 Jan 2025 16:12:55 +0800 Subject: [PATCH] Refine/llm api op unittest (#528) * * update unittests * tags specified field * doc done * + add reference * move mm tags * move meta key * done * test done * rm nested set * enable op error for unittest * enhance api unittest * expose skip_op_error * fix typo --------- Co-authored-by: null <3213204+drcege@users.noreply.github.com> Co-authored-by: gece.gc Co-authored-by: lielin.hyl --- configs/config_all.yaml | 2 + data_juicer/config/config.py | 10 ++- data_juicer/ops/base_op.py | 67 +++++++++++++------ tests/config/test_config_funcs.py | 7 ++ .../test_entity_attribute_aggregator.py | 8 ++- .../test_most_relavant_entities_aggregator.py | 8 ++- .../ops/aggregator/test_nested_aggregator.py | 6 +- .../test_dialog_intent_detection_mapper.py | 2 + .../test_dialog_sentiment_detection_mapper.py | 2 + .../test_dialog_sentiment_intensity_mapper.py | 1 + .../test_dialog_topic_detection_mapper.py | 2 + .../test_extract_entity_attribute_mapper.py | 4 ++ .../test_extract_entity_relation_mapper.py | 4 ++ tests/ops/mapper/test_extract_event_mapper.py | 3 +- .../ops/mapper/test_extract_keyword_mapper.py | 2 + .../mapper/test_extract_nickname_mapper.py | 1 + .../test_extract_support_text_mapper.py | 2 + .../mapper/test_relation_identity_mapper.py | 2 + 18 files changed, 103 insertions(+), 30 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index de74f724f..586c7509e 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -13,6 +13,8 @@ np: 4 # number of subproce text_keys: 'text' # the key name of field where the sample texts to be processed, e.g., `text`, `instruction`, `output`, ... # Note: currently, we support specify only ONE key for each op, for cases requiring multiple keys, users can specify the op multiple times. We will only use the first key of `text_keys` when you set multiple keys. suffixes: [] # the suffix of files that will be read. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] +turbo: false # Enable Turbo mode to maximize processing speed when batch size is 1. +skip_op_error: true # Skip errors in OPs caused by unexpected invalid samples. use_cache: true # whether to use the cache management of Hugging Face datasets. It might take up lots of disk space when using cache ds_cache_dir: null # cache dir for Hugging Face datasets. In default, it\'s the same as the environment variable `HF_DATASETS_CACHE`, whose default value is usually "~/.cache/huggingface/datasets". If this argument is set to a valid path by users, it will override the default cache dir open_monitor: true # Whether to open the monitor to trace resource utilization for each OP during data processing. It\'s True in default. diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index e2b9252ad..7f3aa5a52 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -219,8 +219,13 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None): '--turbo', type=bool, default=False, - help='Enable Turbo mode to maximize processing speed. Stability ' - 'features like fault tolerance will be disabled.') + help='Enable Turbo mode to maximize processing speed when batch size ' + 'is 1.') + parser.add_argument( + '--skip_op_error', + type=bool, + default=True, + help='Skip errors in OPs caused by unexpected invalid samples.') parser.add_argument( '--use_cache', type=bool, @@ -550,6 +555,7 @@ def init_setup_from_cfg(cfg: Namespace): 'video_key': cfg.video_key, 'num_proc': cfg.np, 'turbo': cfg.turbo, + 'skip_op_error': cfg.skip_op_error, 'work_dir': cfg.work_dir, } cfg.process = update_op_attr(cfg.process, op_attrs) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 698203f37..58efbbb5e 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -47,7 +47,7 @@ def wrapper(sample, *args, **kwargs): return wrapper -def catch_map_batches_exception(method, op_name=None): +def catch_map_batches_exception(method, skip_op_error=False, op_name=None): """ For batched-map sample-level fault tolerance. """ @@ -61,6 +61,8 @@ def wrapper(samples, *args, **kwargs): try: return method(samples, *args, **kwargs) except Exception as e: + if not skip_op_error: + raise from loguru import logger logger.error(f'An error occurred in {op_name} when processing ' f'samples "{samples}" -- {type(e)}: {e}') @@ -72,7 +74,10 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method, return_sample=True, op_name=None): +def catch_map_single_exception(method, + return_sample=True, + skip_op_error=False, + op_name=None): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. @@ -103,6 +108,8 @@ def wrapper(sample, *args, **kwargs): else: return [res] except Exception as e: + if skip_op_error: + raise from loguru import logger logger.error(f'An error occurred in {op_name} when processing ' f'sample "{sample}" -- {type(e)}: {e}') @@ -157,6 +164,10 @@ def __init__(self, *args, **kwargs): self.batch_size = kwargs.get('batch_size', 1000) self.work_dir = kwargs.get('work_dir', None) + # for unittest, do not skip the error. + # It would be set to be True in config init. + self.skip_op_error = kwargs.get('skip_op_error', False) + # whether the model can be accelerated using cuda _accelerator = kwargs.get('accelerator', None) if _accelerator is not None: @@ -278,11 +289,15 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.process = catch_map_batches_exception(self.process_batched, - op_name=self._name) + self.process = catch_map_batches_exception( + self.process_batched, + skip_op_error=self.skip_op_error, + op_name=self._name) else: - self.process = catch_map_single_exception(self.process_single, - op_name=self._name) + self.process = catch_map_single_exception( + self.process_single, + skip_op_error=self.skip_op_error, + op_name=self._name) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): @@ -369,15 +384,23 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): self.compute_stats = catch_map_batches_exception( - self.compute_stats_batched, op_name=self._name) - self.process = catch_map_batches_exception(self.process_batched, - op_name=self._name) + self.compute_stats_batched, + skip_op_error=self.skip_op_error, + op_name=self._name) + self.process = catch_map_batches_exception( + self.process_batched, + skip_op_error=self.skip_op_error, + op_name=self._name) else: self.compute_stats = catch_map_single_exception( - self.compute_stats_single, op_name=self._name) - self.process = catch_map_single_exception(self.process_single, - return_sample=False, - op_name=self._name) + self.compute_stats_single, + skip_op_error=self.skip_op_error, + op_name=self._name) + self.process = catch_map_single_exception( + self.process_single, + return_sample=False, + skip_op_error=self.skip_op_error, + op_name=self._name) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): @@ -486,11 +509,15 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.compute_hash = catch_map_batches_exception(self.compute_hash, - op_name=self._name) + self.compute_hash = catch_map_batches_exception( + self.compute_hash, + skip_op_error=self.skip_op_error, + op_name=self._name) else: - self.compute_hash = catch_map_single_exception(self.compute_hash, - op_name=self._name) + self.compute_hash = catch_map_single_exception( + self.compute_hash, + skip_op_error=self.skip_op_error, + op_name=self._name) def compute_hash(self, sample): """ @@ -626,8 +653,10 @@ def __init__(self, *args, **kwargs): queries and responses """ super(Aggregator, self).__init__(*args, **kwargs) - self.process = catch_map_single_exception(self.process_single, - op_name=self._name) + self.process = catch_map_single_exception( + self.process_single, + skip_op_error=self.skip_op_error, + op_name=self._name) def process_single(self, sample): """ diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 2d5578478..9ae5bd55e 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -56,6 +56,7 @@ def test_yaml_cfg_file(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }, 'nested dict load fail, for nonparametric op') @@ -79,6 +80,7 @@ def test_yaml_cfg_file(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }, 'nested dict load fail, un-expected internal value') @@ -151,6 +153,7 @@ def test_mixture_cfg(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }) @@ -174,6 +177,7 @@ def test_mixture_cfg(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }) @@ -197,6 +201,7 @@ def test_mixture_cfg(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }) @@ -220,6 +225,7 @@ def test_mixture_cfg(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }) @@ -243,6 +249,7 @@ def test_mixture_cfg(self): 'turbo': False, 'batch_size': 1000, 'index_key': None, + 'skip_op_error': True, 'work_dir': WORKDIR, } }) diff --git a/tests/ops/aggregator/test_entity_attribute_aggregator.py b/tests/ops/aggregator/test_entity_attribute_aggregator.py index ff390b6fd..1d6a4b1df 100644 --- a/tests/ops/aggregator/test_entity_attribute_aggregator.py +++ b/tests/ops/aggregator/test_entity_attribute_aggregator.py @@ -5,13 +5,13 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.aggregator import EntityAttributeAggregator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.constant import Fields, BatchMetaKeys, MetaKeys @SKIPPED_TESTS.register_module() class EntityAttributeAggregatorTest(DataJuicerTestCaseBase): - def _run_helper(self, op, samples): + def _run_helper(self, op, samples, output_key=BatchMetaKeys.entity_attribute): # before runing this test, set below environment variables: # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ @@ -23,6 +23,8 @@ def _run_helper(self, op, samples): for data in new_dataset: for k in data: logger.info(f"{k}: {data[k]}") + self.assertIn(output_key, data[Fields.batch_meta]) + self.assertNotEqual(data[Fields.batch_met][output_key], '') self.assertEqual(len(new_dataset), len(samples)) @@ -64,7 +66,7 @@ def test_input_output(self): input_key='sub_docs', output_key='text' ) - self._run_helper(op, samples) + self._run_helper(op, samples, output_key='text') def test_max_token_num(self): samples = [ diff --git a/tests/ops/aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py index 062cad43d..5912877ca 100644 --- a/tests/ops/aggregator/test_most_relavant_entities_aggregator.py +++ b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py @@ -6,13 +6,13 @@ from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS -from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.constant import Fields, BatchMetaKeys, MetaKeys @SKIPPED_TESTS.register_module() class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase): - def _run_helper(self, op, samples): + def _run_helper(self, op, samples, output_key=BatchMetaKeys.most_relavant_entities): # before runing this test, set below environment variables: # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ @@ -24,6 +24,8 @@ def _run_helper(self, op, samples): for data in new_dataset: for k in data: logger.info(f"{k}: {data[k]}") + self.assertIn(output_key, data[Fields.batch_meta]) + self.assertNotEqual(data[Fields.batch_meta][output_key], '') self.assertEqual(len(new_dataset), len(samples)) @@ -67,7 +69,7 @@ def test_input_output(self): input_key='events', output_key='relavant_roles' ) - self._run_helper(op, samples) + self._run_helper(op, samples, output_key='relavant_roles') def test_max_token_num(self): samples = [ diff --git a/tests/ops/aggregator/test_nested_aggregator.py b/tests/ops/aggregator/test_nested_aggregator.py index 0d16648df..697e17e95 100644 --- a/tests/ops/aggregator/test_nested_aggregator.py +++ b/tests/ops/aggregator/test_nested_aggregator.py @@ -12,7 +12,7 @@ @SKIPPED_TESTS.register_module() class NestedAggregatorTest(DataJuicerTestCaseBase): - def _run_helper(self, op, samples): + def _run_helper(self, op, samples, output_key=MetaKeys.event_description): # before runing this test, set below environment variables: # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ @@ -24,6 +24,8 @@ def _run_helper(self, op, samples): for data in new_dataset: for k in data: logger.info(f"{k}: {data[k]}") + self.assertIn(output_key, data[Fields.batch_meta]) + self.assertNotEqual(data[Fields.batch_meta][output_key], '') self.assertEqual(len(new_dataset), len(samples)) @@ -61,7 +63,7 @@ def test_input_output(self): input_key='sub_docs', output_key='text' ) - self._run_helper(op, samples) + self._run_helper(op, samples, output_key='text') def test_max_token_num_1(self): samples = [ diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py index b6bb35cab..d2a44ab65 100644 --- a/tests/ops/mapper/test_dialog_intent_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py @@ -27,6 +27,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') logger.info(f'意图:{labels}') + self.assertNotEqual(analysis, '') + self.assertNotEqual(labels, '') self.assertEqual(len(analysis_list), target_len) self.assertEqual(len(labels_list), target_len) diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py index ac6236282..5f0763149 100644 --- a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -28,6 +28,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') logger.info(f'情绪:{labels}') + self.assertNotEqual(analysis, '') + self.assertNotEqual(labels, '') self.assertEqual(len(analysis_list), target_len) self.assertEqual(len(labels_list), target_len) diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index ed7de409a..93fdc54f6 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -28,6 +28,7 @@ def _run_op(self, op, samples, target_len, intensities_key=None, analysis_key=No for analysis, intensity in zip(analysis_list, intensity_list): logger.info(f'分析:{analysis}') logger.info(f'情绪:{intensity}') + self.assertNotEqual(analysis, '') self.assertEqual(len(analysis_list), target_len) self.assertEqual(len(intensity_list), target_len) diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py index f1dc1d9cb..d6d1f5e3d 100644 --- a/tests/ops/mapper/test_dialog_topic_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -29,6 +29,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') logger.info(f'话题:{labels}') + self.assertNotEqual(analysis, '') + self.assertNotEqual(labels, '') self.assertEqual(len(analysis_list), target_len) self.assertEqual(len(labels_list), target_len) diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index 9707b2beb..0ef2579e2 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -49,6 +49,10 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = op.run(dataset) for sample in dataset: + self.assertIn(MetaKeys.main_entities, sample[Fields.meta]) + self.assertIn(MetaKeys.attributes, sample[Fields.meta]) + self.assertIn(MetaKeys.attribute_descriptions, sample[Fields.meta]) + self.assertIn(MetaKeys.attribute_support_texts, sample[Fields.meta]) ents = sample[Fields.meta][MetaKeys.main_entities] attrs = sample[Fields.meta][MetaKeys.attributes] descs = sample[Fields.meta][MetaKeys.attribute_descriptions] diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py index a4c413a33..053881e24 100644 --- a/tests/ops/mapper/test_extract_entity_relation_mapper.py +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -56,6 +56,10 @@ def _run_op(self, op): dataset = Dataset.from_list(samples) dataset = op.run(dataset) sample = dataset[0] + self.assertIn(MetaKeys.entity, sample[Fields.meta]) + self.assertIn(MetaKeys.relation, sample[Fields.meta]) + self.assertNotEqual(len(sample[Fields.meta][MetaKeys.entity]), 0) + self.assertNotEqual(len(sample[Fields.meta][MetaKeys.relation]), 0) logger.info(f"entitis: {sample[Fields.meta][MetaKeys.entity]}") logger.info(f"relations: {sample[Fields.meta][MetaKeys.relation]}") diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index 4c7f47a2b..8da2caf8a 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -59,8 +59,9 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = op.run(dataset) - self.assertNotEqual(len(dataset), 0) for sample in dataset: + self.assertIn(MetaKeys.event_description, sample[Fields.meta]) + self.assertIn(MetaKeys.relevant_characters, sample[Fields.meta]) logger.info(f"chunk_id: {sample['chunk_id']}") self.assertEqual(sample['chunk_id'], 0) logger.info(f"event: {sample[Fields.meta][MetaKeys.event_description]}") diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py index 8528be5d4..47b30d687 100644 --- a/tests/ops/mapper/test_extract_keyword_mapper.py +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -59,6 +59,8 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = op.run(dataset) sample = dataset[0] + self.assertIn(MetaKeys.keyword, sample[Fields.meta]) + self.assertNotEqual(len(sample[Fields.meta][MetaKeys.keyword]), 0) logger.info(f"keywords: {sample[Fields.meta][MetaKeys.keyword]}") def test(self): diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index a869bda92..df204c13e 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -38,6 +38,7 @@ def _run_op(self, api_model, response_path=None): dataset = Dataset.from_list(samples) dataset = op.run(dataset) + self.assertIn(MetaKeys.nickname, dataset[0][Fields.meta]) result = dataset[0][Fields.meta][MetaKeys.nickname] result = [( d[MetaKeys.source_entity], diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py index d4d920fe8..4ee2652c2 100644 --- a/tests/ops/mapper/test_extract_support_text_mapper.py +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -62,7 +62,9 @@ def _run_op(self, api_model): dataset = Dataset.from_list(samples) dataset = op.run(dataset) sample = dataset[0] + self.assertIn(MetaKeys.support_text, sample[Fields.meta]) logger.info(f"support_text: \n{sample[Fields.meta][MetaKeys.support_text]}") + self.assertNotEqual(sample[Fields.meta][MetaKeys.support_text], '') def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py index 3a243189b..57382f988 100644 --- a/tests/ops/mapper/test_relation_identity_mapper.py +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -49,6 +49,8 @@ def _run_op(self, api_model, output_key=MetaKeys.role_relation): for data in dataset: for k in data: logger.info(f"{k}: {data[k]}") + self.assertIn(output_key, data[Fields.meta]) + self.assertNotEqual(data[Fields.meta][output_key], '') def test_default(self): self._run_op('qwen2.5-72b-instruct')