Skip to content

Commit

Permalink
[DevX] Support filtering benchmark configs in on-demand workflow (#7639)
Browse files Browse the repository at this point in the history
Support config filtering in ondemand benchmark flow

Co-authored-by: Github Executorch <[email protected]>
  • Loading branch information
guangy10 and Github Executorch authored Jan 22, 2025
1 parent 99912cd commit f2720fa
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 49 deletions.
Empty file added .ci/scripts/__init__.py
Empty file.
145 changes: 104 additions & 41 deletions .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import logging
import os
import re
from typing import Any, Dict
import sys
from typing import Any, Dict, List

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from examples.models import MODEL_NAME_TO_MODEL


Expand Down Expand Up @@ -45,6 +47,79 @@
}


def extract_all_configs(data, target_os=None):
if isinstance(data, dict):
# If target_os is specified, include "xplat" and the specified branch
include_branches = {"xplat", target_os} if target_os else data.keys()
return [
v
for key, value in data.items()
if key in include_branches
for v in extract_all_configs(value, target_os)
]
elif isinstance(data, list):
return [v for item in data for v in extract_all_configs(item, target_os)]
else:
return [data]


def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
"""
Generate a list of compatible benchmark configurations for a given model name and target OS.
Args:
model_name (str): The name of the model to generate configurations for.
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
Returns:
List[str]: A list of compatible benchmark configurations.
Raises:
None
Example:
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
"""
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
if target_os != "xplat":
# Add OS-specific configs
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")

return configs


def parse_args() -> Any:
"""
Parse command-line arguments.
Expand Down Expand Up @@ -82,6 +157,11 @@ def comma_separated(value: str):
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
)
parser.add_argument(
"--configs",
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
)

return parser.parse_args()

Expand All @@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
set_output("benchmark_configs", {"include": [...]})
"""

if os.getenv("GITHUB_OUTPUT"):
print(f"Setting {val} to GitHub output")
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
print(f"{name}={val}", file=env)
else:
github_output = os.getenv("GITHUB_OUTPUT")
if not github_output:
print(f"::set-output name={name}::{val}")
return

try:
with open(github_output, "a") as env:
env.write(f"{name}={val}\n")
except PermissionError:
# Fall back to printing in case of permission error in unit tests
print(f"::set-output name={name}::{val}")


Expand All @@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
return bool(re.match(pattern, model_name))


def get_benchmark_configs() -> Dict[str, Dict]:
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
"""
Gather benchmark configurations for a given set of models on the target operating system and devices.
Expand Down Expand Up @@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
}
"""
args = parse_args()
target_os = args.os
devices = args.devices
models = args.models
target_os = args.os
target_configs = args.configs

benchmark_configs = {"include": []}

for model_name in models:
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]
configs.extend(generate_compatible_configs(model_name, target_os))
print(f"Discovered all supported configs for model '{model_name}': {configs}")
if target_configs is not None:
for config in target_configs:
if config not in configs:
raise Exception(
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
f"Supported configs are: {configs}"
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
continue
configs = target_configs
print(f"Using provided configs {configs} for model '{model_name}'")

# Add configurations for each valid device
for device in devices:
Expand Down
189 changes: 189 additions & 0 deletions .ci/scripts/tests/test_gather_benchmark_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import importlib.util
import os
import subprocess
import sys
import unittest
from unittest.mock import mock_open, patch

import pytest

# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gather_benchmark_configs)


@pytest.mark.skipif(
sys.platform != "linux", reason="The script under test runs on Linux runners only"
)
class TestGatehrBenchmarkConfigs(unittest.TestCase):

def test_extract_all_configs_android(self):
android_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
)
self.assertIn("xnnpack_q8", android_configs)
self.assertIn("qnn_q8", android_configs)
self.assertIn("llama3_spinquant", android_configs)
self.assertIn("llama3_qlora", android_configs)

def test_extract_all_configs_ios(self):
ios_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
)

self.assertIn("xnnpack_q8", ios_configs)
self.assertIn("coreml_fp16", ios_configs)
self.assertIn("mps", ios_configs)
self.assertIn("llama3_coreml_ane", ios_configs)
self.assertIn("llama3_spinquant", ios_configs)
self.assertIn("llama3_qlora", ios_configs)

def test_generate_compatible_configs_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16", "llama3_coreml_ane"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_quantized_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
expected = ["llama3_spinquant"]
self.assertEqual(result, expected)

model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
expected = ["llama3_qlora"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_non_genai_model(self):
model_name = "mv2"
target_os = "xplat"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "qnn_q8"]
self.assertEqual(result, expected)

target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_unknown_model(self):
model_name = "unknown_model"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
self.assertEqual(result, [])

def test_is_valid_huggingface_model_id_valid(self):
valid_model = "meta-llama/Llama-3.2-1B"
self.assertTrue(
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
)

@patch("builtins.open", new_callable=mock_open)
@patch("os.getenv", return_value=None)
def test_set_output_no_github_env(self, mock_getenv, mock_file):
with patch("builtins.print") as mock_print:
gather_benchmark_configs.set_output("test_name", "test_value")
mock_print.assert_called_with("::set-output name=test_name::test_value")

def test_device_pools_contains_all_devices(self):
expected_devices = [
"apple_iphone_15",
"apple_iphone_15+ios_18",
"samsung_galaxy_s22",
"samsung_galaxy_s24",
"google_pixel_8_pro",
]
for device in expected_devices:
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)

def test_gather_benchmark_configs_cli(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": None,
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
self.assertIn('"model": "mv2"', result.stdout)
self.assertIn('"model": "dl3"', result.stdout)
self.assertIn('"config": "coreml_fp16"', result.stdout)
self.assertIn('"config": "xnnpack_q8"', result.stdout)
self.assertIn('"config": "mps"', result.stdout)

def test_gather_benchmark_configs_cli_specified_configs(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": "coreml_fp16,xnnpack_q8",
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
self.assertIn('"model": "mv2"', result.stdout)
self.assertIn('"model": "dl3"', result.stdout)
self.assertIn('"config": "coreml_fp16"', result.stdout)
self.assertIn('"config": "xnnpack_q8"', result.stdout)
self.assertNotIn('"config": "mps"', result.stdout)

def test_gather_benchmark_configs_cli_specified_configs_raise(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": "qnn_q8",
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit f2720fa

Please sign in to comment.