Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #294

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions grain/_src/python/lazy_dataset/transformations/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import dataclasses
import functools
import sys
from typing import Any, Sequence, Tuple, TypeVar, Union

from grain._src.core.exceptions import PyGrainInternalError
from grain._src.python.lazy_dataset import lazy_dataset
import numpy as np


Element = Any
Expand All @@ -46,11 +46,14 @@ def __init__(
assert len(parents) == len(proportions)
self._proportions = tuple(proportions)

# Compute length.
lengths = np.asarray([len(p) for p in parents])
float_proportions = np.asarray(proportions) / sum(proportions)
# Ensure all elements of constituent datasets appear at most once.
self._length = int((lengths / float_proportions).min())
# Compute length such that elements of constituent datasets appear at most
# once.
weight_sum = sum(proportions)
lengths = [
len(parent) / (weight / weight_sum)
for parent, weight in zip(parents, proportions)
]
self._length = min(sys.maxsize, int(min(lengths)))

def __len__(self) -> int:
return self._length
Expand Down
16 changes: 16 additions & 0 deletions grain/_src/python/lazy_dataset/transformations/mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.
"""Tests for mixing transformation."""

import sys

from absl.testing import absltest
from grain._src.python.lazy_dataset import lazy_dataset
from grain._src.python.lazy_dataset.transformations import mix
from grain._src.python.lazy_dataset.transformations import repeat # pylint: disable=unused-import
import numpy as np


class MixedLazyMapDatasetTest(absltest.TestCase):
Expand Down Expand Up @@ -108,6 +111,19 @@ def test_mixing_zero_one_probability_fails_with_error(self):
parents=[self.even_ds, self.odd_ds], proportions=[0, 1]
)

def test_mix_infinite_datasets(self):
zeros = lazy_dataset.RangeLazyMapDataset(0, 1).repeat()
ones = lazy_dataset.RangeLazyMapDataset(1, 2).repeat()
self.assertLen(zeros, sys.maxsize)
self.assertLen(ones, sys.maxsize)
ld = mix.MixedLazyMapDataset([zeros, ones], proportions=[4, 1])
self.assertLen(ld, sys.maxsize)
# Mix again.
ld = mix.MixedLazyMapDataset([ld, ones], proportions=[1, 1])
num_samples = 1000
value_counts = np.bincount([ld[i] for i in range(num_samples)]).tolist()
self.assertEqual(value_counts, [400, 600])


if __name__ == "__main__":
absltest.main()