diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 37df9bf7d..44e00efeb 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -727,6 +727,35 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis return collision_detected, unique_indexes +def benchmark_candidates(candidate_indices, devices, tuning_client, candidate_trackers): + """ + Runs the benchmarking for a given list of candidate indices. + """ + # Create worker context queue + worker_context_queue = create_worker_context_queue(devices) + + # Prepare task list + task_list = [ + BenchmarkPack( + iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), + benchmark_timeout=tuning_client.get_benchmark_timeout_s(), + candidate_tracker=candidate_trackers[idx], + ) + for idx in candidate_indices + ] + + # Perform benchmarking + benchmark_results = multiprocess_progress_wrapper( + num_worker=len(devices), + task_list=task_list, + function=run_iree_benchmark_module_command, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + return benchmark_results + + def compile( args: argparse.Namespace, path_config: PathConfig, @@ -819,41 +848,24 @@ def benchmark( logging.debug("benchmark()") # Benchmarking baselines on each involved device. - worker_context_queue = create_worker_context_queue(args.devices) - baseline_task_list = [ - BenchmarkPack( - iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), - benchmark_timeout=tuning_client.get_benchmark_timeout_s(), - candidate_tracker=candidate_trackers[0], - ) - ] * len(args.devices) - baseline_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=baseline_task_list, - function=run_iree_benchmark_module_command, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), + baseline_indices = [0] * len(args.devices) + baseline_results = benchmark_candidates( + candidate_indices=baseline_indices, + devices=args.devices, + tuning_client=tuning_client, + candidate_trackers=candidate_trackers, ) + baseline_times_by_device = {} for r in baseline_results: baseline_times_by_device[r.device_id] = r.time - task_list = [ - BenchmarkPack( - iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), - benchmark_timeout=tuning_client.get_benchmark_timeout_s(), - candidate_tracker=candidate_trackers[i], - ) - for i in compiled_candidates - if i != 0 - ] - worker_context_queue = create_worker_context_queue(args.devices) - candidate_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=task_list, - function=run_iree_benchmark_module_command, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), + candidate_indices = [i for i in compiled_candidates if i != 0] + candidate_results = benchmark_candidates( + candidate_indices=candidate_indices, + devices=args.devices, + tuning_client=tuning_client, + candidate_trackers=candidate_trackers, ) # Select top candidates @@ -875,21 +887,14 @@ def get_speedup(result: BenchmarkResult) -> float: top_candidates = [result.candidate_id for result in best_results] # Benchmarking baselines on each involved device again to check performance regression on devices. - worker_context_queue = create_worker_context_queue(args.devices) - baseline_task_list = [ - BenchmarkPack( - iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), - benchmark_timeout=tuning_client.get_benchmark_timeout_s(), - candidate_tracker=candidate_trackers[0], - ) - ] * len(args.devices) - post_baseline_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=baseline_task_list, - function=run_iree_benchmark_module_command, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), + post_baseline_indices = [0] * len(args.devices) + post_baseline_results = benchmark_candidates( + candidate_indices=post_baseline_indices, + devices=args.devices, + tuning_client=tuning_client, + candidate_trackers=candidate_trackers, ) + post_baseline_times_by_device = {} for r in post_baseline_results: post_baseline_times_by_device[r.device_id] = r.time @@ -906,12 +911,12 @@ def get_speedup(result: BenchmarkResult) -> float: if post_time > baseline_time * 1.03: regression_detected = True percentage_slower = ((post_time - baseline_time) / baseline_time) * 100 - logging.info( + logging.warning( f"Performance regression detected on device {device_id}: " f"Baseline time = {baseline_time}, Post-baseline time = {post_time}, " f"Slower by {percentage_slower:.3f}%" ) if not regression_detected: - logging.info("No performance regressions detected.") + logging.debug("No performance regressions detected.") return top_candidates