Skip to content

Commit

Permalink
Fixed the PT2E LLM example (#2082)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Dec 6, 2024
1 parent 697c5be commit 619f77b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,23 @@ Here is how to run the scripts:
#### Quantization

```bash
python run_clm_no_trainer.py --model facebook/opt-125m --quantize --accuracy
```
python run_clm_no_trainer.py --model facebook/opt-125m --quantize --output_dir qmodel_save_path
```

#### Accuracy
```bash
# Measure the accuracy of the floating model
python run_clm_no_trainer.py --model facebook/opt-125m --accuracy --tasks lambada_openai

# Measure the accuracy of the quantized model
python run_clm_no_trainer.py --model facebook/opt-125m --accuracy --tasks lambada_openai --int8 --output_dir qmodel_save_path
```

#### Performance
```bash
# Measure the performance of the floating model
python run_clm_no_trainer.py --model facebook/opt-125m --performance

# Measure the performance of the quantized model
python run_clm_no_trainer.py --model facebook/opt-125m --performance --int8 --output_dir qmodel_save_path
```
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
help="For accuracy measurement only.")
parser.add_argument("--tasks", default="lambada_openai,hellaswag,winogrande,piqa,wikitext",
type=str, help="tasks for accuracy validation")
parser.add_argument("--eval_limits", default=None, type=int,
help="Number of samples to evaluate, default is None, which means all samples")
parser.add_argument("--max_num_tokens", default=2048, type=int,
help="Max number of tokens")
parser.add_argument("--max_batch_size", default=16, type=int,
help="Max batch size")
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
# =======================================

Expand Down Expand Up @@ -75,8 +81,8 @@ def get_example_inputs(tokenizer):
tuple_inputs = (input_ids_batch,)
return tuple_inputs
# torch._dynamo.config.cache_size_limit = 4 # set limitation if out of memory
batch = Dim(name="batch_size")
seq_len = Dim(name="seq_len")
batch = Dim(name="batch_size", max=args.max_batch_size)
seq_len = Dim(name="seq_len", max=args.max_num_tokens)
dynamic_shapes = {"input_ids": (batch, seq_len)}
example_inputs = get_example_inputs(tokenizer)
exported_model = export(user_model, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes)
Expand All @@ -96,26 +102,16 @@ def get_example_inputs(tokenizer):
converted_model.save(example_inputs=example_inputs, output_dir = args.output_dir)



if args.int8:
if args.output_dir:
print("Load int8 model.")
from neural_compressor.torch.quantization import load
model = load(args.output_dir)

model.config = user_model.config # for lm eval

# Compile the quantized model and replace the Q/DQ pattern with Q-operator
from torch._inductor import config

config.freezing = True
opt_model = torch.compile(model)
model_config = user_model.config
user_model = load(args.output_dir)
user_model.config = model_config

opt_model.config = user_model.config # for lm eval
user_model = opt_model

if args.accuracy:

from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser
eval_args = LMEvalParser(
model="hf",
Expand All @@ -124,6 +120,7 @@ def get_example_inputs(tokenizer):
batch_size=args.batch_size,
tasks=args.tasks,
device="cpu",
limit=args.eval_limits,
)
results = evaluate(eval_args)
for task_name in args.tasks.split(","):
Expand All @@ -137,6 +134,14 @@ def get_example_inputs(tokenizer):
if args.performance:
batch_size, input_leng = args.batch_size, 512
example_inputs = torch.ones((batch_size, input_leng), dtype=torch.long)
# Compile the quantized model and replace the Q/DQ pattern with Q-operator
from torch._inductor import config

config.freezing = True
model_config = user_model.config
user_model = torch.compile(user_model)
user_model.config = model_config

print("Batch size = {:d}".format(batch_size))
print("The length of input tokens = {:d}".format(input_leng))
import time
Expand Down

0 comments on commit 619f77b

Please sign in to comment.