Skip to content

Commit

Permalink
Make pytest-xdist work on TPU and update Cloud TPU CI.
Browse files Browse the repository at this point in the history
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
  • Loading branch information
skye committed Nov 18, 2022
1 parent 3a837c8 commit 120125f
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ jobs:
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
run: python -m pytest --tb=short tests examples
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
-m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest -m "multiaccelerator" --tb=short tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ failure() && github.ref_name == 'main' }}
Expand Down
33 changes: 33 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""pytest configuration"""

import os
import pytest


Expand All @@ -24,3 +25,35 @@ def add_imports(doctest_namespace):
doctest_namespace["lax"] = jax.lax
doctest_namespace["jnp"] = jax.numpy
doctest_namespace["np"] = numpy


# A pytest hook that runs immediately before test collection (i.e. when pytest
# loads all the test cases to run). When running parallel tests via xdist on
# Cloud TPU, we use this hook to set the env vars needed to run multiple test
# processes across different TPU chips.
#
# It's important that the hook runs before test collection, since jax tests end
# up initializing the TPU runtime on import (e.g. to query supported test
# types). It's also important that the hook gets called by each xdist worker
# process. Luckily each worker does its own test collection.
#
# The pytest_collection hook can be used to overwrite the collection logic, but
# we only use it to set the env vars and fall back to the default collection
# logic by always returning None. See
# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result
# for details.
#
# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
# in test code to this "magic" hook. TPU tests should not specify more xdist
# workers than the number of TPU chips.
def pytest_collection() -> None:
if not os.environ.get("JAX_ENABLE_TPU_XDIST", None):
return
# When running as an xdist worker, will be something like "gw0"
xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "")
if not xdist_worker_name.startswith("gw"):
return
xdist_worker_number = int(xdist_worker_name[len("gw"):])
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[pytest]
markers =
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =
error
Expand All @@ -19,7 +20,7 @@ filterwarnings =
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
default:Error reading persistent compilation cache entry for 'jit__lambda_'
default:Error writing persistent compilation cache entry for 'jit__lambda_'
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
Expand Down
6 changes: 6 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""

import contextlib
import os
import unittest
from absl.testing import absltest
Expand Down Expand Up @@ -43,6 +44,11 @@

prev_xla_flags = None

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
Expand Down
6 changes: 6 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
import re
from functools import partial, lru_cache
Expand Down Expand Up @@ -57,6 +58,11 @@

prev_xla_flags = None

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
Expand Down
6 changes: 6 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
Expand Down Expand Up @@ -55,6 +56,11 @@

prev_xla_flags = None


with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator

compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]

def all_bdims(*shapes, pmap):
Expand Down
5 changes: 5 additions & 0 deletions tests/remote_transfer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for cross host device transfer."""

from absl.testing import absltest
import contextlib
import unittest
import numpy as np

Expand All @@ -24,6 +25,10 @@

config.parse_flags_with_absl()

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


class RemoteTransferTest(jtu.JaxTestCase):

Expand Down
6 changes: 6 additions & 0 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import itertools as it
import os
Expand Down Expand Up @@ -55,6 +56,11 @@
from jax.config import config
config.parse_flags_with_absl()

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
# Run all tests with 8 CPU devices.
def setUpModule():
Expand Down

0 comments on commit 120125f

Please sign in to comment.