Skip to content

Commit

Permalink
fix bugs; use str for executor_type
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Jan 23, 2025
1 parent 195aff8 commit dd95df0
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 113 deletions.
5 changes: 2 additions & 3 deletions data_juicer/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .adapter import Adapter
from .analyzer import Analyzer
from .data import NestedDataset
from .executor import (Executor, ExecutorBase, ExecutorFactory, ExecutorType,
RayExecutor)
from .executor import Executor, ExecutorFactory, RayExecutor
from .executor.base import ExecutorBase
from .exporter import Exporter
from .monitor import Monitor
from .tracer import Tracer
Expand All @@ -15,7 +15,6 @@
'Executor',
'RayExecutor',
'ExecutorBase',
'ExecutorType',
'Exporter',
'Monitor',
'Tracer',
Expand Down
13 changes: 4 additions & 9 deletions data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis
from data_juicer.config import init_configs
from data_juicer.format import load_formatter
from data_juicer.core.data.dataset_builder import DatasetBuilder
from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS, Filter, load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils import cache_utils
Expand Down Expand Up @@ -44,14 +44,9 @@ def __init__(self, cfg: Optional[Namespace] = None):
f'[{self.cfg.cache_compress}]')
cache_utils.CACHE_COMPRESS = self.cfg.cache_compress

# setup formatter
logger.info('Setting up data formatter...')
self.formatter = load_formatter(
dataset_path=self.cfg.dataset_path,
generated_dataset_config=self.cfg.generated_dataset_config,
text_keys=self.cfg.text_keys,
suffixes=self.cfg.suffixes,
add_suffix=self.cfg.add_suffix)
# setup dataset builder
logger.info('Setting up dataset builder...')
self.dataset_builder = DatasetBuilder(cfg, executor_type='local')

# prepare exporter and check export path suffix
# NOTICE: no need to export dataset texts for analyzer
Expand Down
25 changes: 15 additions & 10 deletions data_juicer/core/data/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from data_juicer.core.data.data_validator import DataValidatorRegistry
from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry
from data_juicer.core.data.ray_dataset import RayDataset
from data_juicer.core.executor.base import ExecutorType
from data_juicer.utils.file_utils import is_absolute_path
from data_juicer.utils.sample import random_sample

Expand All @@ -23,17 +22,20 @@ class DatasetBuilder(object):

def __init__(self, cfg: Namespace, executor_type: str = 'local'):
# if generated_dataset_config present, prioritize
if cfg.generated_dataset_config:
if hasattr(
cfg,
'generated_dataset_config') and cfg.generated_dataset_config:
self.use_generated_dataset_config = True
self.generated_dataset_config = cfg.generated_dataset_config
return
self.use_generated_dataset_config = False

self.cfg = cfg
self.executor_type = ExecutorType(executor_type)
self.executor_type = executor_type

if cfg.dataset_path is not None:
if hasattr(cfg, 'dataset_path') and cfg.dataset_path is not None:
ds_configs = rewrite_cli_datapath(cfg.dataset_path)
elif cfg.dataset is not None:
elif hasattr(cfg, 'dataset') and cfg.dataset is not None:
ds_configs = cfg.dataset
else:
raise ConfigValidationError(
Expand Down Expand Up @@ -80,8 +82,8 @@ def __init__(self, cfg: Namespace, executor_type: str = 'local'):
data_source = ds_config.get('source', None)
self.load_strategies.append(
DataLoadStrategyRegistry.get_strategy_class(
self.executor_type.value, data_type,
data_source)(ds_config, cfg=self.cfg))
self.executor_type, data_type, data_source)(ds_config,
cfg=self.cfg))

# initialzie the sample numbers
self.max_sample_num = ds_configs.get('max_sample_num', None)
Expand All @@ -90,6 +92,9 @@ def __init__(self, cfg: Namespace, executor_type: str = 'local'):
self.weights = [stra.weight for stra in self.load_strategies]
self.sample_numbers = get_sample_numbers(self.weights,
self.max_sample_num)
else:
self.weights = [1.0 for stra in self.load_strategies]
self.sample_numbers = [None for stra in self.load_strategies]

# initialize data validators
self.validators = []
Expand Down Expand Up @@ -125,9 +130,9 @@ def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]:
_datasets.append(dataset)

