forked from turboderp-org/exllamav2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhumaneval.py
221 lines (179 loc) · 7.85 KB
/
humaneval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from __future__ import annotations
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import argparse, contextlib, subprocess
import util
# Args
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
parser.add_argument("-cs", "--cache_size", type = int, default = None)
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6", default = 0.6)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-k sampling, default: 50", default = 50)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6)
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0)
model_init.add_args(parser)
args = parser.parse_args()
# Validate args
directory = os.path.dirname(args.output)
if directory and not os.path.isdir(directory):
print(f" ## Directory for output file {args.output} does not exist.")
sys.exit()
if os.path.exists(args.output):
print(f" !! Warning: Output file exists and will be overwritten.")
# Prompt formats
prompt_formats = {
"raw": (
"```python\n{{problem}} ",
" "
),
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama": (
"[INST] <<SYS>>\n"
"You are a helpful AI coding assistant.\n"
"<</SYS>>\n\n"
"Complete the following Python function:\n\n"
"{{problem}} [/INST] "
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama3": (
"<|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful AI coding assistant.<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"gemma": (
"<bos><start_of_turn>user\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<start_of_turn>model\n"
"```python\n{{problem}}",
" "
)
}
if args.prompt_format is None:
prompt_format, prefix = "{{problem}}", " "
elif args.prompt_format in prompt_formats:
prompt_format, prefix = prompt_formats[args.prompt_format]
else:
print("Prompt format is not supported. Available formats:")
print("\n".join(prompt_formats.keys()))
sys.exit()
# Init model and cache
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
progress = True,
max_output_len = 4,
max_input_len = 2048
)
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
elif args.cache_q6: cache_type = ExLlamaV2Cache_Q6
elif args.cache_q8: cache_type = ExLlamaV2Cache_Q8
else: cache_type = ExLlamaV2Cache
cache = cache_type(
model,
lazy = not model.loaded,
max_seq_len = args.cache_size or model.config.max_seq_len
)
if not model.loaded:
model.load_autosplit(cache, progress = True)
# Generator
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
max_batch_size = 256,
max_q_size = 4
)
gen_settings = ExLlamaV2Sampler.Settings(
token_repetition_penalty = args.repetition_penalty,
temperature = args.temperature,
top_k = args.top_k,
top_p = args.top_p
)
# Get problems
problems = read_problems()
num_samples_per_task = args.samples_per_task
# Create jobs
with util.get_progress() as progress:
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
for problem_id, problem in problems.items():
b_problem = problem["prompt"]
f_problem = prompt_format.replace("{{problem}}", b_problem)
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)
for s in range(num_samples_per_task):
job = ExLlamaV2DynamicJob(
input_ids = input_ids,
gen_settings = gen_settings,
max_new_tokens = args.max_tokens,
stop_conditions = [tokenizer.eos_token_id],
token_healing = True,
identifier = (problem_id, s),
min_new_tokens = 6
)
generator.enqueue(job)
progress.update(task1, advance = 1)
# Collect samples here
samples = []
# Work
total_jobs = generator.num_remaining_jobs()
cm = contextlib.nullcontext() if args.verbose else util.get_progress()
with cm as progress:
if not args.verbose:
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
while generator.num_remaining_jobs():
results = generator.iterate()
for result in results:
# End sample if generator says EOS or if there is a non-indented line at the end of the output
job = result["job"]
eos = False
completion = job.full_completion
last_newline_index = completion.rfind("\n")
if last_newline_index >= 0:
last_line = completion[last_newline_index + 1:]
if last_line != "" and not last_line[0].isspace():
completion = completion[:last_newline_index]
eos = True
eos = eos or result["eos"]
# Collect completed sample
if eos:
identifier = result["identifier"]
sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
if not result["eos"]:
generator.cancel(job)
if args.verbose:
print("----------------------------------------------------------------------")
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
print("----------------------------------------------------------------------")
print(sample)
print()
else:
progress.update(task1, advance = 1)
samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
# Save output
print(f" -- Saving: {args.output}")
write_jsonl(args.output, samples)
# Optionally launch eval script
if args.eval:
subprocess.run(["evaluate_functional_correctness", args.output])