Skip to content

Commit

Permalink
default executor + local data; fix analyzer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Jan 27, 2025
1 parent 3c9caf5 commit b9f6a99
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
project_name: 'dataset-ondisk-json'
dataset:
configs:
- type: 'ondisk'
- type: 'local'
path: 'path/to/json/file'
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
project_name: 'dataset-ondisk-parquet'
dataset:
configs:
- type: 'ondisk'
- type: 'local'
path: 'path/to/parquet/file'
4 changes: 2 additions & 2 deletions configs/datasets/mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ project_name: 'dataset-mixture'
dataset:
max_sample_num: 10000
configs:
- type: 'ondisk'
- type: 'local'
weight: 1.0
path: 'path/to/json/file'
- type: 'ondisk'
- type: 'local'
weight: 1.0
path: 'path/to/csv/file'
2 changes: 1 addition & 1 deletion configs/datasets/validation.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dataset:
configs:
- type: ondisk
- type: local
path: path/to/data.json

validators:
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, cfg: Optional[Namespace] = None):

# setup dataset builder
logger.info('Setting up dataset builder...')
self.dataset_builder = DatasetBuilder(cfg, executor_type='local')
self.dataset_builder = DatasetBuilder(cfg, executor_type='default')

# prepare exporter and check export path suffix
# NOTICE: no need to export dataset texts for analyzer
Expand Down Expand Up @@ -86,7 +86,7 @@ def run(self,
load_data_np = self.cfg.np
if dataset is None:
logger.info('Loading dataset from data formatter...')
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
dataset = self.dataset_builder.load_dataset(num_proc=load_data_np)
else:
logger.info(f'Using existing dataset {dataset}')
if self.cfg.auto:
Expand Down
9 changes: 5 additions & 4 deletions data_juicer/core/data/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DatasetBuilder(object):
DatasetBuilder is a class that builds a dataset from a configuration.
"""

def __init__(self, cfg: Namespace, executor_type: str = 'local'):
def __init__(self, cfg: Namespace, executor_type: str = 'default'):
# if generated_dataset_config present, prioritize
if hasattr(
cfg,
Expand Down Expand Up @@ -133,11 +133,12 @@ def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]:
_datasets.append(dataset)

# handle data mixture
if self.executor_type == 'local':
if self.executor_type == 'default':
return NestedDataset(concatenate_datasets(_datasets))
elif self.executor_type == 'ray':
# TODO: support multiple datasets and mixing for ray
assert len(_datasets) == 1, 'Ray setup supports one dataset now'
assert len(
_datasets) == 1, 'Ray setup only supports one dataset now'
return _datasets[0]

@classmethod
Expand Down Expand Up @@ -177,7 +178,7 @@ def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List:
for p, w in zip(paths, weights):
if os.path.isdir(p) or os.path.isfile(p):
# local files
ret['configs'].append({'type': 'ondisk', 'path': p, 'weight': w})
ret['configs'].append({'type': 'local', 'path': p, 'weight': w})
elif (not is_absolute_path(p) and not p.startswith('.')
and p.count('/') <= 1):
# remote huggingface
Expand Down
42 changes: 21 additions & 21 deletions data_juicer/core/data/load_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def register(cls, executor_type: str, data_type: str, data_source: str):
"""
Decorator for registering data load strategies with wildcard support
:param executor_type: Type of executor (e.g., 'local', 'ray')
:param data_type: Type of data (e.g., 'ondisk', 'remote')
:param executor_type: Type of executor (e.g., 'default', 'ray')
:param data_type: Type of data (e.g., 'local', 'remote')
:param data_source: Specific data source (e.g., 'arxiv', 's3')
:return: Decorator function
"""
Expand Down Expand Up @@ -153,7 +153,7 @@ def load_data(self, **kwargs) -> RayDataset:
pass


class LocalDataLoadStrategy(DataLoadStrategy):
class DefaultDataLoadStrategy(DataLoadStrategy):
"""
abstract class for data load strategy for LocalExecutor
"""
Expand All @@ -176,8 +176,8 @@ def load_data(self, **kwargs) -> DJDataset:
# pass


@DataLoadStrategyRegistry.register('ray', 'ondisk', '*')
class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy):
@DataLoadStrategyRegistry.register('ray', 'local', '*')
class RayLocalJsonDataLoadStrategy(RayDataLoadStrategy):

# TODO ray defaults to json

Expand Down Expand Up @@ -212,8 +212,8 @@ def load_data(self, **kwargs):
'Huggingface data load strategy is not implemented')


@DataLoadStrategyRegistry.register('local', 'ondisk', '*')
class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'local', '*')
class DefaultLocalDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for on disk data for LocalExecutor
rely on AutoFormatter for actual data loading
Expand All @@ -239,8 +239,8 @@ def load_data(self, **kwargs):
return formatter.load_dataset()


@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface')
class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'remote', 'huggingface')
class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for Huggingface dataset for LocalExecutor
"""
Expand Down Expand Up @@ -268,19 +268,19 @@ def load_data(self, **kwargs):
global_cfg=self.cfg)


