Skip to content

Commit

Permalink
Expose more parameters in ModelConfig and add generatio (#166)
Browse files Browse the repository at this point in the history
* replace use_auth_token with token; refactor get_torch_dtype; add model config items

* handle torch_dtype not configured case

* update inference_config

* rollback to use_auth_token

* fix utils

* revert unwanted changes

* generate more tests

* fix generated tests

* fix test

* use float for bigdl

* add ipex for tests

* upd

* consider hf config first
  • Loading branch information
kira-lin authored Apr 12, 2024
1 parent 97dc0b8 commit 64e337b
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 56 deletions.
7 changes: 3 additions & 4 deletions llm_on_ray/inference/deepspeed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import List
import os
from llm_on_ray.inference.predictor import Predictor
from llm_on_ray.inference.utils import get_torch_dtype
from llm_on_ray.inference.utils import decide_torch_dtype
from llm_on_ray.inference.inference_config import (
InferenceConfig,
GenerateResult,
Expand All @@ -54,12 +54,11 @@ def __init__(self, infer_conf: InferenceConfig, pad_token_id, stopping_criteria)
use_auth_token=infer_conf.model_description.config.use_auth_token,
)

# get correct torch type for loading HF model
torch_dtype = get_torch_dtype(infer_conf, hf_config)
# decide correct torch type for loading HF model
decide_torch_dtype(infer_conf, hf_config)
self.model = AutoModelForCausalLM.from_pretrained(
model_desc.model_id_or_path,
config=hf_config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
**model_config.dict(),
)
Expand Down
26 changes: 12 additions & 14 deletions llm_on_ray/inference/hpu_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@
GenerateResult,
)
from llm_on_ray.inference.predictor import Predictor
from llm_on_ray.inference.utils import decide_torch_dtype


class HPUPredictor(Predictor):
def __init__(self, infer_conf: InferenceConfig):
super().__init__(infer_conf)

model_desc = infer_conf.model_description
model_config = model_desc.config
# decide correct torch type for loading HF model
decide_torch_dtype(infer_conf)
self.use_lazy_mode = True
self.use_hpu_graphs = model_desc.use_hpu_graphs
# TODO add torch_compile, i.e. hpu specific configs. including quant
Expand Down Expand Up @@ -96,10 +100,7 @@ def __init__(self, infer_conf: InferenceConfig):
# Not using DeepSpeed, load model locally
self.device = torch.device("hpu")
model = AutoModelForCausalLM.from_pretrained(
model_desc.model_id_or_path,
# TODO expose torch_dtype in model_config
# **model_config
# torch_dtype=model_dtype,
model_desc.model_id_or_path, **model_config.dict()
)
self.model = model.eval().to(self.device)
if self.use_hpu_graphs:
Expand Down Expand Up @@ -241,14 +242,8 @@ def get_streamer(self):
def load_model(self, model_desc, model_config):
import deepspeed

# bf16 is used for deepspeed
model_dtype = torch.bfloat16

config = AutoConfig.from_pretrained(
model_desc.model_id_or_path,
torch_dtype=model_dtype,
trust_remote_code=model_config.trust_remote_code,
)
model_dtype = model_config.torch_dtype
config = AutoConfig.from_pretrained(model_desc.model_id_or_path, **model_config.dict())
load_to_meta = model_on_meta(config)

if load_to_meta:
Expand All @@ -257,12 +252,15 @@ def load_model(self, model_desc, model_config):

checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
write_checkpoints_json(
model_desc.model_id_or_path, self.local_rank, checkpoints_json, token=""
model_desc.model_id_or_path,
self.local_rank,
checkpoints_json,
token=model_config.use_auth_token,
)
else:
with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained(
model_desc.model_id_or_path, torch_dtype=model_dtype, **model_config.dict()
model_desc.model_id_or_path, **model_config.dict()
)
model.eval()

Expand Down
15 changes: 11 additions & 4 deletions llm_on_ray/inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class ModelConfig(BaseModel):
trust_remote_code: bool = False
use_auth_token: Union[str, None] = None
load_in_4bit: bool = False
torch_dtype: Union[str, None] = None
revision: Union[str, None] = None


class Ipex(BaseModel):
Expand Down Expand Up @@ -83,19 +85,24 @@ class GenerateResult(BaseModel):

class ModelDescription(BaseModel):
model_id_or_path: Union[str, None] = None
bigdl: bool = False
tokenizer_name_or_path: Union[str, None] = None
config: ModelConfig = ModelConfig()
prompt: Prompt = Prompt()
chat_processor: Union[str, None] = None

