From 120125f3dd6e516e1b7e727b2504182724f697b2 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 17 Nov 2022 05:33:54 +0000 Subject: [PATCH] Make pytest-xdist work on TPU and update Cloud TPU CI. This change also marks multiaccelerator test files in a way pytest can understand (if pytest is installed). By running single-device tests on a single TPU chip, running the test suite goes from 1hr 45m to 35m (both timings are running slow tests). I tried using bazel at first, which already supported parallel execution across TPU cores, but somehow it still takes 2h 20m! I'm not sure why it's so slow. It appears that bazel creates many new test processes over time, vs. pytest reuses the number of processes initially specified, and starting and stopping the TPU runtime takes a few seconds so that may be adding up. It also appears that single-process bazel is slower than single-process pytest, which I haven't looked into yet. --- .github/workflows/cloud-tpu-ci-nightly.yml | 7 ++++- conftest.py | 33 ++++++++++++++++++++++ pytest.ini | 3 +- tests/array_test.py | 6 ++++ tests/pjit_test.py | 6 ++++ tests/pmap_test.py | 6 ++++ tests/remote_transfer_test.py | 5 ++++ tests/xmap_test.py | 6 ++++ 8 files changed, 70 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index a2476fc85837..c7d3e3dbfd12 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -53,7 +53,12 @@ jobs: - name: Run tests env: JAX_PLATFORMS: tpu,cpu - run: python -m pytest --tb=short tests examples + run: | + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \ + -m "not multiaccelerator" tests examples + # Run multi-accelerator across all chips + python -m pytest -m "multiaccelerator" --tb=short tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ failure() && github.ref_name == 'main' }} diff --git a/conftest.py b/conftest.py index dbc1926fb865..fed4564bbc1c 100644 --- a/conftest.py +++ b/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. """pytest configuration""" +import os import pytest @@ -24,3 +25,35 @@ def add_imports(doctest_namespace): doctest_namespace["lax"] = jax.lax doctest_namespace["jnp"] = jax.numpy doctest_namespace["np"] = numpy + + +# A pytest hook that runs immediately before test collection (i.e. when pytest +# loads all the test cases to run). When running parallel tests via xdist on +# Cloud TPU, we use this hook to set the env vars needed to run multiple test +# processes across different TPU chips. +# +# It's important that the hook runs before test collection, since jax tests end +# up initializing the TPU runtime on import (e.g. to query supported test +# types). It's also important that the hook gets called by each xdist worker +# process. Luckily each worker does its own test collection. +# +# The pytest_collection hook can be used to overwrite the collection logic, but +# we only use it to set the env vars and fall back to the default collection +# logic by always returning None. See +# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result +# for details. +# +# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an +# effect. We do this to minimize any effect on non-TPU tests, and as a pointer +# in test code to this "magic" hook. TPU tests should not specify more xdist +# workers than the number of TPU chips. +def pytest_collection() -> None: + if not os.environ.get("JAX_ENABLE_TPU_XDIST", None): + return + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw"):]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") diff --git a/pytest.ini b/pytest.ini index ae7202141fdc..6d9971df3cd2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] markers = + multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI filterwarnings = error @@ -19,7 +20,7 @@ filterwarnings = # numpy uses distutils which is deprecated ignore:The distutils.* is deprecated.*:DeprecationWarning ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning - # Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning + # Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning default:Error reading persistent compilation cache entry for 'jit__lambda_' default:Error writing persistent compilation cache entry for 'jit__lambda_' doctest_optionflags = NUMBER NORMALIZE_WHITESPACE diff --git a/tests/array_test.py b/tests/array_test.py index 3e84e0e99fe6..30916bd1af2b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for GlobalDeviceArray.""" +import contextlib import os import unittest from absl.testing import absltest @@ -43,6 +44,11 @@ prev_xla_flags = None +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + # Run all tests with 8 CPU devices. def setUpModule(): global prev_xla_flags diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 1fd19119895d..b31021bf5ee7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import re from functools import partial, lru_cache @@ -57,6 +58,11 @@ prev_xla_flags = None +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + def setUpModule(): global prev_xla_flags prev_xla_flags = os.getenv("XLA_FLAGS") diff --git a/tests/pmap_test.py b/tests/pmap_test.py index b550ef60ad6e..26ea32a91d47 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor +import contextlib from functools import partial import itertools as it import gc @@ -55,6 +56,11 @@ prev_xla_flags = None + +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] def all_bdims(*shapes, pmap): diff --git a/tests/remote_transfer_test.py b/tests/remote_transfer_test.py index 5941495ba50f..fb667c09532f 100644 --- a/tests/remote_transfer_test.py +++ b/tests/remote_transfer_test.py @@ -14,6 +14,7 @@ """Tests for cross host device transfer.""" from absl.testing import absltest +import contextlib import unittest import numpy as np @@ -24,6 +25,10 @@ config.parse_flags_with_absl() +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + class RemoteTransferTest(jtu.JaxTestCase): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 1945ae146b17..57dde5d574a8 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import itertools as it import os @@ -55,6 +56,11 @@ from jax.config import config config.parse_flags_with_absl() +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py # Run all tests with 8 CPU devices. def setUpModule():