@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope')
class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'remote', 'modelscope')
class DefaultModelScopeDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for ModelScope dataset for LocalExecutor
"""

def load_data(self):
def load_data(self, **kwargs):
raise NotImplementedError(
'ModelScope data load strategy is not implemented')


@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv')
class LocalArxivDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'remote', 'arxiv')
class DefaultArxivDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for arxiv dataset for LocalExecutor
"""
Expand All @@ -293,13 +293,13 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy):
'custom_validators': {}
}

def load_data(self):
def load_data(self, **kwargs):
raise NotImplementedError(
'Arxiv data load strategy is not implemented')


@DataLoadStrategyRegistry.register('local', 'remote', 'wiki')
class LocalWikiDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'remote', 'wiki')
class DefaultWikiDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for wiki dataset for LocalExecutor
"""
Expand All @@ -312,12 +312,12 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy):
'custom_validators': {}
}

def load_data(self):
def load_data(self, **kwargs):
raise NotImplementedError('Wiki data load strategy is not implemented')


@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl')
class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy):
@DataLoadStrategyRegistry.register('default', 'remote', 'commoncrawl')
class DefaultCommonCrawlDataLoadStrategy(DefaultDataLoadStrategy):
"""
data load strategy for commoncrawl dataset for LocalExecutor
"""
Expand All @@ -336,6 +336,6 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy):
}
}

def load_data(self):
def load_data(self, **kwargs):
raise NotImplementedError(
'CommonCrawl data load strategy is not implemented')
2 changes: 1 addition & 1 deletion data_juicer/core/executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import ExecutorBase
from .default_executor import Executor
from .factory import ExecutorFactory
from .local_executor import Executor
from .ray_executor import RayExecutor

__all__ = ['ExecutorBase'
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion data_juicer/core/executor/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from .local_executor import Executor
from .default_executor import Executor
from .ray_executor import RayExecutor


Expand Down
4 changes: 2 additions & 2 deletions tests/core/data/test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
project_name: 'dataset-ondisk-json'
project_name: 'dataset-local-json'
dataset:
configs:
- type: 'ondisk'
- type: 'local'
path: 'sample.json'
6 changes: 3 additions & 3 deletions tests/core/data/test_config_list.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
project_name: 'dataset-ondisk-list'
project_name: 'dataset-local-list'
dataset:
configs:
- type: 'ondisk'
- type: 'local'
path: 'sample.json'
- type: 'ondisk'
- type: 'local'
path: 'sample.txt'
2 changes: 1 addition & 1 deletion tests/core/data/test_config_ray.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
project_name: 'ray-demo-new-config'
dataset:
configs:
- type: ondisk
- type: local
path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file
weight: 1.0

Expand Down
46 changes: 23 additions & 23 deletions tests/core/test_dataload_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,27 @@ def setUp(self):

def test_exact_match(self):
# Register a specific strategy
@DataLoadStrategyRegistry.register("local", 'ondisk', 'json')
@DataLoadStrategyRegistry.register("default", 'local', 'json')
class TestStrategy(MockStrategy):
pass

# Test exact match
strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'json')
"default", 'local', 'json')
self.assertEqual(strategy, TestStrategy)

# Test no match
strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'csv')
"default", 'local', 'csv')
self.assertIsNone(strategy)

def test_wildcard_matching(self):
# Register strategies with different wildcard patterns
@DataLoadStrategyRegistry.register("local", 'ondisk', '*')
@DataLoadStrategyRegistry.register("default", 'local', '*')
class AllFilesStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register("local", '*', '*')
@DataLoadStrategyRegistry.register("default", '*', '*')
class AllLocalStrategy(MockStrategy):
pass

Expand All @@ -44,11 +44,11 @@ class FallbackStrategy(MockStrategy):

# Test specific matches
strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'json')
"default", 'local', 'json')
self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'remote', 'json')
"default", 'remote', 'json')
self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class(
Expand All @@ -60,29 +60,29 @@ def test_specificity_priority(self):
class GeneralStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register("local", '*', '*')
@DataLoadStrategyRegistry.register("default", '*', '*')
class LocalStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register("local", 'ondisk', '*')
@DataLoadStrategyRegistry.register("default", 'local', '*')
class LocalOndiskStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register("local", 'ondisk', 'json')
@DataLoadStrategyRegistry.register("default", 'local', 'json')
class ExactStrategy(MockStrategy):
pass

# Test matching priority
strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'json')
"default", 'local', 'json')
self.assertEqual(strategy, ExactStrategy) # Should match exact first

strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'csv')
"default", 'local', 'csv')
self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'remote', 'json')
"default", 'remote', 'json')
self.assertEqual(strategy, LocalStrategy) # Should match two wildcards

strategy = DataLoadStrategyRegistry.get_strategy_class(
Expand All @@ -91,41 +91,41 @@ class ExactStrategy(MockStrategy):

def test_pattern_matching(self):
@DataLoadStrategyRegistry.register(
"local", 'ondisk', '*.json')
"default", 'local', '*.json')
class JsonStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register(
"local", 'ondisk', 'data_[0-9]*')
"default", 'local', 'data_[0-9]*')
class NumberedDataStrategy(MockStrategy):
pass

# Test pattern matching
strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'test.json')
"default", 'local', 'test.json')
self.assertEqual(strategy, JsonStrategy)

strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'data_123')
"default", 'local', 'data_123')
self.assertEqual(strategy, NumberedDataStrategy)

strategy = DataLoadStrategyRegistry.get_strategy_class(
"local", 'ondisk', 'test.csv')
"default", 'local', 'test.csv')
self.assertIsNone(strategy)

def test_strategy_key_matches(self):
# Test StrategyKey matching directly
wildcard_key = StrategyKey("*", 'ondisk', '*.json')
specific_key = StrategyKey("local", 'ondisk', 'test.json')
wildcard_key = StrategyKey("*", 'local', '*.json')
specific_key = StrategyKey("default", 'local', 'test.json')

# Exact keys don't match wildcards
self.assertTrue(wildcard_key.matches(specific_key))
self.assertFalse(specific_key.matches(wildcard_key))

# Test pattern matching
pattern_key = StrategyKey("local", '*', 'data_[0-9]*')
match_key = StrategyKey("local", 'ondisk', 'data_123')
no_match_key = StrategyKey("local", 'ondisk', 'data_abc')
pattern_key = StrategyKey("default", '*', 'data_[0-9]*')
match_key = StrategyKey("default", 'local', 'data_123')
no_match_key = StrategyKey("default", 'local', 'data_abc')

self.assertTrue(pattern_key.matches(match_key))
self.assertFalse(pattern_key.matches(no_match_key))
Loading

0 comments on commit b9f6a99

Please sign in to comment.