Skip to content

Commit

Permalink
[common] add device option for TorchConfig (intel#126)
Browse files Browse the repository at this point in the history
* add device option for TorchConfig

* update

* update

* update
  • Loading branch information
harborn authored Nov 22, 2023
1 parent 5157502 commit f7b6e86
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
11 changes: 8 additions & 3 deletions common/torch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ray.train.torch.config import TorchConfig as RayTorchConfig
from ray.train._internal.worker_group import WorkerGroup
from dataclasses import dataclass
from typing import Optional
import os
import sys
# The package importlib_metadata is in a different place, depending on the Python version.
Expand All @@ -13,9 +14,11 @@

@dataclass
class TorchConfig(RayTorchConfig):
device: Optional[str] = None

@property
def backend_cls(self):
EnableCCLBackend.device = self.device
return EnableCCLBackend


Expand All @@ -41,11 +44,13 @@ def libs_import():
) from ccl_not_exist


def _del_torch_distributed_env_vars():
del os.environ["ACCELERATE_TORCH_DEVICE"]
def _set_torch_distributed_env_vars(device):
if device is not None:
os.environ["ACCELERATE_TORCH_DEVICE"] = device


class EnableCCLBackend(_TorchBackend):
device: Optional[str] = None

def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
for i in range(len(worker_group)):
Expand All @@ -54,4 +59,4 @@ def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):

def on_training_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
super().on_training_start(worker_group, backend_config)
worker_group.execute(_del_torch_distributed_env_vars)
worker_group.execute(_set_torch_distributed_env_vars, self.device)
19 changes: 10 additions & 9 deletions finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def train_func(config: Dict[str, Any]):
} if config["General"].get("checkpoint_dir") else None
})

try :
try:
common.logger.info(f"trainer prepare start")
trainer.prepare(model, tokenizer, datasets, optimizer, accelerator)
except Exception as e:
common.logger.critical(e, exc_info=True)
exit(1)
common.logger.info(f"trainer prepare finish")

try :
try:
common.logger.info(f"train start")
trainer.train()
except Exception as e:
Expand All @@ -101,12 +101,12 @@ def main(external_config = None):
num_training_workers = config["Training"].get("num_training_workers")
resources_per_worker = config["Training"].get("resources_per_worker")

device = config["Training"]["device"]
device = config["Training"]["device"].lower()
if not ray.is_initialized():
runtime_env = {
"env_vars": {
"OMP_NUM_THREADS": str(resources_per_worker["CPU"]),
"ACCELERATE_USE_CPU": "True" if device == "CPU" else "False",
"ACCELERATE_USE_CPU": "True" if device == "cpu" else "False",
"ACCELERATE_USE_IPEX": "False",
"ACCELERATE_MIXED_PRECISION": "no",
"CCL_WORKER_COUNT": "1",
Expand All @@ -122,14 +122,14 @@ def main(external_config = None):
num_workers = num_training_workers,
resources_per_worker = resources_per_worker,
placement_strategy = "SPREAD",
use_gpu = False if device == "CPU" else True
use_gpu = False if device == "cpu" else True
)

if config.get("torch_config", None) is None:
torch_config = common.TorchConfig(backend = "ccl" if device == "CPU" else None)
torch_config = common.TorchConfig(backend = "ccl" if device == "cpu" else None, device=device)
else:
customer_torch_config = config.get("torch_config")
torch_config = common.TorchConfig(**customer_torch_config)
torch_config = common.TorchConfig(**customer_torch_config, device=device)

if config.get("failure_config", None) is None:
failure_config = FailureConfig()
Expand All @@ -149,10 +149,11 @@ def main(external_config = None):
train_func,
train_loop_config=config,
scaling_config=scaling_config,
torch_config = torch_config,
run_config = run_config
torch_config=torch_config,
run_config=run_config
)
results = trainer.fit()

return results

if __name__ == "__main__":
Expand Down

0 comments on commit f7b6e86

Please sign in to comment.