Skip to content

Commit

Permalink
fixes, add data
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Dec 8, 2024
1 parent b094c12 commit 2d7910b
Show file tree
Hide file tree
Showing 20 changed files with 317,026 additions and 278 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

164 changes: 122 additions & 42 deletions benchmarking/get_wildchat_trace.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,144 @@
import datasets
from transformers import AutoTokenizer
from tqdm import tqdm
import json, os
import json, os, argparse
from dataclasses import asdict, dataclass, field
from typing import List, Optional

def build_trace(dataset: datasets.Dataset, model_name: str, num_entries: int, seed: int):


@dataclass
class TraceEntry:
prompt: str
response: str
prompt_length: int
response_length: int

@dataclass
class TraceMetadata:
num_warmup_requests: int
avg_entries_per_partition: float
max_prompt_length: int
min_prompt_length: int
avg_prompt_length: float
max_response_length: int
min_response_length: int
avg_response_length: float
max_total_length: int

@dataclass
class Trace:
entries: List[TraceEntry] = field(default_factory=list)
metadata: TraceMetadata = field(default_factory=lambda: TraceMetadata(0, 0, 0, 0, 0, 0, 0, 0,0))

def build_trace(
dataset: datasets.Dataset, model_name: str, num_entries: int, max_length: int, seed: int, apply_chat_template: bool = False
):
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = dataset["train"].filter(
lambda x: x["model"] == "gpt-4" and x["turn"] == 1 and x["language"] == "English"
).shuffle(seed=seed).select(range(num_entries))

dataset = (
dataset["train"]
.filter(
lambda x: x["model"] == "gpt-4"
and x["turn"] == 1
and x["language"] == "English"
)
.shuffle(seed=seed)
.select(range(num_entries*3))
)
pairs = []
for row in dataset:
assert len(row["conversation"]) == 2
assert row["conversation"][0]["role"] == "user"
assert row["conversation"][1]["role"] == "assistant"
pairs.append((
row["conversation"][0]["content"],
row["conversation"][1]["content"],
))
pairs.append(
(
row["conversation"][0]["content"],
row["conversation"][1]["content"],
)
)

trace = Trace()
trace_metadata = TraceMetadata(
num_warmup_requests=0,
avg_entries_per_partition=0,
max_prompt_length=0,
min_prompt_length=float("inf"),
avg_prompt_length=0,
max_response_length=0,
min_response_length=float("inf"),
avg_response_length=0,
max_total_length=0,
)

prompts = []
avg_prompt_length = 0
min_prompt_length = float("inf")
max_prompt_length = 0
avg_response_length = 0
min_response_length = float("inf")
max_response_length = 0
max_total_length = 0
for prompt, response in tqdm(pairs, desc="Processing HF trace"):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
prompt_length = len(tokenizer(prompt)["input_ids"])
response_length = len(tokenizer(response)["input_ids"])
prompts.append(prompt)
avg_prompt_length += prompt_length
avg_response_length += response_length
min_prompt_length = min(min_prompt_length, prompt_length)
min_response_length = min(min_response_length, response_length)
max_prompt_length = max(max_prompt_length, prompt_length)
max_response_length = max(max_response_length, response_length)
max_total_length = max(max_total_length, prompt_length + response_length)
avg_prompt_length /= len(prompts)
avg_response_length /= len(prompts)
if prompt_length + response_length > max_length:
continue
new_entry = TraceEntry(prompt, response, prompt_length, response_length)
trace.entries.append(new_entry)
trace_metadata.max_prompt_length = max(trace_metadata.max_prompt_length, prompt_length)
trace_metadata.min_prompt_length = min(trace_metadata.min_prompt_length, prompt_length)
trace_metadata.avg_prompt_length += prompt_length
trace_metadata.max_response_length = max(trace_metadata.max_response_length, response_length)
trace_metadata.min_response_length = min(trace_metadata.min_response_length, response_length)
trace_metadata.avg_response_length += response_length
trace_metadata.max_total_length = max(trace_metadata.max_total_length, prompt_length + response_length)
if len(trace.entries) == num_entries:
break
trace_metadata.avg_prompt_length /= len(trace.entries)
trace_metadata.avg_response_length /= len(trace.entries)
trace_metadata.avg_entries_per_partition = len(trace.entries)

