From 59a933a09d61603fdd5b9b9c3fd7c557a408ada8 Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Sun, 12 Nov 2023 03:27:54 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 581695313 --- .../python/lazy_dataset/transformations/mix.py | 15 +++++++++------ .../lazy_dataset/transformations/mix_test.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/grain/_src/python/lazy_dataset/transformations/mix.py b/grain/_src/python/lazy_dataset/transformations/mix.py index 6c9e20ce..68a43853 100644 --- a/grain/_src/python/lazy_dataset/transformations/mix.py +++ b/grain/_src/python/lazy_dataset/transformations/mix.py @@ -15,11 +15,11 @@ import dataclasses import functools +import sys from typing import Any, Sequence, Tuple, TypeVar, Union from grain._src.core.exceptions import PyGrainInternalError from grain._src.python.lazy_dataset import lazy_dataset -import numpy as np Element = Any @@ -46,11 +46,14 @@ def __init__( assert len(parents) == len(proportions) self._proportions = tuple(proportions) - # Compute length. - lengths = np.asarray([len(p) for p in parents]) - float_proportions = np.asarray(proportions) / sum(proportions) - # Ensure all elements of constituent datasets appear at most once. - self._length = int((lengths / float_proportions).min()) + # Compute length such that elements of constituent datasets appear at most + # once. + weight_sum = sum(proportions) + lengths = [ + len(parent) / (weight / weight_sum) + for parent, weight in zip(parents, proportions) + ] + self._length = min(sys.maxsize, int(min(lengths))) def __len__(self) -> int: return self._length diff --git a/grain/_src/python/lazy_dataset/transformations/mix_test.py b/grain/_src/python/lazy_dataset/transformations/mix_test.py index 92a5fb95..29f69bae 100644 --- a/grain/_src/python/lazy_dataset/transformations/mix_test.py +++ b/grain/_src/python/lazy_dataset/transformations/mix_test.py @@ -13,10 +13,13 @@ # limitations under the License. """Tests for mixing transformation.""" +import sys + from absl.testing import absltest from grain._src.python.lazy_dataset import lazy_dataset from grain._src.python.lazy_dataset.transformations import mix from grain._src.python.lazy_dataset.transformations import repeat # pylint: disable=unused-import +import numpy as np class MixedLazyMapDatasetTest(absltest.TestCase): @@ -108,6 +111,19 @@ def test_mixing_zero_one_probability_fails_with_error(self): parents=[self.even_ds, self.odd_ds], proportions=[0, 1] ) + def test_mix_infinite_datasets(self): + zeros = lazy_dataset.RangeLazyMapDataset(0, 1).repeat() + ones = lazy_dataset.RangeLazyMapDataset(1, 2).repeat() + self.assertLen(zeros, sys.maxsize) + self.assertLen(ones, sys.maxsize) + ld = mix.MixedLazyMapDataset([zeros, ones], proportions=[4, 1]) + self.assertLen(ld, sys.maxsize) + # Mix again. + ld = mix.MixedLazyMapDataset([ld, ones], proportions=[1, 1]) + num_samples = 1000 + value_counts = np.bincount([ld[i] for i in range(num_samples)]).tolist() + self.assertEqual(value_counts, [400, 600]) + if __name__ == "__main__": absltest.main()