Skip to content

Commit

Permalink
Merge pull request #268 from yiliu30/itrex_woq
Browse files Browse the repository at this point in the history
Integrate ITREX to support int8 model on the CPU-only devices
  • Loading branch information
StochasticRomanAgeev authored Nov 7, 2023
2 parents 688307f + a9dbb28 commit 54d1ec3
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 15 deletions.
10 changes: 10 additions & 0 deletions examples/models/gpt2/gpt2_woq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel

# Initializes the model: Quantize model with weight only algorithms and
# replace the linear with itrex's qbits_linear kernel
model = BaseModel.create("gpt2_int8")

# Once the model has been quantized, you can do inferences directly
output = model.generate(texts=["Why LLM models are becoming so important?"])
print(output)
10 changes: 10 additions & 0 deletions examples/models/llama2/llama2_woq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel

# Initializes the model: Quantize model with weight only algorithms and
# replace the linear with itrex's qbits_linear kernel
model = BaseModel.create("llama2_int8")

# Once the model has been quantized, you can do inferences directly
output = model.generate(texts=["Why LLM models are becoming so important?"])
print(output)
10 changes: 9 additions & 1 deletion src/xturing/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import torch

from xturing.utils.interactive import is_interactive_execution
from xturing.utils.logging import configure_logger
from xturing.utils.utils import assert_install_itrex

logger = configure_logger(__name__)

# check if cuda is available, if not use cpu and throw warning
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DTYPE = torch.float16 if DEFAULT_DEVICE.type == "cuda" else torch.float32
IS_INTERACTIVE = is_interactive_execution()

if DEFAULT_DEVICE.type == "cpu":
print("WARNING: CUDA is not available, using CPU instead, can be very slow")
logger.warning("WARNING: CUDA is not available, using CPU instead, can be very slow")


def assert_not_cpu_int8():
assert DEFAULT_DEVICE.type != "cpu", "Int8 models are not supported on CPU"

def assert_cpu_int8_on_itrex():
if DEFAULT_DEVICE.type == "cpu":
assert_install_itrex()
44 changes: 32 additions & 12 deletions src/xturing/engines/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
)
from xturing.engines.quant_utils.peft_utils import LoraConfig as peftLoraConfig
from xturing.engines.quant_utils.peft_utils import prepare_model_for_kbit_training
from xturing.utils.logging import configure_logger
from xturing.utils.loss_fns import CrossEntropyLoss
from xturing.utils.utils import assert_install_itrex


logger = configure_logger(__name__)

class CausalEngine(BaseEngine):
def __init__(
self,
Expand Down Expand Up @@ -60,18 +64,34 @@ def __init__(
self.tokenizer = tokenizer
elif model_name is not None:
if load_8bit:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=DEFAULT_DTYPE,
load_in_8bit=True,
device_map=device_map,
trust_remote_code=trust_remote_code,
**kwargs,
)
for param in self.model.parameters():
param.data = param.data.contiguous()
self.model = prepare_model_for_int8_training(self.model)
use_itrex = DEFAULT_DEVICE.type == "cpu"
if use_itrex:
logger.info("CUDA is not available, using CPU instead, running the model with itrex.")
assert_install_itrex()
# quantize model with weight-only quantization
from intel_extension_for_transformers.transformers import AutoModelForCausalLM as ItrexAutoModelForCausalLM
from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig
woq_config = WeightOnlyQuantConfig(weight_dtype='int8')
self.model = ItrexAutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=woq_config,
trust_remote_code=trust_remote_code,
use_llm_runtime=False,
**kwargs)
logger.info("Loaded int8 model from Itrex.")
else:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=DEFAULT_DTYPE,
load_in_8bit=True,
device_map=device_map,
trust_remote_code=trust_remote_code,
**kwargs,
)
for param in self.model.parameters():
param.data = param.data.contiguous()
self.model = prepare_model_for_int8_training(self.model)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
Expand Down
4 changes: 2 additions & 2 deletions src/xturing/models/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm
from transformers import BatchEncoding

from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8
from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8, assert_cpu_int8_on_itrex
from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig
from xturing.config.read_config import load_config
from xturing.datasets.instruction_dataset import InstructionDataset
Expand Down Expand Up @@ -320,7 +320,7 @@ def __init__(
model_name: Optional[str] = None,
**kwargs,
):
assert_not_cpu_int8()
assert_cpu_int8_on_itrex()
super().__init__(
engine,
weights_path=weights_path,
Expand Down
29 changes: 29 additions & 0 deletions src/xturing/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,32 @@ def _index_samples(samples: List[Any], logger: logging.Logger):
logger.info(f"Evaluating {len(indices)} samples")
work_items = [(samples[i], i) for i in indices]
return work_items


def is_itrex_available():
"""
Check the availability of 'intel_extension_for_transformers' as an optional dependency.
Returns:
bool: True if 'intel_extension_for_transformers' is available, False otherwise.
Raises:
subprocess.CalledProcessError: If the pip installation process fails.
"""
import importlib
if importlib.util.find_spec("intel_extension_for_transformers") is not None:
return True
else:
try:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "intel-extension-for-transformers"])
return importlib.util.find_spec("intel_extension_for_transformers") is not None
except:
return False

def assert_install_itrex():
assert is_itrex_available(), (
"To run int8 or k-bits model on cpu, please install the `intel-extension-for-transformers` package."
"You can install it with `pip install intel-extension-for-transformers`."
)
32 changes: 32 additions & 0 deletions tests/xturing/models/test_gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,35 @@ def test_saving_loading_model_lora():

model2 = BaseModel.load(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])


import os

def disable_cuda(func):
def wrapper(*args, **kwargs):
# Save the current value of CUDA_VISIBLE_DEVICES
original_cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
# Set CUDA_VISIBLE_DEVICES to -1 to disable CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
try:
# Call the decorated function
return func(*args, **kwargs)
except Exception as e:
# Handle exceptions here
print(f"An error occurred: {e}")
finally:
# Restore the original value of CUDA_VISIBLE_DEVICES
if original_cuda_visible_devices is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = original_cuda_visible_devices
else:
# If CUDA_VISIBLE_DEVICES was not set before, remove it from the environment
if 'CUDA_VISIBLE_DEVICES' in os.environ:
del os.environ['CUDA_VISIBLE_DEVICES']

return wrapper

@disable_cuda
def test_gpt2_int8_woq_cpu():
# test quantize gpt2 with itrex
other_model = BaseModel.create("gpt2_int8")
assert other_model.generate(texts="I want to") != ""

0 comments on commit 54d1ec3

Please sign in to comment.