Skip to content

Commit

Permalink
[pt2][inductor] only use global cache on MAST (pytorch#105375)
Browse files Browse the repository at this point in the history
Summary:
until we can further investigate the autotuning differences between MAST and non-MAST (devserver) environments, turn off the global cache for all non-MAST environments. this ensures we don't see unexpected regressions

also update scuba logging for cache lookup, and add scuba logging for autotuning results.

Test Plan: sandcastle + CI

Differential Revision: D47516633

Pull Request resolved: pytorch#105375
Approved by: https://github.com/jansel
  • Loading branch information
nmacchioni authored and pytorchmergebot committed Jul 18, 2023
1 parent 8010f6b commit 6ca3d7e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
36 changes: 25 additions & 11 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,22 @@
from triton.fb import build_paths
from triton.fb.build import _run_build_command

from torch._inductor.fb.logging import global_cache_log
from torch._inductor.fb.utils import (
log_global_cache_stats,
log_global_cache_vals,
use_global_cache,
)
else:

def global_cache_log(*args, **kwargs):
def log_global_cache_stats(*args, **kwargs):
pass

def log_global_cache_vals(*args, **kwargs):
pass

def use_global_cache():
return False


LOCK_TIMEOUT = 600

Expand Down Expand Up @@ -224,7 +234,8 @@ def lookup(
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
"""

gc_log = partial(global_cache_log, self.system, name, inputs)
log_stats = partial(log_global_cache_stats, self.system, name, inputs)
log_vals = partial(log_global_cache_vals, self.system, name, inputs)
timings = {}

def check_cache(cache, callback=None):
Expand All @@ -235,20 +246,20 @@ def check_cache(cache, callback=None):
if choice_hash in cache.get(name, {}).get(inputs, {}):
# cache hit
timings[choice] = cache[name][inputs][choice_hash]
if callback:
callback(choice_hash, cached=True)
else:
# cache miss
hit = False
if callback:
callback(choice_hash, cached=False)
break
if callback:
callback(cached=hit)
return hit

if config.max_autotune or config.max_autotune_gemm:
local_cache = self.get_local_cache()
# check local cache first since it is data specific to the current machine
if not check_cache(local_cache) and not check_cache(
self.get_global_cache(), callback=gc_log
if not check_cache(local_cache) and not (
use_global_cache()
and check_cache(self.get_global_cache(), callback=log_stats)
):
# re-benchmark everything to try to get consistent numbers from the same machine
for choice in choices:
Expand All @@ -258,9 +269,12 @@ def check_cache(cache, callback=None):
local_cache[name][inputs][choice.hash_key()] = timings[choice]

self.update_local_cache(local_cache)
else:

if use_global_cache():
log_vals(timings)
elif use_global_cache():
# only check global cache, not local one
check_cache(self.get_global_cache(), callback=gc_log)
check_cache(self.get_global_cache(), callback=log_stats)
# may have a partial cache hit, where not everything is benchmarked

return timings
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .virtualized import V

if config.is_fbcode():
from torch._inductor.fb.logging import time_and_log
from torch._inductor.fb.utils import time_and_log
else:
# no-op decorator
def time_and_log(attr: str):
Expand Down

0 comments on commit 6ca3d7e

Please sign in to comment.