diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py index 3eb92f3a1..1d87f6c52 100644 --- a/tuner/examples/dispatch/dispatch_tuner.py +++ b/tuner/examples/dispatch/dispatch_tuner.py @@ -109,48 +109,49 @@ def main(): path_config.base_dir.mkdir(parents=True, exist_ok=True) path_config.output_unilog.touch() candidate_trackers: list[libtuner.CandidateTracker] = [] - mlir_ctx = ir.Context() - logger = logging.getLogger("tune") - tuner_context = TunerContext(mlir_ctx, logger) - dispatch_tuner = DispatchTuner(tuner_context) - stop_after_phase: str = args.stop_after - - print("Setup logging") - libtuner.setup_logging(args, path_config) - print(path_config.run_log, end="\n\n") - - if not args.dry_run: - print("Validating devices") - libtuner.validate_devices(args.devices) - print("Validation successful!\n") - - print("Generating candidates...") - candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) - print(f"Stored candidates in {path_config.candidates_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = libtuner.compile_dispatches( - args, path_config, candidates, candidate_trackers, dispatch_tuner - ) - print(f"Compiled files are stored in {path_config.compiled_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: - return - - print("Benchmarking compiled candidates...") - top_candidates = libtuner.benchmark_dispatches( - args, path_config, compiled_candidates, candidate_trackers, dispatch_tuner - ) - print(f"\nStored results in {path_config.output_unilog.resolve()}\n") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: - return - - libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) - print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") - print("Check the detailed execution logs in:") - print(path_config.run_log.resolve()) + with TunerContext() as tuner_context: + dispatch_tuner = DispatchTuner(tuner_context) + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidates( + args, path_config, candidate_trackers, tuner_context + ) + print(f"Stored candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, dispatch_tuner + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, dispatch_tuner + ) + print(f"\nStored results in {path_config.output_unilog.resolve()}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) for candidate in candidate_trackers: libtuner.logging.debug(candidate) diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py index 2bfdb4d24..189dc63d3 100644 --- a/tuner/examples/punet/punet_autotune.py +++ b/tuner/examples/punet/punet_autotune.py @@ -23,6 +23,7 @@ from tuner import libtuner from pathlib import Path +from tuner.common import * class PunetClient(libtuner.TuningClient): @@ -142,51 +143,54 @@ def main(): print("Validation successful!\n") print("Generating candidates...") - candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) - print(f"Stored candidates in {path_config.candidates_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = libtuner.compile_dispatches( - args, path_config, candidates, candidate_trackers, punet_client - ) - print(f"Compiled files are stored in {path_config.compiled_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: - return - - print("Benchmarking compiled candidates...") - top_candidates = libtuner.benchmark_dispatches( - args, path_config, compiled_candidates, candidate_trackers, punet_client - ) - print(f"Stored results in {path_config.output_unilog}\n") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: - return - - print(f"Compiling top model candidates...") - punet_candidates = libtuner.compile_models( - args, path_config, top_candidates, candidate_trackers, punet_client - ) - print(f"Model candidates compiled in {path_config.base_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_models: - return - - print("Benchmarking model candidates...") - libtuner.benchmark_models( - args, path_config, punet_candidates, candidate_trackers, punet_client - ) - print(f"Stored results in {path_config.output_unilog}") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: - return - - libtuner.summerize_top_candidates(path_config, candidate_trackers) - print(f"Stored top candidates info in {path_config.result_summary_log}\n") - - libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) - print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") - - print("Check the detailed execution logs in:") - print(path_config.run_log) + with TunerContext() as tuner_context: + candidates = libtuner.generate_candidates( + args, path_config, candidate_trackers, tuner_context + ) + print(f"Stored candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, punet_client + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + print(f"Compiling top model candidates...") + punet_candidates = libtuner.compile_models( + args, path_config, top_candidates, candidate_trackers, punet_client + ) + print(f"Model candidates compiled in {path_config.base_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_models: + return + + print("Benchmarking model candidates...") + libtuner.benchmark_models( + args, path_config, punet_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: + return + + libtuner.summerize_top_candidates(path_config, candidate_trackers) + print(f"Stored top candidates info in {path_config.result_summary_log}\n") + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log) for candidate in candidate_trackers: libtuner.logging.debug(candidate) diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py index ec963ed9e..d6c0d131f 100644 --- a/tuner/examples/test/tuner_test.py +++ b/tuner/examples/test/tuner_test.py @@ -111,65 +111,63 @@ def main(): print("Validation successful!\n") print("Generating candidates...") - mlir_ctx = ir.Context() - logger = logging.getLogger("tune") - tuner_context = TunerContext(mlir_ctx, logger) - test_tuner = TestTuner(tuner_context) - candidates = libtuner.generate_candidate_specs( - args, path_config, candidate_trackers, test_tuner - ) - print(f"Stored candidate specs in {path_config.specs_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = libtuner.compile( - args, path_config, candidates, candidate_trackers, test_tuner - ) - - print("Benchmarking compiled candidates...") - top_candidates = libtuner.benchmark( - args, - path_config, - compiled_candidates, - candidate_trackers, - test_tuner, - args.test_num_dispatch_candidates, - ) - - print("Compiling models with top candidates...") - test_tuner.compile_flags = [ - "--iree-hal-target-backends=rocm", - f"--iree-hip-target={args.test_hip_target}", - ] - compiled_model_candidates = libtuner.compile( - args, - path_config, - top_candidates, - candidate_trackers, - test_tuner, - args.test_model_file, - ) - - print("Benchmarking compiled model candidates...") - test_tuner.benchmark_flags = [ - "--benchmark_repetitions=3", - "--input=2048x2048xf16", - "--input=2048x2048xf16", - ] - top_model_candidates = libtuner.benchmark( - args, - path_config, - compiled_model_candidates, - candidate_trackers, - test_tuner, - args.test_num_model_candidates, - ) - - print(f"Top model candidates: {top_model_candidates}") - - print("Check the detailed execution logs in:") - print(path_config.run_log.resolve()) + with TunerContext() as tuner_context: + test_tuner = TestTuner(tuner_context) + candidates = libtuner.generate_candidate_specs( + args, path_config, candidate_trackers, test_tuner + ) + print(f"Stored candidate specs in {path_config.specs_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile( + args, path_config, candidates, candidate_trackers, test_tuner + ) + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark( + args, + path_config, + compiled_candidates, + candidate_trackers, + test_tuner, + args.test_num_dispatch_candidates, + ) + + print("Compiling models with top candidates...") + test_tuner.compile_flags = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={args.test_hip_target}", + ] + compiled_model_candidates = libtuner.compile( + args, + path_config, + top_candidates, + candidate_trackers, + test_tuner, + args.test_model_file, + ) + + print("Benchmarking compiled model candidates...") + test_tuner.benchmark_flags = [ + "--benchmark_repetitions=3", + "--input=2048x2048xf16", + "--input=2048x2048xf16", + ] + top_model_candidates = libtuner.benchmark( + args, + path_config, + compiled_model_candidates, + candidate_trackers, + test_tuner, + args.test_num_model_candidates, + ) + + print(f"Top model candidates: {top_model_candidates}") + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) for candidate in candidate_trackers: libtuner.logging.debug(candidate) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 45cb3512a..d67b4cec3 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -572,6 +572,7 @@ def tune( lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations tile_dims: str = "mnk", # Dimensions for the tile size + tuner_context: TunerContext = TunerContext(), ): input_file = str(input) @@ -586,57 +587,55 @@ def tune( mlir_template = read_input_mlir(input_file) mlir_text = "".join(mlir_template) - with ir.Context() as ctx: - tuner_context = TunerContext(ctx, tune_logger) - mlir_module = parse_mlir(mlir_text, tuner_context) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) + mlir_module = parse_mlir(mlir_text, tuner_context) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) + + walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + # Get the MMA intrinisic intructions supported by the target. + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate( + generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + # TODO: Fix pickling for ir types. + # with open(path.join(output, "configs.pkl"), "wb") as file: + # pickle.dump(configs, file) - walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) - assert len(variant_op_list) == 1, "Expect one executable variant op" - variant_op = variant_op_list[0] - # Get the MMA intrinisic intructions supported by the target. - mma_list = iree_codegen.query_mma_intrinsics(variant_op) - - dispatch_tuner = walk_result.dispatch_tuner - assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate( - generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) - ): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - # TODO: Fix pickling for ir types. - # with open(path.join(output, "configs.pkl"), "wb") as file: - # pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def generate_configs_and_td_specs( diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index d135a8502..44f6b28b9 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -27,9 +27,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + # Mock the logger + mock_logger = MagicMock(spec=Logger) + + # Use TunerContext with the mocked logger + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx def remove_comments(mlir: str) -> str: diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 9d0ce7777..edc5c3f09 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -4,12 +4,13 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import re +from __future__ import annotations import logging from dataclasses import astuple, dataclass from enum import Enum from typing import Optional from typing import Any +from typing_extensions import Literal from iree.compiler import ir # type: ignore @@ -38,16 +39,18 @@ def getI64(self, value: int) -> ir.IntegerAttr: class TunerContext: - def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): - self.mlir_ctx: ir.Context = mlir_ctx - self.logger: logging.Logger = logger - self.type: CommonTypes = CommonTypes(mlir_ctx) - - def __enter__(self): + def __init__(self, logger: Optional[logging.Logger] = None): + self.mlir_ctx: ir.Context = ir.Context() + self.logger: logging.Logger = logger or logging.getLogger( + "tune" + ) # Default to "tune" logger + self.type: CommonTypes = CommonTypes(self.mlir_ctx) + + def __enter__(self) -> TunerContext: self.mlir_ctx.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: self.mlir_ctx.__exit__(exc_type, exc_value, traceback) return False diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 6157bb355..393e5da94 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -23,9 +23,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + # Mock the logger + mock_logger = MagicMock(spec=Logger) + + # Use TunerContext with the mocked logger + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx @pytest.fixture diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 842ea8509..048f43339 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -25,9 +25,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + # Mock the logger + mock_logger = MagicMock(spec=Logger) + + # Use TunerContext with the mocked logger + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 0b87be659..0230a8f37 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -27,9 +27,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + # Mock the logger + mock_logger = MagicMock(spec=Logger) + + # Use TunerContext with the mocked logger + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 4a4e724da..41300801e 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -173,16 +173,6 @@ class TuningClient(ABC): def __init__(self, tuner_context: TunerContext): self.tuner_context = tuner_context - def __enter__(self): - # Enter the context of TunerContext - self.tuner_context.__enter__() - return self - - def __exit__(self, exc_type, exc_value, traceback): - # Exit the context of TunerContext - self.tuner_context.__exit__(exc_type, exc_value, traceback) - return False - @abstractmethod def get_iree_compile_flags(self) -> list[str]: pass @@ -971,6 +961,7 @@ def generate_candidates( args: argparse.Namespace, path_config: PathConfig, candidate_trackers: list[CandidateTracker], + tuner_context: TunerContext, ) -> list[int]: """Generate candidate files for tuning. Returns the list of candidate indexes""" logging.debug("generate_candidates()") @@ -1002,6 +993,7 @@ def generate_candidates( lhs_dims=args.lhs_dims, rhs_dims=args.rhs_dims, tile_dims=args.tile_dims, + tuner_context=tuner_context, ) mlirs = sorted( path_config.candidates_dir.glob("*.mlir"), key=numerical_sort_key