Skip to content

Commit

Permalink
Removing pickler and adding CopyNumPyArrayToSharedMemory MapTransform.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577879209
  • Loading branch information
Grain Team authored and copybara-github committed Oct 31, 2023
1 parent a079f68 commit fdcc75f
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 107 deletions.
1 change: 0 additions & 1 deletion grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ py_library(
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python/experimental/example_packing:packing", # build_cleaner: keep
"//grain/_src/python/experimental/proto_parsers:fast_proto_parser", # build_cleaner: keep
"//grain/_src/python/experimental/shared_memory:np_array_in_shared_memory", # build_cleaner: keep
],
)

Expand Down
2 changes: 0 additions & 2 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ py_library(
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/core:usage_logging",
"//grain/_src/python/experimental/shared_memory:np_array_in_shared_memory",
],
)

Expand Down Expand Up @@ -168,7 +167,6 @@ py_library(
":multiprocessing_common",
":options",
"//grain/_src/core:parallel",
"//grain/_src/python/experimental/shared_memory:np_array_in_shared_memory",
],
)

Expand Down
36 changes: 27 additions & 9 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from grain._src.python import options
from grain._src.python import record
from grain._src.python.data_sources import RandomAccessDataSource
from grain._src.python.experimental.shared_memory import np_array_in_shared_memory
from grain._src.python.operations import BatchOperation
from grain._src.python.operations import Operation
from grain._src.python.samplers import Sampler
from grain._src.python.shared_memory_array import SharedMemoryArray
from grain._src.python.shared_memory_array import SharedMemoryArrayMetadata
import numpy as np
import tree


Expand Down Expand Up @@ -117,6 +117,26 @@ def use_context_if_available(obj):
yield


@dataclasses.dataclass
class CopyNumPyArrayToSharedMemory(transforms.MapTransform):
"""If `element` contains NumPy array copy it to SharedMemoryArray."""

def map(self, element: Any) -> Any:
def copy_if_applied(element: Any) -> Any:
if (
not isinstance(element, np.ndarray)
or element.dtype.hasobject
or not element.flags.c_contiguous
):
return element

shared_memory_arr = SharedMemoryArray(element.shape, element.dtype)
np.copyto(shared_memory_arr, element, casting="no")
return shared_memory_arr.metadata

return tree.map_structure(copy_if_applied, element)


class DataLoader:
"""DataLoader loads and transforms input data."""

