Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578835428
  • Loading branch information
Marvin182 authored and copybara-github committed Nov 2, 2023
1 parent 29bdd38 commit 4a486d4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
1 change: 1 addition & 0 deletions grain/_src/python/lazy_dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ py_test(
srcs_version = "PY3",
deps = [
":mix",
":repeat",
"//grain/_src/python/lazy_dataset",
],
)
Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/lazy_dataset/transformations/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(
# Compute length.
lengths = np.asarray([len(p) for p in parents])
float_proportions = np.asarray(proportions) / sum(proportions)
# Ensure all elements of constituent datasets appear at least once.
self._length = int((lengths / float_proportions).max())
# Ensure all elements of constituent datasets appear at most once.
self._length = int((lengths / float_proportions).min())

def __len__(self) -> int:
return self._length
Expand Down
50 changes: 44 additions & 6 deletions grain/_src/python/lazy_dataset/transformations/mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import absltest
from grain._src.python.lazy_dataset import lazy_dataset
from grain._src.python.lazy_dataset.transformations import mix
from grain._src.python.lazy_dataset.transformations import repeat # pylint: disable=unused-import


class MixedLazyMapDatasetTest(absltest.TestCase):
Expand All @@ -25,6 +26,18 @@ def setUp(self):
self.even_ds = lazy_dataset.RangeLazyMapDataset(0, 10, 2)
self.odd_ds = lazy_dataset.RangeLazyMapDataset(1, 10, 2)

def test_len(self):
# Mix dataset has length to see any element at most once.
ds1 = lazy_dataset.RangeLazyMapDataset(10)
ds2 = lazy_dataset.RangeLazyMapDataset(20)
ds3 = lazy_dataset.RangeLazyMapDataset(5)
# Equal proportions.
ds = mix.MixedLazyMapDataset([ds1, ds2, ds3])
self.assertLen(ds, 15)
# Heigher weight for second dataset.
ds = mix.MixedLazyMapDataset([ds1, ds2, ds3], proportions=[1, 2, 1])
self.assertLen(ds, 5 + 10 + 5)

def test_mixing_equal_probability_with_integer_proportions(self):
mixed_lzds = mix.MixedLazyMapDataset(
parents=[self.even_ds, self.odd_ds], proportions=[2, 2]
Expand Down Expand Up @@ -52,17 +65,42 @@ def test_mixing_with_float_proportions(self):
mixed_lzds = mix.MixedLazyMapDataset(
parents=[self.even_ds, self.odd_ds], proportions=[0.75, 0.25]
)
actual_vals = [mixed_lzds[i] for i in range(len(mixed_lzds))]
expected_vals = [0, 2, 4, 1, 6, 8, 0, 3, 2, 4, 6, 5, 8, 0, 2, 7, 4, 6, 8, 9]
self.assertEqual(expected_vals, actual_vals)
self.assertLen(mixed_lzds, 6)

actual_vals = list(mixed_lzds)
expected_frist_epoch = [0, 2, 4, 1, 6, 8]
self.assertEqual(actual_vals, expected_frist_epoch)

actual_vals = list(mixed_lzds.repeat(2))
expected_two_epochs = [
0,
2,
4,
1,
6,
8,
0,
3,
2,
4,
6,
5,
]
self.assertEqual(actual_vals, expected_two_epochs)

def test_mixing_with_integer_proportions(self):
mixed_lzds = mix.MixedLazyMapDataset(
parents=[self.even_ds, self.odd_ds], proportions=[1, 2]
)
actual_values = [mixed_lzds[i] for i in range(len(mixed_lzds))]
expected_values = [0, 1, 3, 2, 5, 7, 4, 9, 1, 6, 3, 5, 8, 7, 9]
self.assertEqual(expected_values, actual_values)
self.assertLen(mixed_lzds, 7)

actual_values = list(mixed_lzds)
expected_first_epoch = [0, 1, 3, 2, 5, 7, 4]
self.assertEqual(expected_first_epoch, actual_values)

actual_values = list(mixed_lzds.repeat(2))
expected_two_epochs = [0, 1, 3, 2, 5, 7, 4, 9, 1, 6, 3, 5, 8, 7]
self.assertEqual(expected_two_epochs, actual_values)

def test_mixing_zero_one_probability_fails_with_error(self):
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 4a486d4

Please sign in to comment.