From 838b74d05b8e35ea289233ec62b6ad9b90b9a460 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Thu, 21 Nov 2024 10:28:41 +0800 Subject: [PATCH] Add Ascend NPU support (#1758) --- src/axolotl/utils/bench.py | 15 +++++- src/axolotl/utils/config/__init__.py | 8 ++- .../config/models/input/v0_4_1/__init__.py | 35 +++++++++++++ src/axolotl/utils/distributed.py | 52 ++++++++++++++++--- src/axolotl/utils/models.py | 20 ++++--- 5 files changed, 114 insertions(+), 16 deletions(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 11c25160d..57471ae0d 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -4,6 +4,9 @@ import pynvml import torch from pynvml.nvml import NVMLError +from transformers.utils.import_utils import is_torch_npu_available + +from axolotl.utils.distributed import get_device_type def check_cuda_device(default_value): @@ -53,6 +56,12 @@ def mps_memory_usage_all(): return usage, reserved - usage, 0 +def npu_memory_usage_all(device=0): + usage = torch.npu.memory_allocated(device) / 1024.0**3 + reserved = torch.npu.memory_reserved(device) / 1024.0**3 + return usage, reserved - usage, 0 + + @check_cuda_device(0.0) def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): @@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): + cur_device = get_device_type() if torch.backends.mps.is_available(): usage, cache, misc = mps_memory_usage_all() + elif "npu" in str(cur_device) and is_torch_npu_available(): + usage, cache, misc = npu_memory_usage_all(device) else: usage, cache, misc = gpu_memory_usage_all(device) extras = [] @@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device): if misc > 0: extras.append(f"+{misc:.03f}GB misc") log.info( - f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 + f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", + stacklevel=2, ) return usage, cache, misc diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b12ad8113..0100f23ea 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -5,6 +5,7 @@ import torch from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils.import_utils import is_torch_npu_available from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage @@ -29,7 +30,10 @@ def get_device(): if torch.backends.mps.is_available(): return "mps" - raise SystemError("No CUDA/mps device found") + if is_torch_npu_available(): + return f"npu:{cfg.local_rank}" + + raise SystemError("No CUDA/mps/npu device found") except Exception: # pylint: disable=broad-exception-caught return "cpu" @@ -39,6 +43,8 @@ def get_device(): else: if cfg.device.startswith("cuda"): cfg.device_map = {"": torch.cuda.current_device()} + elif cfg.device.startswith("npu"): + cfg.device_map = {"npu": torch.npu.current_device()} else: cfg.device_map = {"": cfg.device} diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f4420ae2c..42cbe52c1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -19,6 +19,7 @@ ) from transformers import SchedulerType from transformers.training_args import OptimizerNames +from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.config.models.internals import GPUCapabilities @@ -1433,6 +1434,40 @@ def check_torch_compile_deepspeed(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_npu_config(cls, data): + if is_torch_npu_available(): + # check attention config + attn_list = ["flash_attention", "sdp_attention", "s2_attention"] + for attn in attn_list: + if data.get(attn): + raise NotImplementedError( + f"{attn} is currently not supported in Ascend npu, please disable this configuration." + ) + + # check quant config + if data.get("optimizer") is not None and "bit" in data.get("optimizer"): + optimizer = data.get("optimizer") + raise NotImplementedError( + f"{optimizer} is currently not supported in Ascend npu, choose another one please." + ) + + quant_list = ["load_in_8bit", "load_in_4bit"] + for quant in quant_list: + if data.get(quant): + raise NotImplementedError( + f"Quantification is currently not supported in Ascend npu, please disable {quant}." + ) + + # check dtype config + if data.get("tf32"): + raise NotImplementedError( + "tf32 dtype is currently not supported in Ascend npu, please disable this configuration" + ) + + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 3a559f5f5..81a928b6e 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -9,10 +9,44 @@ import torch import torch.distributed as dist from accelerate import PartialState +from transformers.utils.import_utils import ( + is_torch_cuda_available, + is_torch_mps_available, + is_torch_npu_available, +) distributed_state = None # pylint: disable=invalid-name +def get_device_type(): + device = torch.device("cpu") + if is_torch_cuda_available(): + device = torch.device("cuda") + elif is_torch_mps_available(): + device = torch.device("mps") + elif is_torch_npu_available(): + device = torch.device("npu") + return device + + +def get_device_count(): + cur_device = get_device_type() + if "cuda" in str(cur_device): + return torch.cuda.device_count() + if "npu" in str(cur_device): + return torch.npu.device_count() + return 1 + + +def get_current_device(): + cur_device = get_device_type() + if "cuda" in str(cur_device): + return torch.cuda.current_device() + if "npu" in str(cur_device): + return torch.npu.current_device() + return 0 + + def is_distributed(): """ Check if distributed training is initialized. @@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n if not is_distributed(): return [value_scalar] value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() + value_scalar, device=f"{get_device_type()}:{get_current_device()}" ).float() if not is_main_process(): @@ -115,13 +149,14 @@ def broadcast_dict(vals: dict): if not is_distributed(): return vals + cur_device = get_device_type() if is_main_process(): data_byte = pickle.dumps(vals) - data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") - data_size = torch.IntTensor([len(data_byte)]).to("cuda") + data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device) + data_size = torch.IntTensor([len(data_byte)]).to(cur_device) else: - data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") - data_size = torch.IntTensor([0]).to("cuda") + data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device) + data_size = torch.IntTensor([0]).to(cur_device) dist.broadcast(data_size, 0) if not is_main_process(): @@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name Returns: - The computed value (int or float). """ + cur_device = f"{get_device_type()}:{get_current_device()}" if is_main_process(): value_scalar = fn() value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device(), dtype=torch.float32 + value_scalar, device=cur_device, dtype=torch.float32 ) else: value_tensor = torch.tensor( - 0.0, device=torch.cuda.current_device(), dtype=torch.float32 + 0.0, device=cur_device, dtype=torch.float32 ) # Placeholder tensor # Broadcast the tensor to all processes. @@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name """ value_scalar = fn() value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() + value_scalar, device=f"{get_device_type()}:{get_current_device()}" ).float() # Placeholder tensor for gathering results diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 75c93fa2a..082df7c27 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -55,7 +55,7 @@ from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import zero_only +from axolotl.utils.distributed import get_device_count, get_device_type, zero_only from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -570,7 +570,8 @@ def set_device_map_config(self) -> None: ) max_memory = {} - for i in range(torch.cuda.device_count()): + num_device = get_device_count() + for i in range(num_device): max_memory[i] = gpu_memory_limit max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything @@ -595,8 +596,11 @@ def set_device_map_config(self) -> None: self.model_kwargs["device_map"] = device_map self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - if torch.backends.mps.is_available(): + cur_device = get_device_type() + if "mps" in str(cur_device): self.model_kwargs["device_map"] = "mps:0" + elif "npu" in str(cur_device): + self.model_kwargs["device_map"] = "npu:0" # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss # if cfg.rl: @@ -1050,7 +1054,11 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: self.ajust_model_config() # log device memory usage - if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"): + if hasattr(self.model, "device") and self.model.device.type in ( + "cuda", + "mps", + "npu", + ): log_gpu_memory_usage(LOG, "after model load", self.model.device) # make sure these are fp32 per Ramesh et al. (2021) @@ -1118,9 +1126,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: and not skip_move_to_device ): # TODO revaldate this conditional - self.model.to(f"cuda:{self.cfg.local_rank}") + self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(self.model, "is_parallelizable", True) setattr(self.model, "model_parallel", True)