Skip to content

Commit

Permalink
Use jax.tree_util for handling structures when available and tree
Browse files Browse the repository at this point in the history
… otherwise.

PiperOrigin-RevId: 579262209
  • Loading branch information
iindyk authored and copybara-github committed Nov 10, 2023
1 parent 4a486d4 commit 65f3b16
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 7 deletions.
35 changes: 35 additions & 0 deletions grain/_src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,38 @@ py_library(
srcs = ["transforms.py"],
srcs_version = "PY3",
)

py_library(
name = "tree",
srcs = [
"tree.py",
],
srcs_version = "PY3",
)

py_library(
name = "tree_test_lib",
testonly = 1,
srcs = ["tree_test.py"],
srcs_version = "PY3",
deps = [":tree"],
)

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
srcs_version = "PY3",
deps = [
":tree_test_lib",
],
)

py_test(
name = "tree_jax_test",
srcs = ["tree_jax_test.py"],
srcs_version = "PY3",
deps = [
":tree",
":tree_test_lib",
],
)
55 changes: 55 additions & 0 deletions grain/_src/core/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with pytrees.
See https://jax.readthedocs.io/en/latest/pytrees.html for more details about
pytrees.
This module merely re-directs imports of the actual implementations. To avoid a
direct dependency on JAX, we check if it's already present and resort to the
`tree` package otherwise.
We should be able to remove this module once b/257971667 is resolved.
"""

try:
from jax import tree_util # pytype: disable=import-error # pylint: disable=g-import-not-at-top

map_structure = tree_util.tree_map
map_structure_with_path = tree_util.tree_map_with_path

def assert_same_structure(a, b):
a_structure = tree_util.tree_structure(a)
b_structure = tree_util.tree_structure(b)
if a_structure != b_structure:
raise ValueError(
f"Structures are not the same: a = {a_structure}, b = {b_structure}"
)

def flatten(structure):
return tree_util.tree_flatten(structure)[0]

def unflatten_as(structure, flat_sequence):
return tree_util.tree_unflatten(
tree_util.tree_structure(structure), flat_sequence
)

except ImportError:
import tree # pylint: disable=g-import-not-at-top

map_structure = tree.map_structure
map_structure_with_path = tree.map_structure_with_path
assert_same_structure = tree.assert_same_structure
flatten = tree.flatten
unflatten_as = tree.unflatten_as
44 changes: 44 additions & 0 deletions grain/_src/core/tree_jax_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testes for tree.py with JAX dependency present."""

from absl.testing import absltest
from grain._src.core import tree
from grain._src.core import tree_test
import jax


class MyTree:

def __init__(self, a, b):
self.a = a
self.b = b

def __eq__(self, other):
return self.a == other.a and self.b == other.b


class TreeJaxTest(tree_test.TreeTest):

def test_map_custom_tree(self):
jax.tree_util.register_pytree_node(
MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args)
)
self.assertEqual(
tree.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
)


