Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tuner] Fix context management #770

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 66 additions & 64 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import argparse
from pathlib import Path
from tuner import libtuner
from tuner.common import *


class TestTuner(libtuner.TuningClient):
def __init__(self):
super().__init__()
def __init__(self, tuner_context: libtuner.TunerContext):
super().__init__(tuner_context)
self.compile_flags = ["--compile-from=executable-sources"]
self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"]

Expand Down Expand Up @@ -104,68 +105,69 @@ def main():
print("Validation successful!\n")

print("Generating candidate tuning specs...")
test_tuner = TestTuner()
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())
with TunerContext() as tuner_context:
test_tuner = TestTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
6 changes: 3 additions & 3 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
Expand Down
23 changes: 17 additions & 6 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import re
import logging
from dataclasses import astuple, dataclass, field
from enum import Enum
from types import TracebackType
from typing import Optional
from typing import Any

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_gpu # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore


class CommonTypes:
Expand All @@ -38,10 +37,22 @@ def getI64(self, value: int) -> ir.IntegerAttr:


class TunerContext:
def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger):
self.mlir_ctx: ir.Context = mlir_ctx
self.logger: logging.Logger = logger
self.type: CommonTypes = CommonTypes(mlir_ctx)
def __init__(self, logger: Optional[logging.Logger] = None):
self.mlir_ctx: ir.Context = ir.Context()
self.logger: logging.Logger = logger or logging.getLogger("tune")
self.type: CommonTypes = CommonTypes(self.mlir_ctx)

def __enter__(self) -> "TunerContext":
self.mlir_ctx.__enter__()
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
return self.mlir_ctx.__exit__(exc_type, exc_value, traceback)


class DispatchKind(Enum):
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


CONTRACTION_TEMPLATE = r"""
Expand Down
26 changes: 11 additions & 15 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@

import math
import signal
import subprocess
import sys
import shutil
import logging
import argparse
from datetime import datetime
from enum import Enum
from pathlib import Path
import time
import multiprocessing
import queue
from tqdm import tqdm
Expand All @@ -37,6 +35,7 @@
import iree.runtime as ireert # type: ignore
import iree.compiler as ireec # type: ignore
from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore
from . import candidate_gen
from . import dispatch_parser
from .op_matchers import *
Expand Down Expand Up @@ -103,10 +102,8 @@ def get_candidate_vmfb_filename(self, candidate_id: int) -> str:


class TuningClient(ABC):
def __init__(self):
mlir_ctx = ir.Context()
logger = logging.getLogger("tune")
self.tuner_context = TunerContext(mlir_ctx, logger)
def __init__(self, tuner_context: TunerContext):
self.tuner_context = tuner_context

@abstractmethod
def get_iree_compile_flags(self) -> list[str]:
Expand Down Expand Up @@ -644,15 +641,14 @@ def generate_candidate_specs(
# source mlir.
mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir)
mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context)
with tuning_client.tuner_context.mlir_ctx:
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("candidate_gen.py ends")
handle_error(
condition=(len(config_specs) <= 1), msg="Failed to generate any candidates"
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_map_result_dim_positions(map: ir.AffineMap):


class ContractionOpInterfaceMatcher(GenericOpMatcher):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.contraction_dimensions: Optional[ContractionDimensions] = None
self.lhs_dims: Optional[list[int]] = None
Expand Down
Loading