trace.metadata = trace_metadata

return prompts, max_prompt_length, max_response_length, avg_prompt_length, avg_response_length, min_prompt_length, min_response_length, max_total_length
return trace

def save_trace(trace: Trace, output_path: str):
"""
Save a Trace instance to a JSON file.
Args:
trace (Trace): The trace to save.
output_path (str): The path where the JSON file will be saved.
"""
# Convert the Trace instance to a dictionary
trace_dict = asdict(trace)

# Save the dictionary as a JSON file
with open(output_path, 'w') as f:
json.dump(trace_dict, f, indent=2)

print(f"Trace saved to {output_path}")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build WildChat trace")
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-70B-Instruct", help="Model name")
parser.add_argument("-m", "--max-length", type=int, default=8190, help="Maximum prompt + response length")
parser.add_argument("-n", "--num_entries", type=int, default=250, help="Number of entries")
parser.add_argument("-s", "--seed", type=int, default=12345, help="Random seed")
parser.add_argument("-o", "--output_file", type=str, default="./wildchat.json", help="Output file name")
args = parser.parse_args()

# Change directory to that holding this script
os.chdir(os.path.dirname(os.path.abspath(__file__)))

dataset = datasets.load_dataset("allenai/WildChat")
prompts, max_prompt_length, max_response_length, avg_prompt_length, avg_response_length, min_prompt_length, min_response_length, max_total_length = build_trace(dataset, "meta-llama/Llama-3.1-70B-Instruct", 250, 42)
print(f"Number of prompts: {len(prompts)}")
print(f"Prompt lengths: [{min_prompt_length} -> {max_prompt_length}] (avg: {avg_prompt_length})")
print(f"Response lengths: [{min_response_length} -> {max_response_length}] (avg: {avg_response_length})")
print(f"Max total length: {max_total_length}")
trace = build_trace(dataset, args.model_name, args.num_entries, args.max_length, args.seed, apply_chat_template=False)
print("Build trace with the following metadata:")
print(trace.metadata)

# Save prompts list to a json file

with open("wildchat.json", "w") as f:
json.dump(prompts, f, indent=2)
num_above_2048 = 0
for entry in trace.entries:
if entry.prompt_length + entry.response_length > 2048:
num_above_2048 += 1
print(f"Number of entries above 2048 tokens: {num_above_2048}")
save_trace(trace, args.output_file)
2 changes: 1 addition & 1 deletion benchmarking/overhead_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ZSIZE=200000
# ZSIZE=20000

OUTPUT_FOLDER="../inference/output/overhead_test"
MAX_SEQ_LEN=2048
MAX_SEQ_LEN=8192
BATCH_SIZE=8

max_tokens_per_batch_values=(
Expand Down
2 changes: 1 addition & 1 deletion benchmarking/plot_finetuning_overheads.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def plot_bwd_overhead(filepath, num_tokens_per_batch):
if not os.path.exists('./plots'):
os.makedirs('./plots')

tp_degree=8
tp_degree=4

for tokens_per_batch in [128, 256, 512]:
fp=f"../inference/output/overhead_test/step_profiling_meta-llama_llama-3.1-70b_tensor_parallelism_{tp_degree}_max_requests_per_batch_8_max_tokens_per_batch_{tokens_per_batch}_arrival_rate_0.000000_num_warmup_requests_10.csv"
Expand Down
Binary file added benchmarking/plots/bwd_overhead_128.pdf
Binary file not shown.
Binary file added benchmarking/plots/bwd_overhead_256.pdf
Binary file not shown.
Binary file added benchmarking/plots/bwd_overhead_512.pdf
Binary file not shown.
Binary file added benchmarking/plots/fwd_overhead_128.pdf
Binary file not shown.
Binary file added benchmarking/plots/fwd_overhead_256.pdf
Binary file not shown.
Binary file added benchmarking/plots/fwd_overhead_512.pdf
Binary file not shown.
Loading

0 comments on commit 2d7910b

Please sign in to comment.