gpt_base_model: bool = False

quantized_model_id_or_path: Union[str, None] = None
quantization_type: Union[str, None] = None

peft_model_id_or_path: Union[str, None] = None
peft_type: Union[str, None] = None

bigdl: bool = False
bigdl_config: BigDLModelConfig = BigDLModelConfig()

# only effective when device is hpu
use_hpu_graphs: bool = True
prompt: Prompt = Prompt()
config: ModelConfig = ModelConfig()
bigdl_config: BigDLModelConfig = BigDLModelConfig()

# prevent warning of protected namespaces
# DO NOT TOUCH
Expand Down
5 changes: 2 additions & 3 deletions llm_on_ray/inference/mllm_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from transformers import TextIteratorStreamer
from llm_on_ray.inference.inference_config import InferenceConfig, GenerateResult, PRECISION_BF16
from llm_on_ray.inference.utils import get_torch_dtype, module_import
from llm_on_ray.inference.utils import decide_torch_dtype, module_import
from llm_on_ray.inference.predictor import Predictor


Expand All @@ -33,13 +33,12 @@ def __init__(self, infer_conf: InferenceConfig):

adapt_transformers_to_gaudi()
# get correct torch type for loading HF model
torch_dtype = get_torch_dtype(infer_conf, None)
decide_torch_dtype(infer_conf)

model_loader_name = infer_conf.model_description.model_loader
input_processor_name = infer_conf.model_description.input_processor
model = module_import("transformers", model_loader_name).from_pretrained(
model_desc.model_id_or_path,
torch_dtype=torch_dtype,
**model_desc.config.dict(),
)
processor = module_import("transformers", input_processor_name).from_pretrained(
Expand Down
8 changes: 3 additions & 5 deletions llm_on_ray/inference/transformer_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from transformers import AutoModelForCausalLM, AutoConfig, TextIteratorStreamer
from llm_on_ray.inference.inference_config import InferenceConfig, GenerateResult, PRECISION_BF16
from llm_on_ray.inference.utils import get_torch_dtype
from llm_on_ray.inference.utils import decide_torch_dtype
from llm_on_ray.inference.predictor import Predictor


Expand All @@ -33,8 +33,8 @@ def __init__(self, infer_conf: InferenceConfig):
use_auth_token=infer_conf.model_description.config.use_auth_token,
)

