diff --git a/examples/models/gpt2/gpt2_woq.py b/examples/models/gpt2/gpt2_woq.py new file mode 100644 index 0000000..0539f45 --- /dev/null +++ b/examples/models/gpt2/gpt2_woq.py @@ -0,0 +1,10 @@ +# from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models import BaseModel + +# Initializes the model: Quantize model with weight only algorithms and +# replace the linear with itrex's qbits_linear kernel +model = BaseModel.create("gpt2_int8") + +# Once the model has been quantized, you can do inferences directly +output = model.generate(texts=["Why LLM models are becoming so important?"]) +print(output) \ No newline at end of file diff --git a/examples/models/llama2/llama2_woq.py b/examples/models/llama2/llama2_woq.py new file mode 100644 index 0000000..f6cda94 --- /dev/null +++ b/examples/models/llama2/llama2_woq.py @@ -0,0 +1,10 @@ +# from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models import BaseModel + +# Initializes the model: Quantize model with weight only algorithms and +# replace the linear with itrex's qbits_linear kernel +model = BaseModel.create("llama2_int8") + +# Once the model has been quantized, you can do inferences directly +output = model.generate(texts=["Why LLM models are becoming so important?"]) +print(output) \ No newline at end of file diff --git a/src/xturing/config/__init__.py b/src/xturing/config/__init__.py index 56820fb..c925602 100644 --- a/src/xturing/config/__init__.py +++ b/src/xturing/config/__init__.py @@ -1,6 +1,10 @@ import torch from xturing.utils.interactive import is_interactive_execution +from xturing.utils.logging import configure_logger +from xturing.utils.utils import assert_install_itrex + +logger = configure_logger(__name__) # check if cuda is available, if not use cpu and throw warning DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -8,8 +12,12 @@ IS_INTERACTIVE = is_interactive_execution() if DEFAULT_DEVICE.type == "cpu": - print("WARNING: CUDA is not available, using CPU instead, can be very slow") + logger.warning("WARNING: CUDA is not available, using CPU instead, can be very slow") def assert_not_cpu_int8(): assert DEFAULT_DEVICE.type != "cpu", "Int8 models are not supported on CPU" + +def assert_cpu_int8_on_itrex(): + if DEFAULT_DEVICE.type == "cpu": + assert_install_itrex() \ No newline at end of file diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py index 64d068e..8f6b7e8 100644 --- a/src/xturing/engines/causal.py +++ b/src/xturing/engines/causal.py @@ -19,9 +19,13 @@ ) from xturing.engines.quant_utils.peft_utils import LoraConfig as peftLoraConfig from xturing.engines.quant_utils.peft_utils import prepare_model_for_kbit_training +from xturing.utils.logging import configure_logger from xturing.utils.loss_fns import CrossEntropyLoss +from xturing.utils.utils import assert_install_itrex +logger = configure_logger(__name__) + class CausalEngine(BaseEngine): def __init__( self, @@ -60,18 +64,34 @@ def __init__( self.tokenizer = tokenizer elif model_name is not None: if load_8bit: - device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=DEFAULT_DTYPE, - load_in_8bit=True, - device_map=device_map, - trust_remote_code=trust_remote_code, - **kwargs, - ) - for param in self.model.parameters(): - param.data = param.data.contiguous() - self.model = prepare_model_for_int8_training(self.model) + use_itrex = DEFAULT_DEVICE.type == "cpu" + if use_itrex: + logger.info("CUDA is not available, using CPU instead, running the model with itrex.") + assert_install_itrex() + # quantize model with weight-only quantization + from intel_extension_for_transformers.transformers import AutoModelForCausalLM as ItrexAutoModelForCausalLM + from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig + woq_config = WeightOnlyQuantConfig(weight_dtype='int8') + self.model = ItrexAutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=woq_config, + trust_remote_code=trust_remote_code, + use_llm_runtime=False, + **kwargs) + logger.info("Loaded int8 model from Itrex.") + else: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=DEFAULT_DTYPE, + load_in_8bit=True, + device_map=device_map, + trust_remote_code=trust_remote_code, + **kwargs, + ) + for param in self.model.parameters(): + param.data = param.data.contiguous() + self.model = prepare_model_for_int8_training(self.model) else: self.model = AutoModelForCausalLM.from_pretrained( model_name, diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 62bb274..ca085ff 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -8,7 +8,7 @@ from tqdm import tqdm from transformers import BatchEncoding -from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8 +from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8, assert_cpu_int8_on_itrex from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig from xturing.config.read_config import load_config from xturing.datasets.instruction_dataset import InstructionDataset @@ -320,7 +320,7 @@ def __init__( model_name: Optional[str] = None, **kwargs, ): - assert_not_cpu_int8() + assert_cpu_int8_on_itrex() super().__init__( engine, weights_path=weights_path, diff --git a/src/xturing/utils/utils.py b/src/xturing/utils/utils.py index a1bc22a..c8af938 100644 --- a/src/xturing/utils/utils.py +++ b/src/xturing/utils/utils.py @@ -150,3 +150,32 @@ def _index_samples(samples: List[Any], logger: logging.Logger): logger.info(f"Evaluating {len(indices)} samples") work_items = [(samples[i], i) for i in indices] return work_items + + +def is_itrex_available(): + """ + Check the availability of 'intel_extension_for_transformers' as an optional dependency. + + Returns: + bool: True if 'intel_extension_for_transformers' is available, False otherwise. + + Raises: + subprocess.CalledProcessError: If the pip installation process fails. + """ + import importlib + if importlib.util.find_spec("intel_extension_for_transformers") is not None: + return True + else: + try: + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "intel-extension-for-transformers"]) + return importlib.util.find_spec("intel_extension_for_transformers") is not None + except: + return False + +def assert_install_itrex(): + assert is_itrex_available(), ( + "To run int8 or k-bits model on cpu, please install the `intel-extension-for-transformers` package." + "You can install it with `pip install intel-extension-for-transformers`." + ) \ No newline at end of file diff --git a/tests/xturing/models/test_gpt2_model.py b/tests/xturing/models/test_gpt2_model.py index 3ad29d2..d79fd82 100644 --- a/tests/xturing/models/test_gpt2_model.py +++ b/tests/xturing/models/test_gpt2_model.py @@ -101,3 +101,35 @@ def test_saving_loading_model_lora(): model2 = BaseModel.load(str(saving_path)) model2.generate(texts=["Why are the LLM so important?"]) + + +import os + +def disable_cuda(func): + def wrapper(*args, **kwargs): + # Save the current value of CUDA_VISIBLE_DEVICES + original_cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None) + # Set CUDA_VISIBLE_DEVICES to -1 to disable CUDA + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + try: + # Call the decorated function + return func(*args, **kwargs) + except Exception as e: + # Handle exceptions here + print(f"An error occurred: {e}") + finally: + # Restore the original value of CUDA_VISIBLE_DEVICES + if original_cuda_visible_devices is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = original_cuda_visible_devices + else: + # If CUDA_VISIBLE_DEVICES was not set before, remove it from the environment + if 'CUDA_VISIBLE_DEVICES' in os.environ: + del os.environ['CUDA_VISIBLE_DEVICES'] + + return wrapper + +@disable_cuda +def test_gpt2_int8_woq_cpu(): + # test quantize gpt2 with itrex + other_model = BaseModel.create("gpt2_int8") + assert other_model.generate(texts="I want to") != "" \ No newline at end of file