diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index b10945985e4a..eaca8b0cc8a0 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -96,7 +96,10 @@ def get_local_ordinal(defval=0): The replication local ordinal of the current process. """ - return xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=defval) + ordinal = xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=-1) + if ordinal >= 0: + return ordinal + return getattr(_TLS, 'device_index', defval) def is_master_ordinal(local=True): @@ -112,12 +115,8 @@ def is_master_ordinal(local=True): A boolean indicating whether the current process is the master ordinal. """ - ordinal = get_local_ordinal(defval=-1) if local else get_ordinal(defval=-1) - if ordinal >= 0: - # We are either on multi-processing, or on BigSlice (or both). - return ordinal == 0 - # We are in the multi-threaded DataParallel setup. - return getattr(_TLS, 'device_index', 0) == 0 + ordinal = get_local_ordinal() if local else get_ordinal() + return ordinal == 0 def master_print(s, fd=sys.stdout, local=True): diff --git a/torch_xla/debug/metrics_saver.py b/torch_xla/debug/metrics_saver.py index e294280e431c..253550bc4187 100644 --- a/torch_xla/debug/metrics_saver.py +++ b/torch_xla/debug/metrics_saver.py @@ -5,7 +5,6 @@ import torch_xla import torch_xla.debug.metrics as met -_STEP_METRICS_FILE = None _STEP_METRICS_FILE_LOCK = threading.Lock() _TLS = threading.local() @@ -22,17 +21,18 @@ def _extract_metrics_file(): import torch_xla.core.xla_model as xm metrics_file = os.environ.get('XLA_METRICS_FILE', None) if metrics_file is not None: - ordinal = xm.get_ordinal(defval=-1) - if ordinal >= 0: + ordinal = xm.get_local_ordinal(defval=-1) + if ordinal >= 0 and xm.xrt_world_size() > 1: metrics_file = '{}.{}'.format(metrics_file, ordinal) return metrics_file def _get_metrics_file(): - global _STEP_METRICS_FILE - if _STEP_METRICS_FILE is None: - _STEP_METRICS_FILE = _extract_metrics_file() - return _STEP_METRICS_FILE + metrics_file = getattr(_TLS, 'metrics_file', '') + if metrics_file == '': + metrics_file = _extract_metrics_file() + _TLS.metrics_file = metrics_file + return metrics_file def save_metrics(metrics_file=None):