# get correct torch type for loading HF model
torch_dtype = get_torch_dtype(infer_conf, hf_config)
# decide correct torch type for loading HF model
decide_torch_dtype(infer_conf, hf_config)
if model_desc.bigdl:
from bigdl.llm.transformers import (
AutoModelForCausalLM as BigDLAutoModelForCLM,
Expand All @@ -46,15 +46,13 @@ def __init__(self, infer_conf: InferenceConfig):
bmodel_config.update(model_desc.bigdl_config.dict())
model = BigDLAutoModelForCLM.from_pretrained(
model_desc.model_id_or_path,
torch_dtype=torch_dtype,
config=hf_config,
low_cpu_mem_usage=True,
**bmodel_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_desc.model_id_or_path,
torch_dtype=torch_dtype,
config=hf_config,
low_cpu_mem_usage=True,
**model_config.dict(),
Expand Down
44 changes: 29 additions & 15 deletions llm_on_ray/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from typing import Dict, Any, List, Optional, Union
from enum import Enum
from llm_on_ray.inference.inference_config import InferenceConfig, DEVICE_CPU
from llm_on_ray.inference.inference_config import InferenceConfig, DEVICE_CPU, DEVICE_HPU
from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage


Expand Down Expand Up @@ -116,22 +116,36 @@ def max_input_len(input_text_length):
return 4096


def get_torch_dtype(infer_conf: InferenceConfig, hf_config) -> torch.dtype:
def decide_torch_dtype(infer_conf: InferenceConfig, hf_config=None):
"""
return torch default dtype, a.k.a float32, if it's cpu only inference without
ipex because bfloat16 is too slow and float16 is not supported in CPU
Decide torch dtype based on user config and model config.
This function modifies `torch_dtype` in infer_conf.model_description.config.
"""
if hf_config is None or is_cpu_without_ipex(infer_conf):
return torch.get_default_dtype()
if hasattr(hf_config, "torch_dtype"):
t = hf_config.torch_dtype
if t:
return t
if hasattr(hf_config, "__getitem__"):
t = hf_config["torch_dtype"]
if t:
return t
return torch.get_default_dtype()
# First, create torch_dtype if it does not exist
if not hasattr(infer_conf.model_description.config, "torch_dtype"):
infer_conf.model_description.config.torch_dtype = None

if infer_conf.model_description.config.torch_dtype:
# respect user config
return
elif hf_config is None:
# default to float32 if hf_config is not supplied
infer_conf.model_description.config.torch_dtype = torch.get_default_dtype()
# if hf_config contains recommended torch_dtype, use it
elif hasattr(hf_config, "torch_dtype") and hf_config.torch_dtype:
infer_conf.model_description.config.torch_dtype = hf_config.torch_dtype
elif hasattr(hf_config, "__getitem__") and "torch_dtype" in hf_config:
infer_conf.model_description.config.torch_dtype = hf_config["torch_dtype"]

# if using hpu
if infer_conf.device == DEVICE_HPU:
# if using deepspeed, we should use bfloat16
# TODO if quantization is enabled, we should use bfloat16
if infer_conf.deepspeed:
infer_conf.model_description.config.torch_dtype = torch.bfloat16
elif is_cpu_without_ipex(infer_conf):
# cpu without ipex use default float32
infer_conf.model_description.config.torch_dtype = torch.get_default_dtype()


def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool:
Expand Down
62 changes: 51 additions & 11 deletions tests/inference/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
get_deployment_actor_options,
StoppingCriteriaSub,
max_input_len,
get_torch_dtype,
decide_torch_dtype,
is_cpu_without_ipex,
)
from llm_on_ray.inference.inference_config import InferenceConfig, DEVICE_CPU
from llm_on_ray.inference.inference_config import InferenceConfig, DEVICE_CPU, DEVICE_HPU


# Mock the InferenceConfig for testing
Expand Down Expand Up @@ -73,20 +73,60 @@ def test_max_input_len():
# Add more tests for edge cases


def test_get_torch_dtype_cpu_without_ipex(mock_infer_conf):
hf_config = None
dtype = get_torch_dtype(mock_infer_conf, hf_config)
assert dtype == torch.get_default_dtype()
def test_is_cpu_without_ipex(mock_infer_conf):
assert is_cpu_without_ipex(mock_infer_conf) is True

mock_infer_conf.ipex.enabled = True
assert is_cpu_without_ipex(mock_infer_conf) is False

# Add more tests for different configurations and hf_config values

def test_decide_torch_dtype_cpu_without_ipex(mock_infer_conf):
decide_torch_dtype(mock_infer_conf)

def test_is_cpu_without_ipex(mock_infer_conf):
assert is_cpu_without_ipex(mock_infer_conf) is True
assert mock_infer_conf.model_description.config.torch_dtype == torch.get_default_dtype()


def test_decide_torch_dtype_respect_user_config(mock_infer_conf):
mock_infer_conf.model_description.config.torch_dtype = torch.float16
mock_infer_conf.device = DEVICE_CPU

decide_torch_dtype(mock_infer_conf)

assert mock_infer_conf.model_description.config.torch_dtype == torch.float16


def test_decide_torch_dtype_hpu_with_deepspeed(mock_infer_conf):
mock_infer_conf.device = DEVICE_HPU
mock_infer_conf.deepspeed = True

decide_torch_dtype(mock_infer_conf)

assert mock_infer_conf.model_description.config.torch_dtype == torch.bfloat16


def test_decide_torch_dtype_with_hf_config_torch_dtype(mock_infer_conf):
mock_infer_conf.device = DEVICE_CPU
mock_infer_conf.ipex.enabled = True
assert is_cpu_without_ipex(mock_infer_conf) is False
hf_config = {"torch_dtype": torch.float16}

decide_torch_dtype(mock_infer_conf, hf_config)

assert mock_infer_conf.model_description.config.torch_dtype == torch.float16


def test_decide_torch_dtype_with_hf_config_torch_dtype_as_attribute(mock_infer_conf):
class HFConfig:
def __init__(self):
self.torch_dtype = torch.float16

mock_infer_conf.device = DEVICE_CPU
mock_infer_conf.ipex.enabled = True
hf_config = HFConfig()

decide_torch_dtype(mock_infer_conf, hf_config)

assert mock_infer_conf.model_description.config.torch_dtype == torch.float16


# Add more tests for different configurations
if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit 64e337b

Please sign in to comment.