Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720699320
  • Loading branch information
Grain Team authored and copybara-github committed Jan 28, 2025
1 parent c42ce44 commit ad93d88
Show file tree
Hide file tree
Showing 18 changed files with 99 additions and 96 deletions.
18 changes: 9 additions & 9 deletions grain/_src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,36 +85,36 @@ py_test(
)

py_library(
name = "tree",
name = "treelib",
srcs = [
"tree.py",
"treelib.py",
],
srcs_version = "PY3",
)

py_library(
name = "tree_test_lib",
testonly = 1,
srcs = ["tree_test.py"],
srcs = ["treelib_test.py"],
srcs_version = "PY3",
deps = [":tree"],
deps = [":treelib"],
)

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
name = "treelib_test",
srcs = ["treelib_test.py"],
srcs_version = "PY3",
deps = [
":tree_test_lib",
],
)

py_test(
name = "tree_jax_test",
srcs = ["tree_jax_test.py"],
name = "treelib_jax_test",
srcs = ["treelib_jax_test.py"],
srcs_version = "PY3",
deps = [
":tree",
":tree_test_lib",
":treelib",
],
)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testes for tree.py with JAX dependency present."""
"""Testes for treelib.py with JAX dependency present."""

from absl.testing import absltest
import attrs
from grain._src.core import tree
from grain._src.core import tree_test
from grain._src.core import treelib
import jax
import numpy as np

Expand Down Expand Up @@ -50,12 +50,12 @@ def test_map_custom_tree(self):
MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args)
)
self.assertEqual(
tree.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
treelib.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
)

def test_spec_like_with_class(self):
self.assertEqual(
tree.spec_like({"B": 1232.4, "C": MyClass(1)}),
treelib.spec_like({"B": 1232.4, "C": MyClass(1)}),
{
"B": "<class 'float'>[]",
"C": "<class '__main__.MyClass'>[]",
Expand All @@ -64,7 +64,7 @@ def test_spec_like_with_class(self):

def test_spec_like_with_list(self):
self.assertEqual(
tree.spec_like({
treelib.spec_like({
"B": 1232.4,
"C": [
tree_test.TestClass(a=1, b="v2"),
Expand All @@ -79,7 +79,7 @@ def test_spec_like_with_list(self):

def test_spec_like_with_unknown_shape(self):
self.assertEqual(
tree.spec_like({
treelib.spec_like({
"B": [np.zeros([2]), np.zeros([1])],
"C": [],
}),
Expand All @@ -88,14 +88,14 @@ def test_spec_like_with_unknown_shape(self):

def test_spec_like_with_dataclass(self):
self.assertEqual(
tree.spec_like(tree_test.TestClass(a=1, b="v2")),
treelib.spec_like(tree_test.TestClass(a=1, b="v2")),
"<class 'grain._src.core.tree_test.TestClass'>\n"
"{'a': \"<class 'int'>[]\", 'b': \"<class 'str'>[]\"}[]",
)

def test_spec_like_with_attrs(self):
self.assertEqual(
tree.spec_like(MyAttrs(d=1, e="v2")),
treelib.spec_like(MyAttrs(d=1, e="v2")),
"<class '__main__.MyAttrs'>\n"
"{'d': \"<class 'int'>[]\", 'e': \"<class 'str'>[]\"}[]",
)
Expand Down
29 changes: 15 additions & 14 deletions grain/_src/core/tree_test.py → grain/_src/core/treelib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testes for tree.py.
"""Testes for treelib.py.
Since the tree.py only re-directs the actual implementations this test does not
try to cover the actual functionality, but rather the re-direction correctness.
Since the treelib.py only re-directs the actual implementations this test does
not try to cover the actual functionality, but rather the re-direction
correctness.
"""
import dataclasses
from typing import Protocol, runtime_checkable

from absl.testing import absltest
from absl.testing import parameterized
from grain._src.core import tree
from grain._src.core import treelib
import numpy as np


Expand Down Expand Up @@ -51,7 +52,7 @@ def spec_like(self, structure):


# Static check that the module implements the necessary functions.
tree: TreeImpl = tree
treelib: TreeImpl = treelib


@dataclasses.dataclass
Expand All @@ -66,37 +67,37 @@ def test_implements_tree_protocol(self):
# Run time check that the module implements the necessary functions.
# The module impl branching happens at run time, so the static check does
# not cover both branches.
self.assertIsInstance(tree, TreeImpl)
self.assertIsInstance(treelib, TreeImpl)

def test_map_structure(self):
self.assertEqual(
tree.map_structure(lambda x: x + 1, ({"B": 10, "A": 20}, [1, 2], 3)),
treelib.map_structure(lambda x: x + 1, ({"B": 10, "A": 20}, [1, 2], 3)),
({"B": 11, "A": 21}, [2, 3], 4),
)

def test_map_structure_with_path(self):
self.assertEqual(
tree.map_structure_with_path(
treelib.map_structure_with_path(
lambda path, x: x if path else None, {"B": "v1", "A": "v2"}
),
{"B": "v1", "A": "v2"},
)

def test_assert_same_structure(self):
tree.assert_same_structure({"B": "v1", "A": "v2"}, {"B": 10, "A": 20})
treelib.assert_same_structure({"B": "v1", "A": "v2"}, {"B": 10, "A": 20})

def test_flatten(self):
self.assertEqual(tree.flatten({"A": "v2", "B": "v1"}), ["v2", "v1"])
self.assertEqual(treelib.flatten({"A": "v2", "B": "v1"}), ["v2", "v1"])

def test_flatten_with_path(self):
result = tree.flatten_with_path({"A": "v2", "B": "v1"})
result = treelib.flatten_with_path({"A": "v2", "B": "v1"})
# Maybe extract keys from path elements.
result = tree.map_structure(lambda x: getattr(x, "key", x), result)
result = treelib.map_structure(lambda x: getattr(x, "key", x), result)
self.assertEqual(result, [(("A",), "v2"), (("B",), "v1")])

def test_unflatten_as(self):
self.assertEqual(
tree.unflatten_as({"A": "v2", "B": "v1"}, [1, 2]), {"A": 1, "B": 2}
treelib.unflatten_as({"A": "v2", "B": "v1"}, [1, 2]), {"A": 1, "B": 2}
)

@parameterized.named_parameters(
Expand Down Expand Up @@ -124,7 +125,7 @@ def test_unflatten_as(self):
),
)
def test_spec_like(self, structure, expected_output):
self.assertEqual(tree.spec_like(structure), expected_output)
self.assertEqual(treelib.spec_like(structure), expected_output)

# The two tests below exercise behavior only without a Jax dependency present.
# The OSS testing runs with Jax always present so we skip them.
Expand Down
6 changes: 3 additions & 3 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ py_library(
deps = [
":record",
":shared_memory_array",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
],
)

Expand Down Expand Up @@ -133,7 +133,7 @@ py_library(
"//grain/_src/core:monitoring",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
"//grain/_src/core:usage_logging",
],
)
Expand Down Expand Up @@ -173,7 +173,7 @@ py_library(
":record",
":shared_memory_array",
"//grain/_src/core:parallel",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
],
)

Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import sharding
from grain._src.core import transforms
from grain._src.core import tree
from grain._src.core import treelib
from grain._src.core import usage_logging
import multiprocessing as mp
from grain._src.python import grain_pool
Expand Down Expand Up @@ -150,7 +150,7 @@ def copy_if_applied(element: Any) -> Any:
np.copyto(shared_memory_arr, element, casting="no")
return shared_memory_arr.metadata

return tree.map_structure(copy_if_applied, element)
return treelib.map_structure(copy_if_applied, element)


class DataLoader:
Expand Down
6 changes: 3 additions & 3 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ py_library(
"//grain/_src/core:exceptions",
"//grain/_src/core:monitoring",
"//grain/_src/core:transforms",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
"//grain/_src/core:usage_logging",
"//grain/_src/python:grain_pool",
"//grain/_src/python:options",
Expand All @@ -61,7 +61,7 @@ py_library(
srcs_version = "PY3",
deps = [
":dataset",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
"//grain/_src/python:options",
],
)
Expand All @@ -86,7 +86,7 @@ py_library(
":base",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
],
)

Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from absl import logging
from grain._src.core import config as grain_config
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import tree
from grain._src.core import treelib
from grain._src.python.dataset import base

from grain._src.core import monitoring
Expand Down Expand Up @@ -546,7 +546,7 @@ def record_self_time(self, offset_ns: int = 0):
def record_output_spec(self, element: T) -> T:
# Visualize the dataset graph once last node had seen a non-None element.
if self._self_output_spec is None:
self._self_output_spec = tree.spec_like(element)
self._self_output_spec = treelib.spec_like(element)
if self._is_output and not self._reported:
# The check above with update without a lock is not atomic, need to
# check again under a lock.
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ py_library(
srcs = ["packing_packed_batch.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree",
"//grain/_src/core:treelib",
],
)

Expand Down
6 changes: 3 additions & 3 deletions grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pprint
from typing import Callable, TypeVar

from grain._src.core import tree
from grain._src.core import treelib
from grain._src.python.dataset import dataset
import numpy as np

Expand All @@ -35,14 +35,14 @@ def _make_batch(values: Sequence[T]) -> T:
raise ValueError("Cannot batch 0 values. Please file a bug.")

try:
return tree.map_structure(lambda *xs: np.stack(xs), *values)
return treelib.map_structure(lambda *xs: np.stack(xs), *values)

except ValueError as e:
# NumPy error message doesn't include actual shapes and dtypes. Provide a
# more helpful error message.
raise ValueError(
"Expected all input elements to have the same structure but got:\n"
f"{pprint.pformat(tree.spec_like(values))}"
f"{pprint.pformat(treelib.spec_like(values))}"
) from e


Expand Down
Loading

0 comments on commit ad93d88

Please sign in to comment.