diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index c50afc68dc..13e7d7f936 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -177,7 +177,6 @@ class LoadHF(Loader): num_proc: Optional[int] = None _cache: dict = InternalField(default=None) requirements_list: List[str] = OptionalField(default_factory=list) - needed_for_random_mix: Optional[List[str]] = None def verify(self): for requirement in self.requirements_list: @@ -275,21 +274,15 @@ def load_dataset(self): return dataset - def load_split(self, dataset, split_name): - limit = self.get_limit() - if limit is not None: - self.log_limited_loading() - logger.info(f"limiting {split_name}") - else: - logger.info(f"\nloading split {split_name} unlimited") - yield from itertools.islice(dataset[split_name], limit) + def split_limited_load(self, dataset, split_name): + yield from itertools.islice(dataset[split_name], self.get_limit()) def limited_load(self, dataset): self.log_limited_loading() return MultiStream( { name: DynamicStream( - generator=self.load_split, + generator=self.split_limited_load, gen_kwargs={"dataset": dataset, "split_name": name}, ) for name in self._cache.keys() @@ -312,17 +305,10 @@ def load_data(self): ): # streaming is not supported for zipped files so we load without streaming dataset = self.load_dataset() - return MultiStream( - { - name: DynamicStream( - generator=self.load_split, - gen_kwargs={"dataset": dataset, "split_name": name}, - ) - for name in self._cache.keys() - if self.needed_for_random_mix is None - or name in self.needed_for_random_mix - } - ) + if self.get_limit() is not None: + return self.limited_load(dataset=dataset) + + return MultiStream.from_iterables(dataset) class LoadCSV(Loader): diff --git a/src/unitxt/split_utils.py b/src/unitxt/split_utils.py index d91da9ba7c..ed0c47989f 100644 --- a/src/unitxt/split_utils.py +++ b/src/unitxt/split_utils.py @@ -1,16 +1,13 @@ import itertools import re -from collections import defaultdict -from typing import Any, Dict, List +from typing import Dict from .generator_utils import ReusableGenerator from .logging_utils import get_logger from .random_utils import new_random_generator -from .settings_utils import get_settings from .stream import MissingStreamError, Stream logger = get_logger() -settings = get_settings() def parse_random_mix_string(input_str): @@ -103,9 +100,6 @@ def parse_slices_string(input_str): def slice_stream(stream, start, end): - if settings.use_eager_execution: - yield from stream.stream.instances_list[start:end] - return # If start is None, consume from the beginning if start is not None: stream = itertools.islice(stream, start, None) @@ -240,22 +234,6 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]): return {**input_streams, **new_streams} -def random_stream_mixer(multi_stream, mapping) -> Dict[str, List[Dict[str, Any]]]: - stream_routing = build_stream_routing(mapping) - new_streams = defaultdict(list) - for old_stream_name in sorted(stream_routing.keys()): - # sorted to canonize the order by which the old streams contribute to each new stream - assert ( - old_stream_name in multi_stream - ), f"'{old_stream_name}' split not found. Possibles options: {multi_stream.keys()}" - optional_streams, weights = stream_routing[old_stream_name] - random_generator = new_random_generator(sub_seed=old_stream_name) - for item in multi_stream[old_stream_name]: - choice = random_generator.choices(optional_streams, weights=weights, k=1)[0] - new_streams[choice].append(item) - return new_streams - - def random_mix_generator( new_stream_name, new_stream_sources, stream_routing, input_streams ): diff --git a/src/unitxt/splitters.py b/src/unitxt/splitters.py index e0299d7435..9968e2f7bb 100644 --- a/src/unitxt/splitters.py +++ b/src/unitxt/splitters.py @@ -7,12 +7,10 @@ from .dict_utils import dict_get from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator from .random_utils import new_random_generator -from .settings_utils import get_settings from .split_utils import ( parse_random_mix_string, parse_slices_string, random_mix_streams, - random_stream_mixer, rename_split, slice_streams, ) @@ -60,11 +58,7 @@ class SplitRandomMix(Splitter): mix: Dict[str, str] def process(self, multi_stream: MultiStream) -> MultiStream: - settings = get_settings() mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()} - if settings.use_eager_execution: - dict_of_instance_lists = random_stream_mixer(multi_stream, mapping) - return MultiStream.from_iterables(dict_of_instance_lists) generators = random_mix_streams(multi_stream, mapping) return MultiStream.from_generators(generators)