Skip to content

Commit

Permalink
Do not add extra .0 if running single core. (pytorch#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlibenzi authored and ailzhang committed Dec 2, 2019
1 parent f66bf94 commit b5ac7b1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
13 changes: 6 additions & 7 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def get_local_ordinal(defval=0):
The replication local ordinal of the current process.
"""

return xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=defval)
ordinal = xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=-1)
if ordinal >= 0:
return ordinal
return getattr(_TLS, 'device_index', defval)


def is_master_ordinal(local=True):
Expand All @@ -112,12 +115,8 @@ def is_master_ordinal(local=True):
A boolean indicating whether the current process is the master ordinal.
"""

ordinal = get_local_ordinal(defval=-1) if local else get_ordinal(defval=-1)
if ordinal >= 0:
# We are either on multi-processing, or on BigSlice (or both).
return ordinal == 0
# We are in the multi-threaded DataParallel setup.
return getattr(_TLS, 'device_index', 0) == 0
ordinal = get_local_ordinal() if local else get_ordinal()
return ordinal == 0


def master_print(s, fd=sys.stdout, local=True):
Expand Down
14 changes: 7 additions & 7 deletions torch_xla/debug/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch_xla
import torch_xla.debug.metrics as met

_STEP_METRICS_FILE = None
_STEP_METRICS_FILE_LOCK = threading.Lock()

_TLS = threading.local()
Expand All @@ -22,17 +21,18 @@ def _extract_metrics_file():
import torch_xla.core.xla_model as xm
metrics_file = os.environ.get('XLA_METRICS_FILE', None)
if metrics_file is not None:
ordinal = xm.get_ordinal(defval=-1)
if ordinal >= 0:
ordinal = xm.get_local_ordinal(defval=-1)
if ordinal >= 0 and xm.xrt_world_size() > 1:
metrics_file = '{}.{}'.format(metrics_file, ordinal)
return metrics_file


def _get_metrics_file():
global _STEP_METRICS_FILE
if _STEP_METRICS_FILE is None:
_STEP_METRICS_FILE = _extract_metrics_file()
return _STEP_METRICS_FILE
metrics_file = getattr(_TLS, 'metrics_file', '')
if metrics_file == '':
metrics_file = _extract_metrics_file()
_TLS.metrics_file = metrics_file
return metrics_file


def save_metrics(metrics_file=None):
Expand Down

0 comments on commit b5ac7b1

Please sign in to comment.