-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathmain.py
332 lines (304 loc) · 12.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import torch
from quant import *
from outlier import *
from eval import *
from collections import defaultdict
from pprint import pprint
from modelutils_llama import quantize_model_llama, reorder_model_llama, quantize_model_gptq_llama, add_act_quant_wrapper_llama
from modelutils_opt import quantize_model_opt, reorder_model_opt, quantize_model_gptq_opt, add_act_quant_wrapper_opt
from modelutils_mixtral import quantize_model_mixtral, add_act_quant_wrapper_mixtral, reorder_model_mixtral
from parallel_utils import map_layers_to_multi_gpus
from LMClass import LMClass
from eval import pattern_match
from lm_eval import tasks as lm_tasks
from lm_eval import evaluator as lm_evaluator
def get_llama(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
model.seqlen = 2048
return model
def get_opt(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
model.seqlen = model.config.max_position_embeddings
return model
def get_mixtral(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
model.seqlen = 2048
return model
if __name__ == '__main__':
import argparse
from datautils import *
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str,
help='LlaMa model to load; pass location of hugginface converted checkpoint.'
)
parser.add_argument(
'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.'
)
parser.add_argument(
'--seed',
type=int, default=0,
help='Seed for sampling the calibration data.'
)
parser.add_argument(
'--nsamples', type=int, default=128,
help='Number of calibration data samples.'
)
# Quantization Method
parser.add_argument(
'--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
help='#bits to use for quantizing weight; use 16 for evaluating base model.'
)
parser.add_argument(
'--abits', type=int, default=16, choices=[2, 3, 4, 8, 16],
help='#bits to use for quantizing activation; use 16 for evaluating base model.'
)
parser.add_argument(
'--exponential', action='store_true',
help='Whether to use exponent-only for weight quantization.'
)
parser.add_argument(
'--a_sym', action='store_true',
help='Whether to perform symmetric quantization. Default is asymmetric.'
)
parser.add_argument(
'--w_sym', action='store_true',
help='Whether to perform symmetric quantization. Default is asymmetric.'
)
parser.add_argument(
'--static', action='store_true',
help='Whether to perform static quantization (For activtions). Default is dynamic. (Deprecated in Atom)'
)
parser.add_argument(
'--weight_group_size', type=int, default=0, choices=[0, 32, 64, 128, 256, 384, 768],
help='Group size when quantizing weights. Using 128 as default quantization group.'
)
parser.add_argument(
'--weight_channel_group', type=int, default=1,
help='Group size of channels that will quantize together. (only for weights now)'
)
parser.add_argument(
'--act_group_size', type=int, default=0, choices=[0, 64, 128, 256, 384, 768],
help='Group size when quantizing activations. Using 128 as default quantization group.'
)
parser.add_argument(
'--reorder', action='store_true',
help='Whether to keep salient weight unquantized.'
)
parser.add_argument(
'--act_sort_metric', type=str, default='hessian', choices=['abs_mean', 'hessian'],
help='The metric used to sort the activations.'
)
parser.add_argument(
'--keeper', type=int, default=0,
help='Group size to keep outliers.'
)
parser.add_argument(
'--keeper_precision', type=int, default=0, choices=[0, 1, 2, 3],
help='Precision to keep outliers. 0 for FP16; 1 for E5M2; 2 for E4M3; 3 for INT8 Quant.'
)
parser.add_argument(
'--cache_index', action='store_true',
help='Whether to use cached reorder index'
)
parser.add_argument(
'--tiling', type=int, default=0, choices=[0, 16],
help='Tile-wise quantization granularity (Deprecated in Atom).'
)
parser.add_argument(
'--kv_cache', action='store_true',
help='Whether to quant KV_Cache'
)
parser.add_argument(
'--use_gptq', action='store_true',
help='Whether to use GPTQ for weight quantization.'
)
parser.add_argument(
'--percdamp', type=float, default=.01,
help='Percent of the average Hessian diagonal to use for dampening.'
)
parser.add_argument(
'--a_clip_ratio', type=float, default=1.0,
help='Clip ratio for activation quantization. new_max = max * clip_ratio'
)
parser.add_argument(
'--w_clip_ratio', type=float, default=1.0,
help='Clip ratio for weight quantization. new_max = max * clip_ratio'
)
parser.add_argument(
'--kv_clip_ratio', type=float, default=1.0,
help='Clip ratio for kv cache quantization. new_max = max * clip_ratio'
)
parser.add_argument(
"--eval_ppl", action="store_true",
help='Whether to evaluate perplexity.'
)
parser.add_argument(
"--eval_common_sense", action="store_true",
help='Whether to evaluate zero-shot accuray on commonsense reasoning tasks.'
)
parser.add_argument(
"--multigpu", action="store_true",
help="at eval, map model to multiple gpus"
)
parser.add_argument(
"--lm_eval_num_fewshot", type=int, default=0,
help="Number of shots in lm evaluation. Default is 0 for zero-shot."
)
parser.add_argument(
"--lm_eval_limit", type=int, default=-1,
help="Limit the number of examples in lm evaluation"
)
parser.add_argument(
'--save_dir', type=str, default='./saved',
help='Path to store the reordering indices and quantized weights.'
)
parser.add_argument(
'--quant_type', type=str, default='int', choices=['int', 'fp'],
help='Determine the mapped data format by quant_type + n_bits. e.g. int8, fp4.'
)
args = parser.parse_args()
model_name = args.model.lower().split('/')[-1]
assert model_name != None, "Please check the model path."
if "llama" in args.model.lower():
model = get_llama(args.model)
get_act_stats_func = get_act_stats_llama
reorder_model_func = reorder_model_llama
add_act_quant_wrapper_func = add_act_quant_wrapper_llama
quantize_model_gptq_func = quantize_model_gptq_llama
quantize_model_func = quantize_model_llama
eval_func = llama_eval
elif "opt" in args.model.lower():
model = get_opt(args.model)
get_act_stats_func = get_act_stats_opt
reorder_model_func = reorder_model_opt
add_act_quant_wrapper_func = add_act_quant_wrapper_opt
quantize_model_gptq_func = quantize_model_gptq_opt
quantize_model_func = quantize_model_opt
eval_func = opt_eval
elif "mixtral" in args.model.lower():
model = get_mixtral(args.model)
get_act_stats_func = get_act_stats_llama
reorder_model_func = reorder_model_mixtral
add_act_quant_wrapper_func = add_act_quant_wrapper_mixtral
quantize_model_gptq_func = quantize_model_gptq_llama
quantize_model_func = quantize_model_mixtral
eval_func = llama_eval
model.eval()
import os
if args.reorder:
if args.cache_index == False:
dataloader, testloader = get_loaders(
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print("Getting activation stats...")
act_scales = get_act_stats_func(
model, dataloader, DEV, metric=args.act_sort_metric
)
print("Getting reording index...")
reorder_index = get_reorder_index(model, act_scales)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
torch.save(reorder_index, f'{args.save_dir}/{model_name}_reorder_index_{args.dataset}.pt')
else:
index_filename = f'{args.save_dir}/{model_name}_reorder_index_{args.dataset}.pt'
assert os.path.isfile(index_filename), "reorder index file not found."
print("Loading cached reording index from disk...")
reorder_index = torch.load(index_filename)
print("Reordering model...")
model = reorder_model_func(
model, device=DEV, args=args, reorder_index=reorder_index
)
if args.abits < 16:
print("Inserting activations quantizers ...")
scales = defaultdict(lambda: None)
model = add_act_quant_wrapper_func(model, device=DEV, args=args, scales=scales)
if args.wbits < 16:
print("Quantizing...")
if args.use_gptq:
dataloader, testloader = get_loaders(
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
model = quantize_model_gptq_func(model, device=DEV, args=args, dataloader=dataloader)
else:
model = quantize_model_func(model, device=DEV, args=args)
if args.eval_ppl:
datasets = ['wikitext2', 'ptb', 'c4']
for dataset in datasets:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print(f"Evaluating {dataset} ...")
ppl = eval_func(model, testloader, DEV)
print(f"targetResult,{dataset},{ppl:.3f}")
# eval zero shot accuracy on commonsense datasets
if args.eval_common_sense:
lm = LMClass(args, model)
lm.seqlen = 2048
lm.model.eval()
for param in lm.model.parameters():
param.requires_grad = False
if args.multigpu:
if ("llama" in args.model.lower()) or ("mixtral" in args.model.lower()):
map_layers_to_multi_gpus(lm.model.model.layers)
input_device = lm.model.model.layers[0].device
output_device = lm.model.model.layers[-1].device
assert input_device == output_device
lm._device = input_device
lm.model.model.embed_tokens.to(input_device)
lm.model.model.norm.to(output_device)
lm.model.lm_head.to(output_device)
elif "opt" in args.model.lower():
map_layers_to_multi_gpus(lm.model.model.decoder.layers)
input_device = lm.model.model.decoder.layers[0].device
output_device = lm.model.model.decoder.layers[-1].device
assert input_device == output_device
lm._device = input_device
lm.model.model.decoder.embed_tokens.to(input_device)
lm.model.model.decoder.embed_positions.to(input_device)
lm.model.model.decoder.final_layer_norm.to(input_device)
lm.model.lm_head.to(output_device)
else:
lm._device = DEV
lm.model = lm.model.to(lm.device)
results = {}
tasks_str = "piqa,arc_easy,arc_challenge,boolq,hellaswag,winogrande"
task_names = pattern_match(tasks_str.split(","), lm_tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
task_dict = lm_tasks.get_task_dict(task_names)
t_results = lm_evaluator.evaluate(
lm,
task_dict,
num_fewshot=args.lm_eval_num_fewshot,
limit=None if args.lm_eval_limit == -1 else args.lm_eval_limit
)
results.update(t_results)
pprint(results)
results_dict = results['results']
for task_name in tasks_str.split(','):
if task_name in ['piqa', 'arc_easy', 'arc_challenge', 'hellaswag']:
print(f"INFO {task_name} : {results_dict[task_name]['acc_norm']*100:.2f}")
else:
print(f"INFO {task_name} : {results_dict[task_name]['acc']*100:.2f}")