Skip to content

Commit

Permalink
Update JAX's XlaExecutable.cost_analysis and related plumbing so it w…
Browse files Browse the repository at this point in the history
…orks on Cloud TPU

* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
  • Loading branch information
skye authored and jax authors committed Jun 22, 2023
1 parent 9f4080a commit 10424c5
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 55 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Remember to align the itemized text with the first line of an item within a list
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
* Executable.cost_analysis() works on Cloud TPU

* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
Expand Down
33 changes: 17 additions & 16 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,21 @@ def as_text(self) -> str:
else:
raise

# TODO(skyewm): this should return a single Dict (I think returning a list
# was to support MPMD executables, which never fully landed)
def cost_analysis(self) -> List[Dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("cost analysis unsupported on current XLA backend: "
f"{type(xla_ext_exe)}")

# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
if hasattr(xla_ext_exe, "cost_analysis"):
try:
return [xla_ext_exe.cost_analysis()]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
raise

# Try client method if executable cost_analysis method is unimplemented
if hasattr(xla_ext_exe, "client"):
try:
return [
Expand All @@ -245,21 +255,12 @@ def cost_analysis(self) -> List[Dict[str, float]]:
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
elif hasattr(xla_ext_exe, "cost_analysis"):
try:
return xla_ext_exe.cost_analysis()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
raise
else:
raise NotImplementedError(err_msg)

raise NotImplementedError(
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"
)

def memory_analysis(self) -> Any:
xla_ext_exe = self.xla_extension_executable()
Expand Down
79 changes: 40 additions & 39 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,66 @@

import collections
import collections.abc
import concurrent.futures
from contextlib import contextmanager
import copy
import enum
import functools
from functools import partial
import inspect
import gc
import importlib
import inspect
import itertools as it
import operator
import operator as op
import os
import platform
import re
import subprocess
import sys
import types
from typing import Callable, List, Optional, NamedTuple
from typing import Callable, List, NamedTuple, Optional
import unittest
import warnings
import weakref
import functools
import itertools as it
import operator as op
import gc

from absl import logging
from absl.testing import absltest, parameterized
import numpy as np

import concurrent.futures

import jax
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax._src import core
from jax._src import config as config_internal
from jax import config
from jax import custom_derivatives as custom_derivatives_public
from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit
from jax import lax
from jax._src import api, dtypes, lib, api_util
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax.sharding import PartitionSpec as P
from jax import tree_util
from jax._src import api, api_util, dtypes, lib
from jax._src import array
from jax.experimental import pjit
from jax._src import config as config_internal
from jax._src import core
from jax._src import custom_derivatives
from jax import custom_derivatives as custom_derivatives_public
from jax._src import linear_util as lu
from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src import test_util as jtu
from jax import tree_util
from jax._src import linear_util as lu
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax._src.ad_checkpoint import saved_residuals
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
from jax.errors import UnexpectedTracerError
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import xla
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np

from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

Expand Down Expand Up @@ -305,7 +304,6 @@ def test_jit_default_platform(self):
def test_complex_support(self):
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)


@parameterized.parameters("static_argnums", "donate_argnums")
def test_jit_argnums_overflow_error(self, argnum_type: str):
def f(a, b, c):
Expand Down Expand Up @@ -346,7 +344,6 @@ def i():
self.jit(h, **{argnum_type: (0, 999)})
self.jit(h, **{argnum_type: (0, -999)})


# No positional arguments
self.jit(i, static_argnums=())
self.jit(i)
Expand Down Expand Up @@ -385,7 +382,6 @@ def h(a, /, b, c, *args, **kwargs):
with self.assertWarns(SyntaxWarning):
self.jit(h, static_argnames=("args", "c"))


def test_jit_with_many_args_works(self):

@self.jit
Expand Down Expand Up @@ -468,7 +464,8 @@ def f(x):

def test_jit_cache_clear(self):
@self.jit
def f(x, y): return x + y
def f(x, y):
return x + y

client = jax.devices()[0].client
gc.collect()
Expand Down Expand Up @@ -1106,8 +1103,12 @@ def test_jit_lower_cost_analysis(self):
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
if xla_extension_version >= 164:
self.assertIsNotNone(f.cost_analysis())
self.assertIsNotNone(g.cost_analysis())
else:
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_memory_analysis(self):
Expand Down

0 comments on commit 10424c5

Please sign in to comment.