From 10424c597273aaf157987a5e10fac82bbba8cc1b Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 22 Jun 2023 14:42:14 -0700 Subject: [PATCH] Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU * Exposes LoadedExecutable.cost_analysis via pybind * Updates XlaExecutable.cost_analysis to try LoadedExecutable.cost_analysis, then fallback to the client method. PiperOrigin-RevId: 542671990 --- CHANGELOG.md | 1 + jax/_src/stages.py | 33 +++++++++---------- tests/api_test.py | 79 +++++++++++++++++++++++----------------------- 3 files changed, 58 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c1f014c4a3..4598b9a5a9b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Remember to align the itemized text with the first line of an item within a list determine the output shardings. * If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh. + * Executable.cost_analysis() works on Cloud TPU * Bug fixes * Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 9310b2227083..e79c1ddfa83d 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -232,11 +232,21 @@ def as_text(self) -> str: else: raise + # TODO(skyewm): this should return a single Dict (I think returning a list + # was to support MPMD executables, which never fully landed) def cost_analysis(self) -> List[Dict[str, float]]: xla_ext_exe = self.xla_extension_executable() - err_msg = ("cost analysis unsupported on current XLA backend: " - f"{type(xla_ext_exe)}") + # TODO(b/259255524): Unify/merge the two cost_analysis calls below. + if hasattr(xla_ext_exe, "cost_analysis"): + try: + return [xla_ext_exe.cost_analysis()] + except xla_extension.XlaRuntimeError as e: + msg, *_ = e.args + if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): + raise + + # Try client method if executable cost_analysis method is unimplemented if hasattr(xla_ext_exe, "client"): try: return [ @@ -245,21 +255,12 @@ def cost_analysis(self) -> List[Dict[str, float]]: ] except xla_extension.XlaRuntimeError as e: msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - elif hasattr(xla_ext_exe, "cost_analysis"): - try: - return xla_ext_exe.cost_analysis() - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: + if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): raise - else: - raise NotImplementedError(err_msg) + + raise NotImplementedError( + f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}" + ) def memory_analysis(self) -> Any: xla_ext_exe = self.xla_extension_executable() diff --git a/tests/api_test.py b/tests/api_test.py index fdf956499246..3e0cc5373c3f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -15,67 +15,66 @@ import collections import collections.abc +import concurrent.futures from contextlib import contextmanager import copy import enum +import functools from functools import partial -import inspect +import gc import importlib +import inspect +import itertools as it import operator +import operator as op import os import platform import re import subprocess import sys import types -from typing import Callable, List, Optional, NamedTuple +from typing import Callable, List, NamedTuple, Optional import unittest import warnings import weakref -import functools -import itertools as it -import operator as op -import gc from absl import logging from absl.testing import absltest, parameterized -import numpy as np - -import concurrent.futures - import jax -import jax.custom_batching -import jax.custom_derivatives -import jax.custom_transpose -import jax.numpy as jnp -from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian -from jax._src import core -from jax._src import config as config_internal +from jax import config +from jax import custom_derivatives as custom_derivatives_public +from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit from jax import lax -from jax._src import api, dtypes, lib, api_util -from jax.errors import UnexpectedTracerError -from jax.interpreters import ad -from jax._src.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters import batching -from jax._src.interpreters import partial_eval as pe -from jax.sharding import PartitionSpec as P +from jax import tree_util +from jax._src import api, api_util, dtypes, lib from jax._src import array -from jax.experimental import pjit +from jax._src import config as config_internal +from jax._src import core from jax._src import custom_derivatives -from jax import custom_derivatives as custom_derivatives_public +from jax._src import linear_util as lu from jax._src import prng +from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.ad_checkpoint import saved_residuals +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src import test_util as jtu -from jax import tree_util -from jax._src import linear_util as lu +from jax._src.lib import xla_extension_version import jax._src.util as jax_util -from jax._src.ad_checkpoint import saved_residuals -from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name +from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +from jax.errors import UnexpectedTracerError +from jax.experimental import pjit +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import xla +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import numpy as np -from jax import config config.parse_flags_with_absl() FLAGS = config.FLAGS @@ -305,7 +304,6 @@ def test_jit_default_platform(self): def test_complex_support(self): self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j) - @parameterized.parameters("static_argnums", "donate_argnums") def test_jit_argnums_overflow_error(self, argnum_type: str): def f(a, b, c): @@ -346,7 +344,6 @@ def i(): self.jit(h, **{argnum_type: (0, 999)}) self.jit(h, **{argnum_type: (0, -999)}) - # No positional arguments self.jit(i, static_argnums=()) self.jit(i) @@ -385,7 +382,6 @@ def h(a, /, b, c, *args, **kwargs): with self.assertWarns(SyntaxWarning): self.jit(h, static_argnames=("args", "c")) - def test_jit_with_many_args_works(self): @self.jit @@ -468,7 +464,8 @@ def f(x): def test_jit_cache_clear(self): @self.jit - def f(x, y): return x + y + def f(x, y): + return x + y client = jax.devices()[0].client gc.collect() @@ -1106,8 +1103,12 @@ def test_jit_lower_cost_analysis(self): def test_jit_lower_compile_cost_analysis(self): f = self.jit(lambda x: x).lower(1.).compile() g = self.jit(lambda x: x + 4).lower(1.).compile() - f.cost_analysis() # doesn't raise - g.cost_analysis() # doesn't raise + if xla_extension_version >= 164: + self.assertIsNotNone(f.cost_analysis()) + self.assertIsNotNone(g.cost_analysis()) + else: + f.cost_analysis() # doesn't raise + g.cost_analysis() # doesn't raise @jtu.skip_on_xla_cpu_mlir def test_jit_lower_compile_memory_analysis(self):