Skip to content

Commit

Permalink
Merge pull request #182 from ROCm/ci-upstream-sync-55_1
Browse files Browse the repository at this point in the history
CI: 12/10/24 upstream sync
  • Loading branch information
github-actions[bot] authored Dec 10, 2024
2 parents f04f164 + 20c5c71 commit 6dc4dee
Show file tree
Hide file tree
Showing 33 changed files with 949 additions and 741 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.38

* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.

## jax 0.4.37 (Dec 9, 2024)

This is a patch release of jax 0.4.36. Only "jax" was released at this version.
Expand Down
1 change: 0 additions & 1 deletion docs/contributor_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@ some of JAX's (extensible) internals.

autodidax
jep/index
jax_internal_api
18 changes: 18 additions & 0 deletions docs/jax.extend.core.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
``jax.extend.core`` module
==========================

.. automodule:: jax.extend.core

.. autosummary::
:toctree: _autosummary

ClosedJaxpr
Jaxpr
JaxprEqn
Literal
Primitive
Token
Var
array_types
jaxpr_as_fun
primitives
1 change: 1 addition & 0 deletions docs/jax.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Modules
.. toctree::
:maxdepth: 1

jax.extend.core
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir
Expand Down
14 changes: 0 additions & 14 deletions docs/jax_internal_api.rst

This file was deleted.

187 changes: 123 additions & 64 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,33 @@ def compile_or_get_cached(

use_compilation_cache = compilation_cache.is_cache_used(backend)

is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1
)
min_device_process_id = min(
devices.flatten(), key=lambda device: device.id
).process_index
is_auto_pgle_used = (
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
)

if not use_compilation_cache:
if (
is_multi_process
and is_auto_pgle_used
and distributed.global_state.client is not None
):
compile_options.executable_build_options.fdo_profile = (
_share_fdo_profiles(
computation,
devices,
compile_options,
backend,
distributed.global_state.client,
min_device_process_id,
)
)

return backend_compile(backend, computation, compile_options,
host_callbacks)

Expand All @@ -373,61 +399,18 @@ def compile_or_get_cached(
return backend_compile(backend, computation, compile_options,
host_callbacks)

is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1)
min_device_process_id = (
min(devices.flatten(), key=lambda device: device.id).process_index)

# When PGLE is enabled there might be 3 types of situations:
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
# in the persistent cache. In this case the module should be returned from
# cache and PGLE should be disabled for this module. Is module is stored in
# the persistent cache under the "pgle_profiled_module_key" which calculated
# with replacing FDO profile with flag which identify that module were PGLE
# profiled.
# 2. PGLE profiled module is not in the persistent cache and the module is
# getting built with an FDO profile. In this case we need to share FDO profile
# with other processes and store the result under the
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
# module.
# 3. PGLE profiled module is not in the persistent cache and the module is
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
# simply return the non-PGLE profiled module from the persistent cache.
if (config.enable_pgle.value
and config.pgle_profiling_runs.value > 0):
fdo_profile = compile_options.executable_build_options.fdo_profile
compile_options.executable_build_options.fdo_profile = b"pgle profiled"

pgle_profiled_module_key = compilation_cache.get_cache_key(
if is_auto_pgle_used:
cache_key = _resolve_pgle_module_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
pgle_profiler,
is_multi_process,
cache_key,
module_name,
min_device_process_id,
)
compile_options.executable_build_options.fdo_profile = fdo_profile

if _is_executable_in_cache(backend, pgle_profiled_module_key):
# Load PGLE profiled module from the persistent cache.
cache_key = pgle_profiled_module_key
if pgle_profiler is not None:
pgle_profiler.disable()
elif fdo_profile is not None and len(fdo_profile) > 0:
# Store module under PGLE profiled module cache key.
cache_key = pgle_profiled_module_key
if is_multi_process and distributed.global_state.client is not None:
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
computation, devices, compile_options, backend,
distributed.global_state.client,
min_device_process_id
)
else:
compile_options.executable_build_options.fdo_profile = fdo_profile
logger.debug(
"Compiling module %s with FDO profile: %s",
module_name,
compile_options.executable_build_options.fdo_profile,
)

cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(
Expand Down Expand Up @@ -493,6 +476,75 @@ def compile_or_get_cached(
cache_key,
)


# When PGLE is enabled there might be 3 types of situations:
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
# in the persistent cache. In this case the module should be returned from
# cache and PGLE should be disabled for this module. Is module is stored in
# the persistent cache under the "pgle_profiled_module_key" which calculated
# with replacing FDO profile with flag which identify that module were PGLE
# profiled.
# 2. PGLE profiled module is not in the persistent cache and the module is
# getting built with an FDO profile. In this case we need to share FDO profile
# with other processes and store the result under the
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
# module.
# 3. PGLE profiled module is not in the persistent cache and the module is
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
# simply return the non-PGLE profiled module from the persistent cache.
def _resolve_pgle_module_cache_key(
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
backend: xc.Client,
pgle_profiler: profiler.PGLEProfiler | None,
is_multi_process: bool,
cache_key: str,
module_name: str,
min_device_process_id: int,
) -> str:
fdo_profile = compile_options.executable_build_options.fdo_profile
compile_options.executable_build_options.fdo_profile = b"pgle profiled"

pgle_profiled_module_key = compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
compile_options.executable_build_options.fdo_profile = fdo_profile

result_key = cache_key
if _is_executable_in_cache(backend, pgle_profiled_module_key):
# Load PGLE profiled module from the persistent cache.
result_key = pgle_profiled_module_key
if pgle_profiler is not None:
pgle_profiler.disable()
elif fdo_profile is not None and len(fdo_profile) > 0:
# Store module under PGLE profiled module cache key.
result_key = pgle_profiled_module_key
if is_multi_process and distributed.global_state.client is not None:
compile_options.executable_build_options.fdo_profile = (
_share_fdo_profiles(
computation,
devices,
compile_options,
backend,
distributed.global_state.client,
min_device_process_id,
)
)
else:
compile_options.executable_build_options.fdo_profile = fdo_profile
logger.debug(
"Compiling module %s with FDO profile of length %d",
module_name,
len(compile_options.executable_build_options.fdo_profile),
)
return result_key


# The process that has the lowest device ID should share FDO profile before
# compilation with other processes.
def _share_fdo_profiles(
Expand All @@ -510,32 +562,39 @@ def _share_fdo_profiles(
return fdo_profile

compile_options.executable_build_options.fdo_profile = b""
profile_key = (
compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
+ "_fdo_sync"
)
try:
profile_key = (
compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
+ "_fdo_sync"
)
except xc._xla.XlaRuntimeError as ex:
logger.error(
"compile_or_get_cached: unable to generate cache key, "
"skipping the fdo profile sharing: %s",
ex,
)
return fdo_profile

if profile_key in _share_fdo_profiles.modules_profiles:
return _share_fdo_profiles.modules_profiles[profile_key]

share_timeout = config.share_binary_between_hosts_timeout_ms.value
if distributed.global_state.process_id == min_process_id:
logger.debug(
"Sharing FDO profile: %s. For module %s. Process %d.",
fdo_profile,
"Module %s. Sharing FDO profile. Process %d.",
module_name,
min_process_id,
)
global_client.key_value_set_bytes(profile_key, fdo_profile)
else:
logger.debug(
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
fdo_profile,
"Module %s. Waiting for FDO profile which should be set by process %d.",
module_name,
min_process_id,
)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import math

import jax
from jax import core
from jax import dtypes
from jax._src import core
from jax._src import dispatch
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/cudnn/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import functools
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir
Expand Down
Loading

0 comments on commit 6dc4dee

Please sign in to comment.