Skip to content

Commit

Permalink
no need to mix all output streams at once, because at the end, pullin…
Browse files Browse the repository at this point in the history
…g drives split by split, and need to compute each from scratch

Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Nov 3, 2024
1 parent 132a030 commit ba308ce
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 50 deletions.
28 changes: 7 additions & 21 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ class LoadHF(Loader):
num_proc: Optional[int] = None
_cache: dict = InternalField(default=None)
requirements_list: List[str] = OptionalField(default_factory=list)
needed_for_random_mix: Optional[List[str]] = None

def verify(self):
for requirement in self.requirements_list:
Expand Down Expand Up @@ -275,21 +274,15 @@ def load_dataset(self):

return dataset

def load_split(self, dataset, split_name):
limit = self.get_limit()
if limit is not None:
self.log_limited_loading()
logger.info(f"limiting {split_name}")
else:
logger.info(f"\nloading split {split_name} unlimited")
yield from itertools.islice(dataset[split_name], limit)
def split_limited_load(self, dataset, split_name):
yield from itertools.islice(dataset[split_name], self.get_limit())

def limited_load(self, dataset):
self.log_limited_loading()
return MultiStream(
{
name: DynamicStream(
generator=self.load_split,
generator=self.split_limited_load,
gen_kwargs={"dataset": dataset, "split_name": name},
)
for name in self._cache.keys()
Expand All @@ -312,17 +305,10 @@ def load_data(self):
): # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset()

return MultiStream(
{
name: DynamicStream(
generator=self.load_split,
gen_kwargs={"dataset": dataset, "split_name": name},
)
for name in self._cache.keys()
if self.needed_for_random_mix is None
or name in self.needed_for_random_mix
}
)
if self.get_limit() is not None:
return self.limited_load(dataset=dataset)

return MultiStream.from_iterables(dataset)


class LoadCSV(Loader):
Expand Down
24 changes: 1 addition & 23 deletions src/unitxt/split_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import itertools
import re
from collections import defaultdict
from typing import Any, Dict, List
from typing import Dict

from .generator_utils import ReusableGenerator
from .logging_utils import get_logger
from .random_utils import new_random_generator
from .settings_utils import get_settings
from .stream import MissingStreamError, Stream

logger = get_logger()
settings = get_settings()


def parse_random_mix_string(input_str):
Expand Down Expand Up @@ -103,9 +100,6 @@ def parse_slices_string(input_str):


def slice_stream(stream, start, end):
if settings.use_eager_execution:
yield from stream.stream.instances_list[start:end]
return
# If start is None, consume from the beginning
if start is not None:
stream = itertools.islice(stream, start, None)
Expand Down Expand Up @@ -240,22 +234,6 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
return {**input_streams, **new_streams}


def random_stream_mixer(multi_stream, mapping) -> Dict[str, List[Dict[str, Any]]]:
stream_routing = build_stream_routing(mapping)
new_streams = defaultdict(list)
for old_stream_name in sorted(stream_routing.keys()):
# sorted to canonize the order by which the old streams contribute to each new stream
assert (
old_stream_name in multi_stream
), f"'{old_stream_name}' split not found. Possibles options: {multi_stream.keys()}"
optional_streams, weights = stream_routing[old_stream_name]
random_generator = new_random_generator(sub_seed=old_stream_name)
for item in multi_stream[old_stream_name]:
choice = random_generator.choices(optional_streams, weights=weights, k=1)[0]
new_streams[choice].append(item)
return new_streams


def random_mix_generator(
new_stream_name, new_stream_sources, stream_routing, input_streams
):
Expand Down
6 changes: 0 additions & 6 deletions src/unitxt/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from .dict_utils import dict_get
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import new_random_generator
from .settings_utils import get_settings
from .split_utils import (
parse_random_mix_string,
parse_slices_string,
random_mix_streams,
random_stream_mixer,
rename_split,
slice_streams,
)
Expand Down Expand Up @@ -60,11 +58,7 @@ class SplitRandomMix(Splitter):
mix: Dict[str, str]

def process(self, multi_stream: MultiStream) -> MultiStream:
settings = get_settings()
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
if settings.use_eager_execution:
dict_of_instance_lists = random_stream_mixer(multi_stream, mapping)
return MultiStream.from_iterables(dict_of_instance_lists)
generators = random_mix_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)

Expand Down

0 comments on commit ba308ce

Please sign in to comment.