From aa0e2184ca910d8e72ddcbb394888efeb0184d1a Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Wed, 15 Jun 2022 16:24:00 +0200 Subject: [PATCH] Add JAX array to registry (#20) --- .github/workflows/main.yml | 28 ++++++++++++++++++++- .pre-commit-config.yaml | 14 +++++------ environment.yml | 2 ++ src/pybaum/config.py | 9 +++++++ src/pybaum/equality.py | 10 ++++++++ src/pybaum/registry.py | 1 + src/pybaum/registry_entries.py | 22 +++++++++++++++++ src/pybaum/tree_util.py | 10 +++----- tests/test_tree_util.py | 17 ++++--------- tests/test_with_jax.py | 45 ++++++++++++++++++++++++++++++++++ tox.ini | 18 ++++++++++++++ 11 files changed, 149 insertions(+), 27 deletions(-) create mode 100644 tests/test_with_jax.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 488da2c..f78e484 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] + os: ['ubuntu-latest', 'macos-latest'] python-version: ['3.7', '3.8', '3.9', '3.10'] steps: @@ -46,6 +46,32 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} + run-tests-windows: + + name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ['windows-latest'] + python-version: ['3.7', '3.8', '3.9', '3.10'] + + steps: + - uses: actions/checkout@v2 + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: ${{ matrix.python-version }} + + - name: Install core dependencies. + shell: bash -l {0} + run: conda install -c conda-forge tox-conda + + - name: Run pytest. + shell: bash -l {0} + run: tox -e pytest-windows -- -m "not slow" + docs: name: Run documentation. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 40fd5fd..258fe23 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.3.0 hooks: - id: check-merge-conflict - id: debug-statements @@ -11,7 +11,7 @@ repos: - id: reorder-python-imports types: [python] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.3.0 hooks: - id: check-added-large-files args: ['--maxkb=100'] @@ -42,13 +42,13 @@ repos: rev: v1.12.1 hooks: - id: blacken-docs - additional_dependencies: [black] + additional_dependencies: [black==22.3.0] types: [rst] - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black - types: [python] + language_version: python3.9 - repo: https://github.com/PyCQA/flake8 rev: 4.0.1 hooks: @@ -71,7 +71,7 @@ repos: Pygments, ] - repo: https://github.com/PyCQA/doc8 - rev: 0.11.1 + rev: 0.11.2 hooks: - id: doc8 - repo: meta @@ -84,7 +84,7 @@ repos: hooks: - id: check-manifest - repo: https://github.com/PyCQA/doc8 - rev: 0.11.1 + rev: 0.11.2 hooks: - id: doc8 - repo: https://github.com/asottile/setup-cfg-fmt @@ -102,7 +102,7 @@ repos: hooks: - id: codespell - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v2.34.0 hooks: - id: pyupgrade args: [--py37-plus] diff --git a/environment.yml b/environment.yml index 0cf0dd2..8c03220 100644 --- a/environment.yml +++ b/environment.yml @@ -30,3 +30,5 @@ dependencies: - pdbpp - numpy - pandas + - jax + - jaxlib diff --git a/src/pybaum/config.py b/src/pybaum/config.py index 5645ec4..f021057 100644 --- a/src/pybaum/config.py +++ b/src/pybaum/config.py @@ -12,3 +12,12 @@ IS_PANDAS_INSTALLED = False else: IS_PANDAS_INSTALLED = True + + +try: + import jax # noqa: F401 + import jaxlib # noqa: F401 +except ImportError: + IS_JAX_INSTALLED = False +else: + IS_JAX_INSTALLED = True diff --git a/src/pybaum/equality.py b/src/pybaum/equality.py index 33ccd63..dd6b146 100644 --- a/src/pybaum/equality.py +++ b/src/pybaum/equality.py @@ -1,4 +1,5 @@ """Functions to check equality of pytree leaves.""" +from pybaum.config import IS_JAX_INSTALLED from pybaum.config import IS_NUMPY_INSTALLED from pybaum.config import IS_PANDAS_INSTALLED @@ -10,6 +11,9 @@ if IS_PANDAS_INSTALLED: import pandas as pd +if IS_JAX_INSTALLED: + import jaxlib + EQUALITY_CHECKERS = {} @@ -21,3 +25,9 @@ if IS_PANDAS_INSTALLED: EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b) EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b) + + +if IS_JAX_INSTALLED: + EQUALITY_CHECKERS[jaxlib.xla_extension.DeviceArray] = lambda a, b: bool( + (a == b).all() + ) diff --git a/src/pybaum/registry.py b/src/pybaum/registry.py index f78304a..8a73d4c 100644 --- a/src/pybaum/registry.py +++ b/src/pybaum/registry.py @@ -15,6 +15,7 @@ def get_registry(types=None, include_defaults=True): - :obj:`None` - :class:`collections.OrderedDict` - "numpy.ndarray" + - "jax.numpy.ndarray" - "pandas.Series" - "pandas.DataFrame" include_defaults (bool): Whether the default pytree containers "tuple", "dict" diff --git a/src/pybaum/registry_entries.py b/src/pybaum/registry_entries.py index 3d2d244..3fa0bd2 100644 --- a/src/pybaum/registry_entries.py +++ b/src/pybaum/registry_entries.py @@ -3,6 +3,7 @@ from collections import OrderedDict from itertools import product +from pybaum.config import IS_JAX_INSTALLED from pybaum.config import IS_NUMPY_INSTALLED from pybaum.config import IS_PANDAS_INSTALLED @@ -12,6 +13,10 @@ if IS_PANDAS_INSTALLED: import pandas as pd +if IS_JAX_INSTALLED: + import jax + import jaxlib + def _none(): """Create registry entry for NoneType.""" @@ -117,6 +122,22 @@ def _array_element_names(arr): return names +def _jax_array(): + if IS_JAX_INSTALLED: + entry = { + jaxlib.xla_extension.DeviceArray: { + "flatten": lambda arr: (arr.flatten().tolist(), arr.shape), + "unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape( + aux_data + ), + "names": _array_element_names, + }, + } + else: + entry = {} + return entry + + def _pandas_series(): """Create registry entry for pandas.Series.""" if IS_PANDAS_INSTALLED: @@ -186,6 +207,7 @@ def _index_element_to_string(element): "tuple": _tuple, "dict": _dict, "numpy.ndarray": _numpy_array, + "jax.numpy.ndarray": _jax_array, "pandas.Series": _pandas_series, "pandas.DataFrame": _pandas_dataframe, "None": _none, diff --git a/src/pybaum/tree_util.py b/src/pybaum/tree_util.py index 0281330..769fcfe 100644 --- a/src/pybaum/tree_util.py +++ b/src/pybaum/tree_util.py @@ -6,8 +6,6 @@ - The treedef containing information to unflatten pytrees is implemented differently. """ -import itertools - from pybaum.equality import EQUALITY_CHECKERS from pybaum.registry import get_registry from pybaum.typecheck import get_type @@ -42,8 +40,8 @@ def tree_flatten(tree, is_leaf=None, registry=None): is_leaf = _process_is_leaf(is_leaf) flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry) - dummy_flat = ["*"] * len(flat) - treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry) + # unflatten the flat tree to make a copy + treedef = tree_unflatten(tree, flat, is_leaf=is_leaf, registry=registry) return flat, treedef @@ -124,9 +122,7 @@ def tree_yield(tree, is_leaf=None, registry=None): is_leaf = _process_is_leaf(is_leaf) flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry) - dummy_flat = itertools.repeat("*") - treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry) - return flat, treedef + return flat, tree def tree_just_yield(tree, is_leaf=None, registry=None): diff --git a/tests/test_tree_util.py b/tests/test_tree_util.py index f076d1c..0db0f3e 100644 --- a/tests/test_tree_util.py +++ b/tests/test_tree_util.py @@ -31,20 +31,13 @@ def example_flat(): @pytest.fixture -def example_treedef(): - return (["*", "*", {"a": "*", "b": "*"}], "*") +def example_treedef(example_tree): + return example_tree @pytest.fixture -def extended_treedef(): - return ( - [ - "*", - np.array(["*", "*"]), - {"a": pd.Series(["*", "*"], index=["c", "d"]), "b": "*"}, - ], - "*", - ) +def extended_treedef(example_tree): + return example_tree @pytest.fixture @@ -195,7 +188,7 @@ def test_flatten_df_all_columns(): def test_tree_yield(example_tree, example_treedef, example_flat): generator, treedef = tree_yield(example_tree) - assert treedef == example_treedef + assert tree_equal(treedef, example_treedef) assert inspect.isgenerator(generator) for a, b in zip(generator, example_flat): if isinstance(a, (np.ndarray, pd.Series)): diff --git a/tests/test_with_jax.py b/tests/test_with_jax.py new file mode 100644 index 0000000..bbd5be2 --- /dev/null +++ b/tests/test_with_jax.py @@ -0,0 +1,45 @@ +import pytest +from pybaum.config import IS_JAX_INSTALLED +from pybaum.registry import get_registry +from pybaum.tree_util import leaf_names +from pybaum.tree_util import tree_equal +from pybaum.tree_util import tree_flatten +from pybaum.tree_util import tree_just_flatten + +if IS_JAX_INSTALLED: + import jax.numpy as jnp +else: + # run the tests with normal numpy instead + import numpy as jnp + + +@pytest.fixture +def tree(): + return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)} + + +@pytest.fixture +def flat(): + return [0, 1, 2, 3, 1, 1] + + +@pytest.fixture +def registry(): + return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"]) + + +def test_tree_just_flatten_with_jax(tree, registry, flat): + got = tree_just_flatten(tree, registry=registry) + assert got == flat + + +def test_tree_flatten_with_jax(tree, registry, flat): + got_flat, got_treedef = tree_flatten(tree, registry=registry) + assert got_flat == flat + assert tree_equal(got_treedef, tree) + + +def test_leaf_names_with_jax(tree, registry): + got = leaf_names(tree, registry=registry) + expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"] + assert got == expected diff --git a/tox.ini b/tox.ini index 8a79d33..8d18e58 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,24 @@ skip_missing_interpreters = True basepython = python [testenv:pytest] +setenv = + CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 +conda_channels = + conda-forge + defaults +conda_deps = + conda-build + numpy + pandas + pytest + pytest-cov + pytest-mock + pytest-xdist + jax + jaxlib +commands = pytest {posargs} + +[testenv:pytest-windows] setenv = CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 conda_channels =