diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index a2476fc85837..c7d3e3dbfd12 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -53,7 +53,12 @@ jobs: - name: Run tests env: JAX_PLATFORMS: tpu,cpu - run: python -m pytest --tb=short tests examples + run: | + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \ + -m "not multiaccelerator" tests examples + # Run multi-accelerator across all chips + python -m pytest -m "multiaccelerator" --tb=short tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ failure() && github.ref_name == 'main' }} diff --git a/conftest.py b/conftest.py index dbc1926fb865..fed4564bbc1c 100644 --- a/conftest.py +++ b/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. """pytest configuration""" +import os import pytest @@ -24,3 +25,35 @@ def add_imports(doctest_namespace): doctest_namespace["lax"] = jax.lax doctest_namespace["jnp"] = jax.numpy doctest_namespace["np"] = numpy + + +# A pytest hook that runs immediately before test collection (i.e. when pytest +# loads all the test cases to run). When running parallel tests via xdist on +# Cloud TPU, we use this hook to set the env vars needed to run multiple test +# processes across different TPU chips. +# +# It's important that the hook runs before test collection, since jax tests end +# up initializing the TPU runtime on import (e.g. to query supported test +# types). It's also important that the hook gets called by each xdist worker +# process. Luckily each worker does its own test collection. +# +# The pytest_collection hook can be used to overwrite the collection logic, but +# we only use it to set the env vars and fall back to the default collection +# logic by always returning None. See +# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result +# for details. +# +# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an +# effect. We do this to minimize any effect on non-TPU tests, and as a pointer +# in test code to this "magic" hook. TPU tests should not specify more xdist +# workers than the number of TPU chips. +def pytest_collection() -> None: + if not os.environ.get("JAX_ENABLE_TPU_XDIST", None): + return + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw"):]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") diff --git a/pytest.ini b/pytest.ini index ae7202141fdc..6d9971df3cd2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] markers = + multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI filterwarnings = error @@ -19,7 +20,7 @@ filterwarnings = # numpy uses distutils which is deprecated ignore:The distutils.* is deprecated.*:DeprecationWarning ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning - # Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning + # Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning default:Error reading persistent compilation cache entry for 'jit__lambda_' default:Error writing persistent compilation cache entry for 'jit__lambda_' doctest_optionflags = NUMBER NORMALIZE_WHITESPACE diff --git a/tests/array_test.py b/tests/array_test.py index 3e84e0e99fe6..30916bd1af2b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for GlobalDeviceArray.""" +import contextlib import os import unittest from absl.testing import absltest @@ -43,6 +44,11 @@ prev_xla_flags = None +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + # Run all tests with 8 CPU devices. def setUpModule(): global prev_xla_flags diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 1fd19119895d..b31021bf5ee7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import re from functools import partial, lru_cache @@ -57,6 +58,11 @@ prev_xla_flags = None +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + def setUpModule(): global prev_xla_flags prev_xla_flags = os.getenv("XLA_FLAGS") diff --git a/tests/pmap_test.py b/tests/pmap_test.py index b550ef60ad6e..26ea32a91d47 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor +import contextlib from functools import partial import itertools as it import gc @@ -55,6 +56,11 @@ prev_xla_flags = None + +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] def all_bdims(*shapes, pmap): diff --git a/tests/remote_transfer_test.py b/tests/remote_transfer_test.py index 5941495ba50f..fb667c09532f 100644 --- a/tests/remote_transfer_test.py +++ b/tests/remote_transfer_test.py @@ -14,6 +14,7 @@ """Tests for cross host device transfer.""" from absl.testing import absltest +import contextlib import unittest import numpy as np @@ -24,6 +25,10 @@ config.parse_flags_with_absl() +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + class RemoteTransferTest(jtu.JaxTestCase): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 1945ae146b17..57dde5d574a8 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import itertools as it import os @@ -55,6 +56,11 @@ from jax.config import config config.parse_flags_with_absl() +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py # Run all tests with 8 CPU devices. def setUpModule():