From b7407959375271da983972b81bf13bb3fdea01e6 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Tue, 3 Dec 2024 20:54:10 -0800 Subject: [PATCH] Add input len 2048 test for bs4 tp1 8b f16 non-decomposed prefill (#634) Add test for llama 3.1 8b f16 non-decomposed prefill bs4 tp1 for input len 2048. Signed-off-by: aviator19941 --- .../models/llama/benchmark_amdgpu_test.py | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 31d5fbb64..166b0b78a 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -110,6 +110,9 @@ def setUp(self): self.prefill_args_bs4_128_in_tokens_f16 = ( self.artifacts_dir / "prefill_args_bs4_128" ) + self.prefill_args_bs4_2048_in_tokens_f16 = ( + self.artifacts_dir / "prefill_args_bs4_2048" + ) self.decode_args_f16 = self.artifacts_dir / "decode_args" self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" @@ -129,6 +132,14 @@ def setUp(self): f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] + self.iree_run_prefill_nondecomposed_args_fp16_2048 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens_2048.npy", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/cs_f16.npy", + "--benchmark_repetitions=3", + ] self.iree_run_decode_args = [ "--function=decode_bs4", f"--input=@{self.decode_args_f16}/tokens.npy", @@ -196,8 +207,42 @@ def testBenchmark8B_f16_Decomposed(self): ) @skipif_run_quick_llama_test - def testBenchmark8B_f16_Non_Decomposed_Prefill(self): - output_file_name = self.dir_path_8b / "f16_torch_prefill" + def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): + output_file_name = self.dir_path_8b / "f16_torch_prefill_128" + output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + skip_decode=True, + ) + self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.iree_device, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_nondecomposed_args_fp16, + cwd=self.repo_root, + ) + + @skipif_run_quick_llama_test + def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self): + output_file_name = self.dir_path_8b / "f16_torch_prefill_2048" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".mlir", prefix=output_file_name )