if __name__ == "__main__":
absltest.main()
84 changes: 84 additions & 0 deletions grain/_src/core/tree_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testes for tree.py.
Since the tree.py only re-directs the actual implementations this test does not
try to cover the actual functionality, but rather the re-direction correctness.
"""

from typing import Protocol, runtime_checkable

from absl.testing import absltest
from grain._src.core import tree


@runtime_checkable
class TreeImpl(Protocol):

def map_structure(self, f, *structures):
...

def map_structure_with_path(self, f, *structures):
...

def assert_same_structure(self, a, b):
...

def flatten(self, structure):
...

def unflatten_as(self, structure, flat_sequence):
...


# Static check that the module implements the necessary functions.
tree: TreeImpl = tree


class TreeTest(absltest.TestCase):

def test_implements_tree_protocol(self):
# Run time check that the module implements the necessary functions.
# The module impl branching happens at run time, so the static check does
# not cover both branches.
self.assertIsInstance(tree, TreeImpl)

def test_map_structure(self):
self.assertEqual(
tree.map_structure(lambda x: x + 1, ({"B": 10, "A": 20}, [1, 2], 3)),
({"B": 11, "A": 21}, [2, 3], 4),
)

def test_map_structure_with_path(self):
self.assertEqual(
tree.map_structure_with_path(
lambda path, x: x if path else None, {"B": "v1", "A": "v2"}
),
{"B": "v1", "A": "v2"},
)

def test_assert_same_structure(self):
tree.assert_same_structure({"B": "v1", "A": "v2"}, {"B": 10, "A": 20})

def test_flatten(self):
self.assertEqual(tree.flatten({"A": "v2", "B": "v1"}), ["v2", "v1"])

def test_unflatten_as(self):
self.assertEqual(
tree.unflatten_as({"A": "v2", "B": "v1"}, [1, 2]), {"A": 1, "B": 2}
)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ py_library(
deps = [
":record",
":shared_memory_array",
"//grain/_src/core:tree",
],
)

Expand Down Expand Up @@ -132,6 +133,7 @@ py_library(
":shared_memory_array",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/core:tree",
"//grain/_src/core:usage_logging",
],
)
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from concurrent import futures
from grain._src.core import sharding
from grain._src.core import transforms
from grain._src.core import tree
from grain._src.core import usage_logging
import multiprocessing as mp
from grain._src.python import grain_pool
Expand All @@ -45,7 +46,6 @@
from grain._src.python.shared_memory_array import SharedMemoryArray
from grain._src.python.shared_memory_array import SharedMemoryArrayMetadata
import numpy as np
import tree


_T = TypeVar("_T")
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/experimental/example_packing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = ["packing.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree",
"//grain/_src/python:record",
"//third_party/py/jaxtyping",
],
Expand All @@ -18,6 +19,7 @@ py_test(
srcs_version = "PY3",
deps = [
":packing",
"//grain/_src/core:tree",
"//grain/_src/python:record",
],
)
2 changes: 1 addition & 1 deletion grain/_src/python/experimental/example_packing/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import dataclasses
from typing import Generic, Iterator, TypeVar, cast

from grain._src.core import tree
from grain._src.python import record
import jax
import jaxtyping as jt
import numpy as np
import tree

_T = TypeVar("_T")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Tests for packing.py."""

from absl.testing import absltest
from grain._src.core import tree
from grain._src.python import record
from grain._src.python.experimental.example_packing import packing
import numpy as np
import tensorflow as tf
import tree


def create_input_dataset(input_dataset_elements):
Expand Down
6 changes: 5 additions & 1 deletion grain/_src/python/lazy_dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ py_library(
name = "batch",
srcs = ["batch.py"],
srcs_version = "PY3",
deps = ["//grain/_src/python/lazy_dataset"],
deps = [
"//grain/_src/core:tree",
"//grain/_src/python/lazy_dataset",
],
)

py_test(
Expand Down Expand Up @@ -108,6 +111,7 @@ py_library(
srcs = ["packing.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree",
"//grain/_src/python/lazy_dataset",
"//third_party/py/jaxtyping",
],
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/lazy_dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import math
from typing import TypeVar

from grain._src.core import tree
from grain._src.python.lazy_dataset import lazy_dataset
import numpy as np
import tree

T = TypeVar("T")

Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/lazy_dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import copy
from typing import Any

from grain._src.core import tree
from grain._src.python.lazy_dataset import lazy_dataset
from jaxtyping import PyTree # pylint: disable=g-importing-member
import numpy as np
import tree


# SingleBinPackLazyDatasetIterator's state is defined by the a state of the
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from typing import Any, Callable, Generic, Iterator, Protocol, Sequence, TypeVar

from absl import logging
from grain._src.core import tree
from grain._src.python import record
from grain._src.python.shared_memory_array import SharedMemoryArray
import numpy as np
import tree

_IN = TypeVar("_IN")
_OUT = TypeVar("_OUT")
Expand Down

0 comments on commit 65f3b16

Please sign in to comment.