diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 1b3b8d09..7762e0a1 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -20,6 +20,7 @@ import contextlib import copy import functools +import math import queue import sys import threading @@ -275,6 +276,7 @@ def _copy_leaf_to_shm(leaf: Any) -> Any: not isinstance(leaf, np.ndarray) or leaf.dtype.hasobject or not leaf.flags.c_contiguous + or math.prod(leaf.shape) == 0 ): return leaf diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index d5f1a555..92ab1ddb 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -29,6 +29,7 @@ from grain._src.python.dataset.transformations import filter as filter_lazy_dataset from grain._src.python.dataset.transformations import map as map_lazy_dataset from grain._src.python.dataset.transformations import prefetch +import numpy as np _T = TypeVar('_T') @@ -263,6 +264,22 @@ def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): expected = list(range(1, 20, 2)) self.assertSequenceEqual(actual, expected) + def test_prefetch_size_zero_data(self): + ds = dataset.MapDataset.source( + [np.zeros(shape=(0,), dtype=np.int64)] + ).repeat(3) + iter_ds = ds.to_iter_dataset() + prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( + iter_ds, + options.MultiprocessingOptions(num_workers=1), + ) + actual = list(prefetch_lazy_iter_ds) + expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 + self.assertLen(actual, 3) + self.assertLen(expected, 3) + for i in range(3): + np.testing.assert_array_equal(actual[i], expected[i]) + @parameterized.named_parameters( dict( testcase_name='0_workers', @@ -576,6 +593,22 @@ def test_prefetch_with_random_map(self): distinct_elements = set(elements) self.assertLen(distinct_elements, len(elements)) + def test_concurrent_start_prefetch(self): + num_iters = 10 # Can't set this much higher without Forge OOMing. + + def make_iter(i): + ds = dataset.MapDataset.source([i]) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) + return ds.__iter__() + + iters = [make_iter(i) for i in range(num_iters)] + with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: + for it in iters: + executor.submit(it.start_prefetch) + for it in iters: + _ = next(it) + class ThreadPrefetchIterDatasetTest(parameterized.TestCase): @@ -663,22 +696,6 @@ def test_fails_with_negative_prefetch_buffer_size(self): ): prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1) - def test_concurrent_start_prefetch(self): - num_iters = 10 # Can't set this much higher without Forge OOMing. - - def make_iter(i): - ds = dataset.MapDataset.source([i]) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) - return ds.__iter__() - - iters = [make_iter(i) for i in range(num_iters)] - with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: - for it in iters: - executor.submit(it.start_prefetch) - for it in iters: - _ = next(it) - if __name__ == '__main__': absltest.main()