Skip to content

Commit

Permalink
#grain Fix bug in transferring size-0 np arrays between processes.
Browse files Browse the repository at this point in the history
Before the change to prefetch, the added test would fail at

File "/<embedded stdlib>/multiprocessing/shared_memory.py", line 81, in __init__
    raise ValueError("'size' must be a positive number different from zero")

PiperOrigin-RevId: 713359364
  • Loading branch information
aaudiber authored and copybara-github committed Jan 8, 2025
1 parent d458dfc commit 7888c0c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import contextlib
import copy
import functools
import math
import queue
import sys
import threading
Expand Down Expand Up @@ -275,6 +276,7 @@ def _copy_leaf_to_shm(leaf: Any) -> Any:
not isinstance(leaf, np.ndarray)
or leaf.dtype.hasobject
or not leaf.flags.c_contiguous
or math.prod(leaf.shape) == 0
):
return leaf

Expand Down
49 changes: 33 additions & 16 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from grain._src.python.dataset.transformations import filter as filter_lazy_dataset
from grain._src.python.dataset.transformations import map as map_lazy_dataset
from grain._src.python.dataset.transformations import prefetch
import numpy as np


_T = TypeVar('_T')
Expand Down Expand Up @@ -263,6 +264,22 @@ def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int):
expected = list(range(1, 20, 2))
self.assertSequenceEqual(actual, expected)

def test_prefetch_size_zero_data(self):
ds = dataset.MapDataset.source(
[np.zeros(shape=(0,), dtype=np.int64)]
).repeat(3)
iter_ds = ds.to_iter_dataset()
prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset(
iter_ds,
options.MultiprocessingOptions(num_workers=1),
)
actual = list(prefetch_lazy_iter_ds)
expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3
self.assertLen(actual, 3)
self.assertLen(expected, 3)
for i in range(3):
np.testing.assert_array_equal(actual[i], expected[i])

@parameterized.named_parameters(
dict(
testcase_name='0_workers',
Expand Down Expand Up @@ -576,6 +593,22 @@ def test_prefetch_with_random_map(self):
distinct_elements = set(elements)
self.assertLen(distinct_elements, len(elements))

def test_concurrent_start_prefetch(self):
num_iters = 10 # Can't set this much higher without Forge OOMing.

def make_iter(i):
ds = dataset.MapDataset.source([i])
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1))
return ds.__iter__()

iters = [make_iter(i) for i in range(num_iters)]
with futures.ThreadPoolExecutor(max_workers=num_iters) as executor:
for it in iters:
executor.submit(it.start_prefetch)
for it in iters:
_ = next(it)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

Expand Down Expand Up @@ -663,22 +696,6 @@ def test_fails_with_negative_prefetch_buffer_size(self):
):
prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1)

def test_concurrent_start_prefetch(self):
num_iters = 10 # Can't set this much higher without Forge OOMing.

def make_iter(i):
ds = dataset.MapDataset.source([i])
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1))
return ds.__iter__()

iters = [make_iter(i) for i in range(num_iters)]
with futures.ThreadPoolExecutor(max_workers=num_iters) as executor:
for it in iters:
executor.submit(it.start_prefetch)
for it in iters:
_ = next(it)


if __name__ == '__main__':
absltest.main()

0 comments on commit 7888c0c

Please sign in to comment.