Skip to content

Commit

Permalink
add vInference evaluation script and improve sharding logic in tests #…
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jan 7, 2025
1 parent b9c4bd2 commit 498338e
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 52 deletions.
151 changes: 151 additions & 0 deletions tests/vinference_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import sys

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))


import jax
import torch
import transformers
from jax import numpy as jnp

import easydel as ed
from tqdm import tqdm
from datasets import load_dataset


def calc_accuracy(actuals, preds):
total_correct = 0
total_examples = len(actuals)
for actual, pred in zip(actuals, preds):
pred_letter = "A"
if "A" in pred:
pred_letter = "A"
if "B" in pred:
pred_letter = "B"
if "C" in pred:
pred_letter = "C"
if "D" in pred:
pred_letter = "D"
if actual == pred_letter:
total_correct += 1
acc_score = total_correct / total_examples
return acc_score


FORCE_SP = jax.device_count() > 4 # False


def main():
if jax.device_count() > 4 and not FORCE_SP:
sharding_axis_dims = (1, 1, 2, -1)
else:
sharding_axis_dims = (1, 1, 1, -1)

max_length = 4096

# pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"

partition_axis = ed.PartitionAxis()

dtype = jnp.bfloat16

print("LOADING MODEL ... ")
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_dtype=dtype,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
platform=ed.EasyDeLPlatforms.JAX,
param_dtype=dtype,
dtype=dtype,
torch_dtype=torch.float16,
partition_axis=partition_axis,
precision=jax.lax.Precision("fastest"),
)
print("MODEL LOADED")
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
print("TOKENIZER LOADED")
model.eval()
print("CREATING vInference")

inference = ed.vInference(
model=model,
processor_class=tokenizer,
generation_config=ed.vInferenceConfig(
max_new_tokens=1024,
temperature=0.0,
do_sample=False,
top_p=0.95,
top_k=10,
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
),
)

print(model.model_task)
print(model.model_type)
print("Compiling")
inference.precompile(1, inference.model_prefill_length)

print("Done Compiling")
print("Evaluating on MMLU Lite")
prompts = []
pred_list = []
actual_list = []
data = load_dataset("CohereForAI/Global-MMLU-Lite", "en", split="test")
for item in tqdm(data, total=len(data)):
question = item["question"]
option_a = item["option_a"]
option_b = item["option_b"]
option_c = item["option_c"]
option_d = item["option_d"]
actual_list.append(item["answer"])
prompt = f"Answer the following question by writing the right answer letter which can be A,B,C or D. Write only the correct answer letter in your response. \nQuestion : {question}\nA. {option_a}. \nB. {option_b}. \nC. {option_c}. \nD. {option_d}"
prompts.append(prompt)
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt},
]
ids = tokenizer.apply_chat_template(
messages,
return_tensors="jax",
return_dict=True,
max_length=inference.model_prefill_length,
padding="max_length",
add_generation_prompt=True,
)

pad_seq = inference.model_prefill_length
for response in inference.generate(**ids):
next_slice = slice(
pad_seq,
pad_seq + inference.generation_config.streaming_chunks,
)
pad_seq += inference.generation_config.streaming_chunks
output = tokenizer.decode(
response.sequences[0][next_slice],
skip_special_tokens=True,
)
pred_list.append(output)
for prompt, pred in zip(prompts, pred_list):
print("--------------------------------------")
print(f"Prompt: {prompt}\nPrediction : {pred}")
print("---------- Evaluation Score -----------------")
acc_score = calc_accuracy(actual_list, pred_list)
print(f"accuracy score : {acc_score}")


if __name__ == "__main__":
main()
85 changes: 33 additions & 52 deletions tests/vinference_test.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
# fmt:off
import os
import sys
import threading
import time

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

import easydel as ed
# fmt:on

import jax
import torch
import transformers
from huggingface_hub import HfApi
from jax import numpy as jnp
from jax import sharding


PartitionSpec, api = sharding.PartitionSpec, HfApi()


def log_mem():
while True:
ed.utils.analyze_memory.SMPMemoryMonitor(5).print_current_status()
time.sleep(5)


threading.Thread(target=log_mem) # .start()
import easydel as ed


def main():
sharding_axis_dims = (1, 1, 1, -1)
if jax.device_count() > 4:
sharding_axis_dims = (1, 1, 2, -1)
else:
sharding_axis_dims = (1, 1, 1, -1)

max_length = 4096

pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
# pretrained_model_name_or_path = "AntonV/mamba2-370m-hf"
# pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"

partition_axis = ed.PartitionAxis()

Expand All @@ -51,12 +39,10 @@ def main():
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
attn_mechanism=ed.AttentionMechanisms.VANILLA,
# use_scan_mlp=True,
# scan_mlp_chunk_size=128,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
platform=ed.EasyDeLPlatforms.TRITON,
param_dtype=jnp.float8_e5m2, # dtype, #
platform=ed.EasyDeLPlatforms.JAX,
param_dtype=dtype,
dtype=dtype,
torch_dtype=torch.float16,
partition_axis=partition_axis,
Expand All @@ -68,11 +54,6 @@ def main():
tokenizer.pad_token_id = tokenizer.eos_token_id
print("TOKENIZER LOADED")
model.eval()
# model = model.quantize(
# method=ed.EasyDeLQuantizationMethods.A8BIT,
# block_size=128,
# quantization_pattern=".*(gate_proj|up_proj).*",
# )
print("CREATING vInference")

inference = ed.vInference(
Expand All @@ -96,12 +77,10 @@ def main():

print("Done Compiling")
messages = [
# {
# "role": "system",
# "content": "Please reason step by step, and put your final answer within \\boxed{}. and give 3 different responses",
# },
# {"role": "user", "content": "Find the value of $x$ that satisfies the equation $4x+5 = 6x+7$."},
{"role": "system", "content": "You are a helpful AI assistant."},
{
"role": "system",
"content": "You are a helpful AI assistant.",
},
{
"role": "user",
"content": "Can you provide ways to eat combinations of bananas and dragonfruits?",
Expand All @@ -110,7 +89,10 @@ def main():
"role": "assistant",
"content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.",
},
{"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
{
"role": "user",
"content": "What about solving an 2x + 3 = 7 equation?",
},
]

ids = tokenizer.apply_chat_template(
Expand All @@ -125,21 +107,20 @@ def main():
pad_seq = inference.model_prefill_length

print("Start Generation Process.")
with jax.profiler.trace("tmp-files/vinference"):
for response in inference.generate(**ids):
next_slice = slice(
pad_seq,
pad_seq + inference.generation_config.streaming_chunks,
)
pad_seq += inference.generation_config.streaming_chunks
print(
tokenizer.decode(response.sequences[0][next_slice], skip_special_tokens=True),
end="",
)

print()
print(response.generated_tokens)
print("TPS :", response.tokens_pre_second)
for response in inference.generate(**ids):
next_slice = slice(
pad_seq,
pad_seq + inference.generation_config.streaming_chunks,
)
pad_seq += inference.generation_config.streaming_chunks
print(
tokenizer.decode(response.sequences[0][next_slice], skip_special_tokens=True),
end="",
)

print()
print(response.generated_tokens)
print("TPS :", response.tokens_pre_second)


if __name__ == "__main__":
Expand Down

0 comments on commit 498338e

Please sign in to comment.