diff --git a/grain/_src/core/BUILD b/grain/_src/core/BUILD index 087e65e8..b42f946c 100644 --- a/grain/_src/core/BUILD +++ b/grain/_src/core/BUILD @@ -72,3 +72,38 @@ py_library( srcs = ["transforms.py"], srcs_version = "PY3", ) + +py_library( + name = "tree", + srcs = [ + "tree.py", + ], + srcs_version = "PY3", +) + +py_library( + name = "tree_test_lib", + testonly = 1, + srcs = ["tree_test.py"], + srcs_version = "PY3", + deps = [":tree"], +) + +py_test( + name = "tree_test", + srcs = ["tree_test.py"], + srcs_version = "PY3", + deps = [ + ":tree_test_lib", + ], +) + +py_test( + name = "tree_jax_test", + srcs = ["tree_jax_test.py"], + srcs_version = "PY3", + deps = [ + ":tree", + ":tree_test_lib", + ], +) diff --git a/grain/_src/core/tree.py b/grain/_src/core/tree.py new file mode 100644 index 00000000..d3eaaf33 --- /dev/null +++ b/grain/_src/core/tree.py @@ -0,0 +1,55 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for working with pytrees. + +See https://jax.readthedocs.io/en/latest/pytrees.html for more details about +pytrees. + +This module merely re-directs imports of the actual implementations. To avoid a +direct dependency on JAX, we check if it's already present and resort to the +`tree` package otherwise. + +We should be able to remove this module once b/257971667 is resolved. +""" + +try: + from jax import tree_util # pytype: disable=import-error # pylint: disable=g-import-not-at-top + + map_structure = tree_util.tree_map + map_structure_with_path = tree_util.tree_map_with_path + + def assert_same_structure(a, b): + a_structure = tree_util.tree_structure(a) + b_structure = tree_util.tree_structure(b) + if a_structure != b_structure: + raise ValueError( + f"Structures are not the same: a = {a_structure}, b = {b_structure}" + ) + + def flatten(structure): + return tree_util.tree_flatten(structure)[0] + + def unflatten_as(structure, flat_sequence): + return tree_util.tree_unflatten( + tree_util.tree_structure(structure), flat_sequence + ) + +except ImportError: + import tree # pylint: disable=g-import-not-at-top + + map_structure = tree.map_structure + map_structure_with_path = tree.map_structure_with_path + assert_same_structure = tree.assert_same_structure + flatten = tree.flatten + unflatten_as = tree.unflatten_as diff --git a/grain/_src/core/tree_jax_test.py b/grain/_src/core/tree_jax_test.py new file mode 100644 index 00000000..784ae316 --- /dev/null +++ b/grain/_src/core/tree_jax_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testes for tree.py with JAX dependency present.""" + +from absl.testing import absltest +from grain._src.core import tree +from grain._src.core import tree_test +import jax + + +class MyTree: + + def __init__(self, a, b): + self.a = a + self.b = b + + def __eq__(self, other): + return self.a == other.a and self.b == other.b + + +class TreeJaxTest(tree_test.TreeTest): + + def test_map_custom_tree(self): + jax.tree_util.register_pytree_node( + MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args) + ) + self.assertEqual( + tree.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/grain/_src/core/tree_test.py b/grain/_src/core/tree_test.py new file mode 100644 index 00000000..85c38828 --- /dev/null +++ b/grain/_src/core/tree_test.py @@ -0,0 +1,84 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testes for tree.py. + +Since the tree.py only re-directs the actual implementations this test does not +try to cover the actual functionality, but rather the re-direction correctness. +""" + +from typing import Protocol, runtime_checkable + +from absl.testing import absltest +from grain._src.core import tree + + +@runtime_checkable +class TreeImpl(Protocol): + + def map_structure(self, f, *structures): + ... + + def map_structure_with_path(self, f, *structures): + ... + + def assert_same_structure(self, a, b): + ... + + def flatten(self, structure): + ... + + def unflatten_as(self, structure, flat_sequence): + ... + + +# Static check that the module implements the necessary functions. +tree: TreeImpl = tree + + +class TreeTest(absltest.TestCase): + + def test_implements_tree_protocol(self): + # Run time check that the module implements the necessary functions. + # The module impl branching happens at run time, so the static check does + # not cover both branches. + self.assertIsInstance(tree, TreeImpl) + + def test_map_structure(self): + self.assertEqual( + tree.map_structure(lambda x: x + 1, ({"B": 10, "A": 20}, [1, 2], 3)), + ({"B": 11, "A": 21}, [2, 3], 4), + ) + + def test_map_structure_with_path(self): + self.assertEqual( + tree.map_structure_with_path( + lambda path, x: x if path else None, {"B": "v1", "A": "v2"} + ), + {"B": "v1", "A": "v2"}, + ) + + def test_assert_same_structure(self): + tree.assert_same_structure({"B": "v1", "A": "v2"}, {"B": 10, "A": 20}) + + def test_flatten(self): + self.assertEqual(tree.flatten({"A": "v2", "B": "v1"}), ["v2", "v1"]) + + def test_unflatten_as(self): + self.assertEqual( + tree.unflatten_as({"A": "v2", "B": "v1"}, [1, 2]), {"A": 1, "B": 2} + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index f169584a..de2499b0 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -73,6 +73,7 @@ py_library( deps = [ ":record", ":shared_memory_array", + "//grain/_src/core:tree", ], ) @@ -132,6 +133,7 @@ py_library( ":shared_memory_array", "//grain/_src/core:sharding", "//grain/_src/core:transforms", + "//grain/_src/core:tree", "//grain/_src/core:usage_logging", ], ) diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 305a69d6..280ed720 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -32,6 +32,7 @@ from concurrent import futures from grain._src.core import sharding from grain._src.core import transforms +from grain._src.core import tree from grain._src.core import usage_logging import multiprocessing as mp from grain._src.python import grain_pool @@ -45,7 +46,6 @@ from grain._src.python.shared_memory_array import SharedMemoryArray from grain._src.python.shared_memory_array import SharedMemoryArrayMetadata import numpy as np -import tree _T = TypeVar("_T") diff --git a/grain/_src/python/experimental/example_packing/BUILD b/grain/_src/python/experimental/example_packing/BUILD index f1634731..a40aaaa4 100644 --- a/grain/_src/python/experimental/example_packing/BUILD +++ b/grain/_src/python/experimental/example_packing/BUILD @@ -7,6 +7,7 @@ py_library( srcs = ["packing.py"], srcs_version = "PY3", deps = [ + "//grain/_src/core:tree", "//grain/_src/python:record", "//third_party/py/jaxtyping", ], @@ -18,6 +19,7 @@ py_test( srcs_version = "PY3", deps = [ ":packing", + "//grain/_src/core:tree", "//grain/_src/python:record", ], ) diff --git a/grain/_src/python/experimental/example_packing/packing.py b/grain/_src/python/experimental/example_packing/packing.py index e9190dc6..50045b9a 100644 --- a/grain/_src/python/experimental/example_packing/packing.py +++ b/grain/_src/python/experimental/example_packing/packing.py @@ -14,11 +14,11 @@ import dataclasses from typing import Generic, Iterator, TypeVar, cast +from grain._src.core import tree from grain._src.python import record import jax import jaxtyping as jt import numpy as np -import tree _T = TypeVar("_T") diff --git a/grain/_src/python/experimental/example_packing/packing_test.py b/grain/_src/python/experimental/example_packing/packing_test.py index 7d3ad83b..8f886211 100644 --- a/grain/_src/python/experimental/example_packing/packing_test.py +++ b/grain/_src/python/experimental/example_packing/packing_test.py @@ -1,11 +1,11 @@ """Tests for packing.py.""" from absl.testing import absltest +from grain._src.core import tree from grain._src.python import record from grain._src.python.experimental.example_packing import packing import numpy as np import tensorflow as tf -import tree def create_input_dataset(input_dataset_elements): diff --git a/grain/_src/python/lazy_dataset/transformations/BUILD b/grain/_src/python/lazy_dataset/transformations/BUILD index e42b78b5..5a03ecb9 100644 --- a/grain/_src/python/lazy_dataset/transformations/BUILD +++ b/grain/_src/python/lazy_dataset/transformations/BUILD @@ -6,7 +6,10 @@ py_library( name = "batch", srcs = ["batch.py"], srcs_version = "PY3", - deps = ["//grain/_src/python/lazy_dataset"], + deps = [ + "//grain/_src/core:tree", + "//grain/_src/python/lazy_dataset", + ], ) py_test( @@ -108,6 +111,7 @@ py_library( srcs = ["packing.py"], srcs_version = "PY3", deps = [ + "//grain/_src/core:tree", "//grain/_src/python/lazy_dataset", "//third_party/py/jaxtyping", ], diff --git a/grain/_src/python/lazy_dataset/transformations/batch.py b/grain/_src/python/lazy_dataset/transformations/batch.py index 74e82476..f40ccf5f 100644 --- a/grain/_src/python/lazy_dataset/transformations/batch.py +++ b/grain/_src/python/lazy_dataset/transformations/batch.py @@ -17,9 +17,9 @@ import math from typing import TypeVar +from grain._src.core import tree from grain._src.python.lazy_dataset import lazy_dataset import numpy as np -import tree T = TypeVar("T") diff --git a/grain/_src/python/lazy_dataset/transformations/packing.py b/grain/_src/python/lazy_dataset/transformations/packing.py index 42d8ccc5..cd193ee3 100644 --- a/grain/_src/python/lazy_dataset/transformations/packing.py +++ b/grain/_src/python/lazy_dataset/transformations/packing.py @@ -16,10 +16,10 @@ import copy from typing import Any +from grain._src.core import tree from grain._src.python.lazy_dataset import lazy_dataset from jaxtyping import PyTree # pylint: disable=g-importing-member import numpy as np -import tree # SingleBinPackLazyDatasetIterator's state is defined by the a state of the diff --git a/grain/_src/python/operations.py b/grain/_src/python/operations.py index 8f56e8d1..2bb0b635 100644 --- a/grain/_src/python/operations.py +++ b/grain/_src/python/operations.py @@ -21,10 +21,10 @@ from typing import Any, Callable, Generic, Iterator, Protocol, Sequence, TypeVar from absl import logging +from grain._src.core import tree from grain._src.python import record from grain._src.python.shared_memory_array import SharedMemoryArray import numpy as np -import tree _IN = TypeVar("_IN") _OUT = TypeVar("_OUT")