From 498338e3930f553f0edf4ebcf443ec55c2fc3c4e Mon Sep 17 00:00:00 2001 From: erfanzar Date: Tue, 7 Jan 2025 09:08:35 +0000 Subject: [PATCH] add vInference evaluation script and improve sharding logic in tests #181 --- tests/vinference_eval.py | 151 +++++++++++++++++++++++++++++++++++++++ tests/vinference_test.py | 85 +++++++++------------- 2 files changed, 184 insertions(+), 52 deletions(-) create mode 100644 tests/vinference_eval.py diff --git a/tests/vinference_eval.py b/tests/vinference_eval.py new file mode 100644 index 00000000..0e29646a --- /dev/null +++ b/tests/vinference_eval.py @@ -0,0 +1,151 @@ +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + + +import jax +import torch +import transformers +from jax import numpy as jnp + +import easydel as ed +from tqdm import tqdm +from datasets import load_dataset + + +def calc_accuracy(actuals, preds): + total_correct = 0 + total_examples = len(actuals) + for actual, pred in zip(actuals, preds): + pred_letter = "A" + if "A" in pred: + pred_letter = "A" + if "B" in pred: + pred_letter = "B" + if "C" in pred: + pred_letter = "C" + if "D" in pred: + pred_letter = "D" + if actual == pred_letter: + total_correct += 1 + acc_score = total_correct / total_examples + return acc_score + + +FORCE_SP = jax.device_count() > 4 # False + + +def main(): + if jax.device_count() > 4 and not FORCE_SP: + sharding_axis_dims = (1, 1, 2, -1) + else: + sharding_axis_dims = (1, 1, 1, -1) + + max_length = 4096 + + # pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct" + pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct" + + partition_axis = ed.PartitionAxis() + + dtype = jnp.bfloat16 + + print("LOADING MODEL ... ") + model = ed.AutoEasyDeLModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + auto_shard_model=True, + sharding_axis_dims=sharding_axis_dims, + config_kwargs=ed.EasyDeLBaseConfigDict( + freq_max_position_embeddings=max_length, + mask_max_position_embeddings=max_length, + attn_dtype=dtype, + gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE, + kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE, + attn_mechanism=ed.AttentionMechanisms.VANILLA, + ), + quantization_method=ed.EasyDeLQuantizationMethods.NONE, + platform=ed.EasyDeLPlatforms.JAX, + param_dtype=dtype, + dtype=dtype, + torch_dtype=torch.float16, + partition_axis=partition_axis, + precision=jax.lax.Precision("fastest"), + ) + print("MODEL LOADED") + tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer.padding_side = "left" + tokenizer.pad_token_id = tokenizer.eos_token_id + print("TOKENIZER LOADED") + model.eval() + print("CREATING vInference") + + inference = ed.vInference( + model=model, + processor_class=tokenizer, + generation_config=ed.vInferenceConfig( + max_new_tokens=1024, + temperature=0.0, + do_sample=False, + top_p=0.95, + top_k=10, + eos_token_id=model.generation_config.eos_token_id, + streaming_chunks=32, + ), + ) + + print(model.model_task) + print(model.model_type) + print("Compiling") + inference.precompile(1, inference.model_prefill_length) + + print("Done Compiling") + print("Evaluating on MMLU Lite") + prompts = [] + pred_list = [] + actual_list = [] + data = load_dataset("CohereForAI/Global-MMLU-Lite", "en", split="test") + for item in tqdm(data, total=len(data)): + question = item["question"] + option_a = item["option_a"] + option_b = item["option_b"] + option_c = item["option_c"] + option_d = item["option_d"] + actual_list.append(item["answer"]) + prompt = f"Answer the following question by writing the right answer letter which can be A,B,C or D. Write only the correct answer letter in your response. \nQuestion : {question}\nA. {option_a}. \nB. {option_b}. \nC. {option_c}. \nD. {option_d}" + prompts.append(prompt) + messages = [ + {"role": "system", "content": "You are a helpful AI assistant."}, + {"role": "user", "content": prompt}, + ] + ids = tokenizer.apply_chat_template( + messages, + return_tensors="jax", + return_dict=True, + max_length=inference.model_prefill_length, + padding="max_length", + add_generation_prompt=True, + ) + + pad_seq = inference.model_prefill_length + for response in inference.generate(**ids): + next_slice = slice( + pad_seq, + pad_seq + inference.generation_config.streaming_chunks, + ) + pad_seq += inference.generation_config.streaming_chunks + output = tokenizer.decode( + response.sequences[0][next_slice], + skip_special_tokens=True, + ) + pred_list.append(output) + for prompt, pred in zip(prompts, pred_list): + print("--------------------------------------") + print(f"Prompt: {prompt}\nPrediction : {pred}") + print("---------- Evaluation Score -----------------") + acc_score = calc_accuracy(actual_list, pred_list) + print(f"accuracy score : {acc_score}") + + +if __name__ == "__main__": + main() diff --git a/tests/vinference_test.py b/tests/vinference_test.py index 2f02f838..999f9218 100644 --- a/tests/vinference_test.py +++ b/tests/vinference_test.py @@ -1,39 +1,27 @@ -# fmt:off import os import sys -import threading -import time sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) -import easydel as ed -# fmt:on + import jax import torch import transformers -from huggingface_hub import HfApi from jax import numpy as jnp -from jax import sharding - - -PartitionSpec, api = sharding.PartitionSpec, HfApi() - - -def log_mem(): - while True: - ed.utils.analyze_memory.SMPMemoryMonitor(5).print_current_status() - time.sleep(5) - -threading.Thread(target=log_mem) # .start() +import easydel as ed def main(): - sharding_axis_dims = (1, 1, 1, -1) + if jax.device_count() > 4: + sharding_axis_dims = (1, 1, 2, -1) + else: + sharding_axis_dims = (1, 1, 1, -1) + max_length = 4096 - pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct" - # pretrained_model_name_or_path = "AntonV/mamba2-370m-hf" + # pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct" + pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct" partition_axis = ed.PartitionAxis() @@ -51,12 +39,10 @@ def main(): gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE, attn_mechanism=ed.AttentionMechanisms.VANILLA, - # use_scan_mlp=True, - # scan_mlp_chunk_size=128, ), quantization_method=ed.EasyDeLQuantizationMethods.NONE, - platform=ed.EasyDeLPlatforms.TRITON, - param_dtype=jnp.float8_e5m2, # dtype, # + platform=ed.EasyDeLPlatforms.JAX, + param_dtype=dtype, dtype=dtype, torch_dtype=torch.float16, partition_axis=partition_axis, @@ -68,11 +54,6 @@ def main(): tokenizer.pad_token_id = tokenizer.eos_token_id print("TOKENIZER LOADED") model.eval() - # model = model.quantize( - # method=ed.EasyDeLQuantizationMethods.A8BIT, - # block_size=128, - # quantization_pattern=".*(gate_proj|up_proj).*", - # ) print("CREATING vInference") inference = ed.vInference( @@ -96,12 +77,10 @@ def main(): print("Done Compiling") messages = [ - # { - # "role": "system", - # "content": "Please reason step by step, and put your final answer within \\boxed{}. and give 3 different responses", - # }, - # {"role": "user", "content": "Find the value of $x$ that satisfies the equation $4x+5 = 6x+7$."}, - {"role": "system", "content": "You are a helpful AI assistant."}, + { + "role": "system", + "content": "You are a helpful AI assistant.", + }, { "role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?", @@ -110,7 +89,10 @@ def main(): "role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.", }, - {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"}, + { + "role": "user", + "content": "What about solving an 2x + 3 = 7 equation?", + }, ] ids = tokenizer.apply_chat_template( @@ -125,21 +107,20 @@ def main(): pad_seq = inference.model_prefill_length print("Start Generation Process.") - with jax.profiler.trace("tmp-files/vinference"): - for response in inference.generate(**ids): - next_slice = slice( - pad_seq, - pad_seq + inference.generation_config.streaming_chunks, - ) - pad_seq += inference.generation_config.streaming_chunks - print( - tokenizer.decode(response.sequences[0][next_slice], skip_special_tokens=True), - end="", - ) - - print() - print(response.generated_tokens) - print("TPS :", response.tokens_pre_second) + for response in inference.generate(**ids): + next_slice = slice( + pad_seq, + pad_seq + inference.generation_config.streaming_chunks, + ) + pad_seq += inference.generation_config.streaming_chunks + print( + tokenizer.decode(response.sequences[0][next_slice], skip_special_tokens=True), + end="", + ) + + print() + print(response.generated_tokens) + print("TPS :", response.tokens_pre_second) if __name__ == "__main__":