Expand Down Expand Up @@ -165,15 +185,13 @@ def __init__(

worker_count = _determine_worker_count(worker_count)

# Shared memory should be enabled in Batch operation iff worker_count > 0.
if (
not np_array_in_shared_memory.numpy_shared_memory_pickler_enabled()
and worker_count > 0
and len(operations)
and isinstance(operations[-1], BatchOperation)
):
# Shared memory should be enabled iff worker_count > 0.
if operations and isinstance(operations[-1], BatchOperation):
logging.info("Enabling SharedMemoryArray for BatchOperation.")
operations[-1]._enable_shared_memory()
logging.info("Enabling shared memory.")
else:
logging.info("Adding CopyNumPyArrayToSharedMemory MapTransform.")
operations = list(operations) + [CopyNumPyArrayToSharedMemory()]

self._data_source = data_source
self._sampler = sampler
Expand Down
65 changes: 51 additions & 14 deletions grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ def map(self, x):
return x


class CopyNumPyArrayToSharedMemoryTest(absltest.TestCase):

def test_copy_numpy_array_to_shared_memory(self):
element = np.array([1, 2, 3, 4, 5, 6, 7])
transform = data_loader_lib.CopyNumPyArrayToSharedMemory()
result = transform.map(element)
self.assertIsInstance(result, data_loader_lib.SharedMemoryArrayMetadata)

def test_copy_nested_numpy_array_to_shared_memory(self):
element_1 = np.arange(5)
element_2 = np.arange(5)
transform = data_loader_lib.CopyNumPyArrayToSharedMemory()
result = transform.map([element_1, element_2])
self.assertIsInstance(result[0], data_loader_lib.SharedMemoryArrayMetadata)
self.assertIsInstance(result[1], data_loader_lib.SharedMemoryArrayMetadata)

def test_copy_skipped_non_numpy_array(self):
element = "randomstring"
transform = data_loader_lib.CopyNumPyArrayToSharedMemory()
result = transform.map(element)
self.assertIs(result, element)

def test_copy_skipped_dtype_hasobject(self):
class DT:
pass

element = np.array([127, 128, 129], dtype=np.dtype(DT))
transform = data_loader_lib.CopyNumPyArrayToSharedMemory()
result = transform.map(element)
print(result)
self.assertIs(result, element)

def test_copy_skipped_flags_c_contiguous(self):
element = np.arange(9).reshape(3, 3)[:, (0, 1)]
transform = data_loader_lib.CopyNumPyArrayToSharedMemory()
result = transform.map(element)
self.assertIs(result, element)


class DataLoaderTest(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -557,41 +596,39 @@ def test_batch_transform_mapped_to_batch_operation(self):
actual = list(data_loader)
np.testing.assert_equal(actual, expected)

@mock.patch.object(data_loader_lib, "np_array_in_shared_memory")
def test_global_shared_memory(self, mock_np_array_in_shared_memory):
@mock.patch.object(data_loader_lib, "CopyNumPyArrayToSharedMemory")
def test_shared_memory(self, mock_copy_numpy_array_to_shared_memory):
range_data_source = RangeDataSource(start=0, stop=8, step=1)
sampler = samplers.SequentialSampler(
num_records=len(range_data_source), shard_options=sharding.NoSharding()
)

batch_operation = mock.MagicMock(BatchOperation(batch_size=2))
operations = [
PlusOne(),
FilterEven(),
batch_operation,
]

mock_np_array_in_shared_memory.numpy_shared_memory_pickler_enabled.return_value = (
True
)
data_loader_lib.DataLoader(
batch_operation = mock.MagicMock(BatchOperation(batch_size=2))

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations,
worker_count=2,
worker_count=0,
)
batch_operation._enable_shared_memory.assert_not_called()

mock_np_array_in_shared_memory.numpy_shared_memory_pickler_enabled.return_value = (
False
self.assertTrue(
data_loader._operations[-1], mock_copy_numpy_array_to_shared_memory
)
data_loader_lib.DataLoader(

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations,
operations=operations + [batch_operation],
worker_count=2,
)
batch_operation._enable_shared_memory.assert_called_once()
self.assertTrue(data_loader._operations[-1], batch_operation)


class PyGrainDatasetIteratorTest(absltest.TestCase):
Expand Down
10 changes: 0 additions & 10 deletions grain/_src/python/experimental/shared_memory/BUILD

This file was deleted.

This file was deleted.

7 changes: 0 additions & 7 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from grain._src.core import parallel
from grain._src.python import grain_logging
from grain._src.python import multiprocessing_common
from grain._src.python.experimental.shared_memory import np_array_in_shared_memory
from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member

T = TypeVar("T")
Expand Down Expand Up @@ -153,15 +152,12 @@ def _worker_loop(
worker_index: int,
worker_count: int,
enable_profiling: bool,
enable_numpy_shared_memory: bool = False,
):
"""Code to be run on each child process."""
try:
grain_logging.set_process_identifier_prefix(
f"PyGrain Worker {worker_index}"
)
if enable_numpy_shared_memory:
np_array_in_shared_memory.enable_numpy_shared_memory_pickler()
logging.info("Starting work.")
element_producer = _get_element_producer_from_queue(
args_queue, worker_index=worker_index, worker_count=worker_count
Expand Down Expand Up @@ -265,9 +261,6 @@ def __init__(
"worker_index": worker_index,
"worker_count": options.num_workers,
"enable_profiling": options.enable_profiling,
"enable_numpy_shared_memory": (
np_array_in_shared_memory.numpy_shared_memory_pickler_enabled()
),
}
# The process kwargs must all be pickable and will be unpickle before
# absl.app.run() is called. We send arguments via a queue to ensure that
Expand Down
5 changes: 0 additions & 5 deletions grain/python_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,3 @@
)
from ._src.python.experimental.example_packing.packing import PackAndBatchOperation
from ._src.python.experimental.index_shuffle.python.index_shuffle_module import index_shuffle
from ._src.python.experimental.shared_memory.np_array_in_shared_memory import (
disable_numpy_shared_memory_pickler,
enable_numpy_shared_memory_pickler,
numpy_shared_memory_pickler_enabled,
)

0 comments on commit fdcc75f

Please sign in to comment.