# handle data mixture
if self.executor_type == ExecutorType.LOCAL:
if self.executor_type == 'local':
return NestedDataset(concatenate_datasets(_datasets))
elif self.executor_type == ExecutorType.RAY:
elif self.executor_type == 'ray':
# TODO: support multiple datasets and mixing for ray
assert len(_datasets) == 1, 'Ray setup supports one dataset now'
return _datasets[0]
Expand Down Expand Up @@ -169,7 +174,7 @@ def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List:
for p, w in zip(paths, weights):
if os.path.isdir(p) or os.path.isfile(p):
# local files
ret['configs'].append({'type': 'ondisk', 'path': [p], 'weight': w})
ret['configs'].append({'type': 'ondisk', 'path': p, 'weight': w})
elif (not is_absolute_path(p) and not p.startswith('.')
and p.count('/') <= 1):
# remote huggingface
Expand Down
45 changes: 22 additions & 23 deletions data_juicer/core/data/load_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from data_juicer.core.data import DJDataset, RayDataset
from data_juicer.core.data.config_validator import ConfigValidator
from data_juicer.core.executor.base import ExecutorType
from data_juicer.download.downloader import validate_snapshot_format
from data_juicer.format.formatter import unify_format
from data_juicer.format.load import load_formatter
Expand All @@ -27,7 +26,7 @@ class StrategyKey:
"""
Immutable key for strategy registration with wildcard support
"""
executor_type: ExecutorType
executor_type: str
data_type: str
data_source: str

Expand All @@ -41,8 +40,7 @@ def matches(self, other: 'StrategyKey') -> bool:
- '[seq]' matches any character in seq
- '[!seq]' matches any character not in seq
"""
return (fnmatch.fnmatch(other.executor_type.value,
self.executor_type.value)
return (fnmatch.fnmatch(other.executor_type, self.executor_type)
and fnmatch.fnmatch(other.data_type, self.data_type)
and fnmatch.fnmatch(other.data_source, self.data_source))

Expand Down Expand Up @@ -71,7 +69,7 @@ class DataLoadStrategyRegistry:

@classmethod
def get_strategy_class(
cls, executor_type: ExecutorType, data_type: str,
cls, executor_type: str, data_type: str,
data_source: str) -> Optional[Type[DataLoadStrategy]]:
"""
Retrieve the most specific matching strategy
Expand All @@ -81,7 +79,7 @@ def get_strategy_class(
2. Wildcard matches from most specific to most general
"""
# default to wildcard if not provided
executor_type = executor_type or ExecutorType.ANY
executor_type = executor_type or '*'
data_type = data_type or '*'
data_source = data_source or '*'

Expand Down Expand Up @@ -121,8 +119,7 @@ def specificity_score(key: StrategyKey) -> int:
return None

@classmethod
def register(cls, executor_type: ExecutorType, data_type: str,
data_source: str):
def register(cls, executor_type: str, data_type: str, data_source: str):
"""
Decorator for registering data load strategies with wildcard support
Expand Down Expand Up @@ -179,7 +176,7 @@ def load_data(self, **kwargs) -> DJDataset:
# pass


@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'ondisk', 'json')
@DataLoadStrategyRegistry.register('ray', 'ondisk', 'json')
class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy):

CONFIG_VALIDATION_RULES = {
Expand All @@ -191,13 +188,13 @@ class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy):
}

def load_data(self, **kwargs):
dataset = rd.read_json(self.ds_config.path)
dataset = rd.read_json(self.ds_config['path'])
return RayDataset(dataset,
dataset_path=self.ds_config.path,
dataset_path=self.ds_config['path'],
cfg=self.cfg)


@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'remote', 'huggingface')
@DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface')
class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy):

CONFIG_VALIDATION_RULES = {
Expand All @@ -213,7 +210,7 @@ def load_data(self, **kwargs):
'Huggingface data load strategy is not implemented')


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*')
@DataLoadStrategyRegistry.register('local', 'ondisk', '*')
class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for on disk data for LocalExecutor
Expand All @@ -229,16 +226,18 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy):
}

def load_data(self, **kwargs):
print(f'kwards: {kwargs}')
# use proper formatter to load data
formatter = load_formatter(dataset_path=self.ds_config.path,
formatter = load_formatter(dataset_path=self.ds_config['path'],
suffixes=self.cfg.suffixes,
text_keys=self.cfg.text_keys,
add_suffix=self.cfg.add_suffix**kwargs)
add_suffix=self.cfg.add_suffix,
**kwargs)
# TODO more sophiscated localformatter routing
return formatter.load_data()
return formatter.load_dataset()


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'huggingface')
@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface')
class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for Huggingface dataset for LocalExecutor
Expand All @@ -254,8 +253,8 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy):
}

