Skip to content

Commit

Permalink
[pytree] add tree_iter function (pytorch#123913)
Browse files Browse the repository at this point in the history
- Add a new `tree_iter` function.
- Bump `optree` version to `0.11.0` for C++ version of `tree_iter`.

This PR is split from pytorch#120300.

- pytorch#120300

Pull Request resolved: pytorch#123913
Approved by: https://github.com/zou3519
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Apr 16, 2024
1 parent 0eab740 commit 2e48f7b
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 32 deletions.
4 changes: 2 additions & 2 deletions .ci/docker/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ opt-einsum==3.3
#Pinned versions: 3.3
#test that import: test_linalg.py

optree==0.9.1
optree==0.11.0
#Description: A library for tree manipulation
#Pinned versions: 0.9.1
#Pinned versions: 0.11.0
#test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py,
#test_pytree.py, test_ops.py, test_control_flow.py, test_modules.py,
#common_utils.py, test_eager_transforms.py, test_python_dispatch.py,
Expand Down
2 changes: 1 addition & 1 deletion .github/requirements/pip-requirements-iOS.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# iOS simulator requirements
coremltools==5.0b5
protobuf==3.20.2
optree==0.9.1
optree==0.11.0
2 changes: 1 addition & 1 deletion .github/requirements/pip-requirements-macOS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pytest-cpp==2.3.0
rockset==1.0.3
z3-solver==4.12.2.0
tensorboard==2.13.0
optree==0.9.1
optree==0.11.0
# NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in
# which the stringify metadata is wrong when escaping double quote
protobuf==3.20.2
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ init_command = [
'junitparser==2.1.1',
'rich==10.9.0',
'pyyaml==6.0.1',
'optree==0.10.0',
'optree==0.11.0',
]

[[linter]]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ fsspec
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"
packaging
optree>=0.9.1
optree>=0.11.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def main():
install_requires += extra_install_requires

extras_require = {
"optree": ["optree>=0.9.1"],
"optree": ["optree>=0.11.0"],
"opt-einsum": ["opt-einsum>=3.3"],
}

Expand Down
44 changes: 40 additions & 4 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
Expand Down Expand Up @@ -321,6 +322,41 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]


def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[None]
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
An iterator over the leaf values.
"""
return optree.tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)


def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
Expand Down Expand Up @@ -670,7 +706,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))


Expand All @@ -679,7 +715,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))


Expand Down Expand Up @@ -719,7 +755,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))


Expand Down Expand Up @@ -759,7 +795,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))


Expand Down
40 changes: 19 additions & 21 deletions torch/utils/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
Expand Down Expand Up @@ -865,32 +866,29 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return treespec.unflatten(leaves)


def _tree_leaves_helper(
def tree_iter(
tree: PyTree,
leaves: List[Any],
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> None:
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree."""
if _is_leaf(tree, is_leaf=is_leaf):
leaves.append(tree)
return

node_type = _get_node_type(tree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, _ = flatten_fn(tree)
yield tree
else:
node_type = _get_node_type(tree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, _ = flatten_fn(tree)

# Recursively flatten the children
for child in child_pytrees:
_tree_leaves_helper(child, leaves, is_leaf=is_leaf)
# Recursively flatten the children
for child in child_pytrees:
yield from tree_iter(child, is_leaf=is_leaf)


def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
"""Get a list of leaves of a pytree."""
leaves: List[Any] = []
_tree_leaves_helper(tree, leaves, is_leaf=is_leaf)
return leaves
return list(tree_iter(tree, is_leaf=is_leaf))


def tree_structure(
Expand Down Expand Up @@ -1171,7 +1169,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))


Expand All @@ -1180,7 +1178,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))


Expand Down Expand Up @@ -1220,7 +1218,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))


Expand Down Expand Up @@ -1260,7 +1258,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))


Expand Down Expand Up @@ -1468,9 +1466,9 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
"""
leaves: List[Any] = []
for a in args:
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
for a in kwargs.values():
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
return leaves


Expand Down

0 comments on commit 2e48f7b

Please sign in to comment.