def load_data(self, **kwargs):
num_proc = kwargs.get('num_proc', 1)
ds = datasets.load_dataset(self.ds_config.path,
num_proc = kwargs.pop('num_proc', 1)
ds = datasets.load_dataset(self.ds_config['path'],
split=self.ds_config.split,
name=self.ds_config.name,
limit=self.ds_config.limit,
Expand All @@ -264,7 +263,7 @@ def load_data(self, **kwargs):
ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc)


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'modelscope')
@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope')
class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for ModelScope dataset for LocalExecutor
Expand All @@ -275,7 +274,7 @@ def load_data(self):
'ModelScope data load strategy is not implemented')


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'arxiv')
@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv')
class LocalArxivDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for arxiv dataset for LocalExecutor
Expand All @@ -294,7 +293,7 @@ def load_data(self):
'Arxiv data load strategy is not implemented')


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'wiki')
@DataLoadStrategyRegistry.register('local', 'remote', 'wiki')
class LocalWikiDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for wiki dataset for LocalExecutor
Expand All @@ -312,7 +311,7 @@ def load_data(self):
raise NotImplementedError('Wiki data load strategy is not implemented')


@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'commoncrawl')
@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl')
class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy):
"""
data load strategy for commoncrawl dataset for LocalExecutor
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/core/executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .base import ExecutorBase, ExecutorType
from .base import ExecutorBase
from .factory import ExecutorFactory
from .local_executor import Executor
from .ray_executor import RayExecutor

__all__ = [
'ExecutorBase', 'ExecutorFactory', 'Executor', 'RayExecutor',
'ExecutorType'
]
__all__ = ['ExecutorBase'
'ExecutorFactory', 'Executor', 'RayExecutor']
7 changes: 0 additions & 7 deletions data_juicer/core/executor/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional

from jsonargparse import Namespace
Expand All @@ -8,12 +7,6 @@
from data_juicer.config import init_configs


class ExecutorType(Enum):
LOCAL = 'local'
RAY = 'ray'
ANY = '*'


class ExecutorBase(ABC):

@abstractmethod
Expand Down
3 changes: 1 addition & 2 deletions data_juicer/core/executor/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from loguru import logger
from pydantic import PositiveInt

from data_juicer.core import ExecutorType
from data_juicer.core.adapter import Adapter
from data_juicer.core.data.dataset_builder import DatasetBuilder
from data_juicer.core.executor import ExecutorBase
Expand Down Expand Up @@ -54,7 +53,7 @@ def __init__(self, cfg: Optional[Namespace] = None):
# setup dataset builder
logger.info('Setting up dataset builder...')
self.dataset_builder = DatasetBuilder(cfg,
executor_type=ExecutorType.LOCAL)
executor_type=self.executor_type)

# whether to use checkpoint mechanism. If it's true, Executor will
# check if there are existing checkpoints first and try to load the
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/core/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from data_juicer.core.adapter import Adapter
from data_juicer.core.data.dataset_builder import DatasetBuilder
from data_juicer.core.executor import ExecutorBase, ExecutorType
from data_juicer.core.executor import ExecutorBase
from data_juicer.ops import load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils.lazy_loader import LazyLoader
Expand Down Expand Up @@ -63,8 +63,7 @@ def __init__(self, cfg: Optional[Namespace] = None):
ray.get_runtime_context().get_job_id())

# init dataset builder
self.datasetbuilder = DatasetBuilder(self.cfg,
executor_type=ExecutorType.RAY)
self.datasetbuilder = DatasetBuilder(self.cfg, executor_type='ray')

def run(self,
load_data_np: Optional[PositiveInt] = None,
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/format/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
:param global_cfg: global cfg used in consequent processes,
:return: formatted dataset
"""
_num_proc = self.kwargs.pop('num_proc', 1)
num_proc = num_proc or _num_proc
datasets = load_dataset(self.type,
data_files={
key.strip('.'): self.data_files[key]
Expand Down
3 changes: 1 addition & 2 deletions tests/core/data/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ project_name: 'dataset-ondisk-json'
dataset:
configs:
- type: 'ondisk'
path:
- 'sample.json'
path: 'sample.json'
6 changes: 2 additions & 4 deletions tests/core/data/test_config_list.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ project_name: 'dataset-ondisk-list'
dataset:
configs:
- type: 'ondisk'
path:
- 'sample.json'
path: 'sample.json'
- type: 'ondisk'
path:
- 'sample.txt'
path: 'sample.txt'
Loading

0 comments on commit dd95df0

Please sign in to comment.