diff --git a/.gitignore b/.gitignore index 43bd95c2b..b55675b79 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ poetry.lock noxenv.txt noxsettings.toml +hyperparamtuning/ ### Python ### *.pyc @@ -16,6 +17,8 @@ push_to_pypi.sh .nfs* *.log *.json +!kernel_tuner/schema/T1/1.0.0/input-schema.json +!test/test_T1_input.json *.csv .cache *.ipynb_checkpoints diff --git a/kernel_tuner/__init__.py b/kernel_tuner/__init__.py index b64d69813..40b88d463 100644 --- a/kernel_tuner/__init__.py +++ b/kernel_tuner/__init__.py @@ -1,5 +1,5 @@ from kernel_tuner.integration import store_results, create_device_targets -from kernel_tuner.interface import tune_kernel, run_kernel +from kernel_tuner.interface import tune_kernel, tune_kernel_T1, run_kernel from importlib.metadata import version diff --git a/kernel_tuner/backends/backend.py b/kernel_tuner/backends/backend.py index a37c9d6e7..13954be6a 100644 --- a/kernel_tuner/backends/backend.py +++ b/kernel_tuner/backends/backend.py @@ -1,16 +1,16 @@ -"""This module contains the interface of all kernel_tuner backends""" +"""This module contains the interface of all kernel_tuner backends.""" from __future__ import print_function from abc import ABC, abstractmethod class Backend(ABC): - """Base class for kernel_tuner backends""" + """Base class for kernel_tuner backends.""" @abstractmethod def ready_argument_list(self, arguments): """This method must implement the allocation of the arguments on device memory.""" - pass + return arguments @abstractmethod def compile(self, kernel_instance): @@ -59,7 +59,7 @@ def memcpy_htod(self, dest, src): class GPUBackend(Backend): - """Base class for GPU backends""" + """Base class for GPU backends.""" @abstractmethod def __init__(self, device, iterations, compiler_options, observers): @@ -82,7 +82,7 @@ def copy_texture_memory_args(self, texmem_args): class CompilerBackend(Backend): - """Base class for compiler backends""" + """Base class for compiler backends.""" @abstractmethod def __init__(self, iterations, compiler_options, compiler): diff --git a/kernel_tuner/backends/hypertuner.py b/kernel_tuner/backends/hypertuner.py new file mode 100644 index 000000000..65a263ce1 --- /dev/null +++ b/kernel_tuner/backends/hypertuner.py @@ -0,0 +1,131 @@ +"""This module contains a 'device' for hyperparameter tuning using the autotuning methodology.""" + +import platform +from pathlib import Path + +from numpy import mean + +from kernel_tuner.backends.backend import Backend +from kernel_tuner.observers.observer import BenchmarkObserver + +try: + methodology_available = True + from autotuning_methodology.experiments import generate_experiment_file + from autotuning_methodology.report_experiments import get_strategy_scores +except ImportError: + methodology_available = False + + +class ScoreObserver(BenchmarkObserver): + def __init__(self, dev): + self.dev = dev + self.scores = [] + + def after_finish(self): + self.scores.append(self.dev.last_score) + + def get_results(self): + results = {'score': mean(self.scores), 'scores': self.scores.copy()} + self.scores = [] + return results + +class HypertunerFunctions(Backend): + """Class for executing hyperparameter tuning.""" + units = {} + + def __init__(self, iterations): + self.iterations = iterations + self.observers = [ScoreObserver(self)] + self.name = platform.processor() + self.max_threads = 1024 + self.last_score = None + + # set the environment options + env = dict() + env["iterations"] = self.iterations + self.env = env + + # check for the methodology package + if methodology_available is not True: + raise ImportError("Unable to import the autotuning methodology, run `pip install autotuning_methodology`.") + + def ready_argument_list(self, arguments): + arglist = super().ready_argument_list(arguments) + if arglist is None: + arglist = [] + return arglist + + def compile(self, kernel_instance): + super().compile(kernel_instance) + path = Path(__file__).parent.parent.parent / "hyperparamtuning" + path.mkdir(exist_ok=True) + + # TODO get applications & GPUs args from benchmark + gpus = ["RTX_3090", "RTX_2080_Ti"] + applications = None + # applications = [ + # { + # "name": "convolution", + # "folder": "./cached_data_used/kernels", + # "input_file": "convolution.json" + # }, + # { + # "name": "pnpoly", + # "folder": "./cached_data_used/kernels", + # "input_file": "pnpoly.json" + # } + # ] + + # strategy settings + strategy: str = kernel_instance.arguments[0] + hyperparams = [{'name': k, 'value': v} for k, v in kernel_instance.params.items()] + hyperparams_string = "_".join(f"{k}={str(v)}" for k, v in kernel_instance.params.items()) + searchspace_strategies = [{ + "autotuner": "KernelTuner", + "name": f"{strategy.lower()}_{hyperparams_string}", + "display_name": strategy.replace('_', ' ').capitalize(), + "search_method": strategy.lower(), + 'search_method_hyperparameters': hyperparams + }] + + # any additional settings + override = { + "experimental_groups_defaults": { + "samples": self.iterations + } + } + + name = kernel_instance.name if len(kernel_instance.name) > 0 else kernel_instance.kernel_source.kernel_name + experiments_filepath = generate_experiment_file(name, path, searchspace_strategies, applications, gpus, + override=override, overwrite_existing_file=True) + return str(experiments_filepath) + + def start_event(self): + return super().start_event() + + def stop_event(self): + return super().stop_event() + + def kernel_finished(self): + super().kernel_finished() + return True + + def synchronize(self): + return super().synchronize() + + def run_kernel(self, func, gpu_args=None, threads=None, grid=None, stream=None): + # generate the experiments file + experiments_filepath = Path(func) + + # run the methodology to get a fitness score for this configuration + scores = get_strategy_scores(str(experiments_filepath)) + self.last_score = scores[list(scores.keys())[0]]['score'] + + def memset(self, allocation, value, size): + return super().memset(allocation, value, size) + + def memcpy_dtoh(self, dest, src): + return super().memcpy_dtoh(dest, src) + + def memcpy_htod(self, dest, src): + return super().memcpy_htod(dest, src) diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index abd4a017e..f139111e7 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -1,4 +1,4 @@ -""" Module for grouping the core functionality needed by most runners """ +"""Module for grouping the core functionality needed by most runners.""" import logging import re @@ -14,15 +14,16 @@ import kernel_tuner.util as util from kernel_tuner.accuracy import Tunable -from kernel_tuner.backends.pycuda import PyCudaFunctions +from kernel_tuner.backends.compiler import CompilerFunctions from kernel_tuner.backends.cupy import CupyFunctions from kernel_tuner.backends.hip import HipFunctions +from kernel_tuner.backends.hypertuner import HypertunerFunctions from kernel_tuner.backends.nvcuda import CudaFunctions from kernel_tuner.backends.opencl import OpenCLFunctions -from kernel_tuner.backends.compiler import CompilerFunctions +from kernel_tuner.backends.pycuda import PyCudaFunctions from kernel_tuner.observers.nvml import NVMLObserver -from kernel_tuner.observers.tegra import TegraObserver from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver +from kernel_tuner.observers.tegra import TegraObserver try: import torch @@ -45,15 +46,15 @@ class KernelInstance(_KernelInstance): - """Class that represents the specific parameterized instance of a kernel""" + """Class that represents the specific parameterized instance of a kernel.""" def delete_temp_files(self): - """Delete any generated temp files""" + """Delete any generated temp files.""" for v in self.temp_files.values(): util.delete_temp_file(v) def prepare_temp_files_for_error_msg(self): - """Prepare temp file with source code, and return list of temp file names""" + """Prepare temp file with source code, and return list of temp file names.""" temp_filename = util.get_temp_filename(suffix=self.kernel_source.get_suffix()) util.write_file(temp_filename, self.kernel_string) ret = [temp_filename] @@ -89,7 +90,7 @@ def __init__(self, kernel_name, kernel_sources, lang, defines=None): self.lang = lang.upper() def get_kernel_string(self, index=0, params=None): - """retrieve the kernel source with the given index and return as a string + """Retrieve the kernel source with the given index and return as a string. See util.get_kernel_string() for details. @@ -105,13 +106,16 @@ def get_kernel_string(self, index=0, params=None): """ logging.debug("get_kernel_string called") + if hasattr(self, 'lang') and self.lang.upper() == "HYPERTUNER": + return "" + kernel_source = self.kernel_sources[index] return util.get_kernel_string(kernel_source, params) def prepare_list_of_files( self, kernel_name, params, grid, threads, block_size_names ): - """prepare the kernel string along with any additional files + """Prepare the kernel string along with any additional files. The first file in the list is allowed to include or read in the others The files beyond the first are considered additional files that may also contain tunable parameters @@ -144,6 +148,9 @@ def prepare_list_of_files( """ temp_files = dict() + if self.lang.upper() == "HYPERTUNER": + return tuple(["", "", temp_files]) + for i, f in enumerate(self.kernel_sources): if i > 0 and not util.looks_like_a_filename(f): raise ValueError( @@ -194,7 +201,6 @@ def get_suffix(self, index=0): This uses the user-specified suffix if available, or one based on the lang/backend otherwise. """ - # TODO: Consider delegating this to the backend suffix = self.get_user_suffix(index) if suffix is not None: @@ -207,7 +213,7 @@ def get_suffix(self, index=0): return ".c" def check_argument_lists(self, kernel_name, arguments): - """Check if the kernel arguments have the correct types + """Check if the kernel arguments have the correct types. This is done by calling util.check_argument_list on each kernel string. """ @@ -223,7 +229,7 @@ def check_argument_lists(self, kernel_name, arguments): class DeviceInterface(object): - """Class that offers a High-Level Device Interface to the rest of the Kernel Tuner""" + """Class that offers a High-Level Device Interface to the rest of the Kernel Tuner.""" def __init__( self, @@ -236,7 +242,7 @@ def __init__( iterations=7, observers=None, ): - """Instantiate the DeviceInterface, based on language in kernel source + """Instantiate the DeviceInterface, based on language in kernel source. :param kernel_source: The kernel sources :type kernel_source: kernel_tuner.core.KernelSource @@ -266,6 +272,7 @@ def __init__( """ lang = kernel_source.lang + self.requires_warmup = True logging.debug("DeviceInterface instantiated, lang=%s", lang) @@ -311,6 +318,9 @@ def __init__( iterations=iterations, observers=observers, ) + elif lang.upper() == "HYPERTUNER": + dev = HypertunerFunctions(iterations=iterations) + self.requires_warmup = False else: raise ValueError("Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet") self.dev = dev @@ -351,8 +361,7 @@ def __init__( print("Using: " + self.dev.name) def benchmark_prologue(self, func, gpu_args, threads, grid, result): - """Benchmark prologue one kernel execution per PrologueObserver""" - + """Benchmark prologue one kernel execution per PrologueObserver.""" for obs in self.prologue_observers: self.dev.synchronize() obs.before_start() @@ -362,8 +371,7 @@ def benchmark_prologue(self, func, gpu_args, threads, grid, result): result.update(obs.get_results()) def benchmark_default(self, func, gpu_args, threads, grid, result): - """Benchmark one kernel execution for 'iterations' at a time""" - + """Benchmark one kernel execution for 'iterations' at a time.""" self.dev.synchronize() for _ in range(self.iterations): for obs in self.benchmark_observers: @@ -387,7 +395,7 @@ def benchmark_default(self, func, gpu_args, threads, grid, result): def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration): - """Benchmark continuously for at least 'duration' seconds""" + """Benchmark continuously for at least 'duration' seconds.""" iterations = int(np.ceil(duration / (result["time"] / 1000))) self.dev.synchronize() for obs in self.continuous_observers: @@ -485,7 +493,7 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett def check_kernel_output( self, func, gpu_args, instance, answer, atol, verify, verbose ): - """runs the kernel once and checks the result against answer""" + """Runs the kernel once and checks the result against answer.""" logging.debug("check_kernel_output") #if not using custom verify function, check if the length is the same @@ -633,7 +641,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, return result def compile_kernel(self, instance, verbose): - """compile the kernel for this specific instance""" + """Compile the kernel for this specific instance.""" logging.debug("compile_kernel " + instance.name) # compile kernel_string into device func @@ -664,23 +672,23 @@ def compile_kernel(self, instance, verbose): @staticmethod def preprocess_gpu_arguments(old_arguments, params): - """ Get a flat list of arguments based on the configuration given by `params` """ + """Get a flat list of arguments based on the configuration given by `params`.""" return _preprocess_gpu_arguments(old_arguments, params) def copy_shared_memory_args(self, smem_args): - """adds shared memory arguments to the most recently compiled module""" + """Adds shared memory arguments to the most recently compiled module.""" self.dev.copy_shared_memory_args(smem_args) def copy_constant_memory_args(self, cmem_args): - """adds constant memory arguments to the most recently compiled module""" + """Adds constant memory arguments to the most recently compiled module.""" self.dev.copy_constant_memory_args(cmem_args) def copy_texture_memory_args(self, texmem_args): - """adds texture memory arguments to the most recently compiled module""" + """Adds texture memory arguments to the most recently compiled module.""" self.dev.copy_texture_memory_args(texmem_args) def create_kernel_instance(self, kernel_source, kernel_options, params, verbose): - """create kernel instance from kernel source, parameters, problem size, grid divisors, and so on""" + """Create kernel instance from kernel source, parameters, problem size, grid divisors, and so on.""" grid_div = ( kernel_options.grid_div_x, kernel_options.grid_div_y, @@ -725,15 +733,15 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) return KernelInstance(name, kernel_source, kernel_string, temp_files, threads, grid, params, arguments) def get_environment(self): - """Return dictionary with information about the environment""" + """Return dictionary with information about the environment.""" return self.dev.env def memcpy_dtoh(self, dest, src): - """perform a device to host memory copy""" + """Perform a device to host memory copy.""" self.dev.memcpy_dtoh(dest, src) def ready_argument_list(self, arguments): - """ready argument list to be passed to the kernel, allocates gpu mem if necessary""" + """Ready argument list to be passed to the kernel, allocates gpu mem if necessary.""" flat_args = [] # Flatten all arguments into a single list. Required to deal with `Tunable`s @@ -760,7 +768,7 @@ def ready_argument_list(self, arguments): return gpu_args def run_kernel(self, func, gpu_args, instance): - """Run a compiled kernel instance on a device""" + """Run a compiled kernel instance on a device.""" logging.debug("run_kernel %s", instance.name) logging.debug("thread block dims (%d, %d, %d)", *instance.threads) logging.debug("grid dims (%d, %d, %d)", *instance.grid) @@ -782,7 +790,7 @@ def run_kernel(self, func, gpu_args, instance): def _preprocess_gpu_arguments(old_arguments, params): - """ Get a flat list of arguments based on the configuration given by `params` """ + """Get a flat list of arguments based on the configuration given by `params`.""" new_arguments = [] for argument in old_arguments: @@ -795,8 +803,7 @@ def _preprocess_gpu_arguments(old_arguments, params): def _default_verify_function(instance, answer, result_host, atol, verbose): - """default verify function based on np.allclose""" - + """Default verify function based on np.allclose.""" # first check if the length is the same if len(instance.arguments) != len(answer): raise TypeError( @@ -919,7 +926,7 @@ def _flatten(a): # these functions facilitate compiling templated kernels with PyCuda def split_argument_list(argument_list): - """split all arguments in a list into types and names""" + """Split all arguments in a list into types and names.""" regex = r"(.*[\s*]+)(.+)?" type_list = [] name_list = [] @@ -933,10 +940,10 @@ def split_argument_list(argument_list): def apply_template_typenames(type_list, templated_typenames): - """replace the typename tokens in type_list with their templated typenames""" + """Replace the typename tokens in type_list with their templated typenames.""" def replace_typename_token(matchobj): - """function for a whitespace preserving token regex replace""" + """Function for a whitespace preserving token regex replace.""" # replace only the match, leaving the whitespace around it as is return ( matchobj.group(1) @@ -954,7 +961,7 @@ def replace_typename_token(matchobj): def get_templated_typenames(template_parameters, template_arguments): - """based on the template parameters and arguments, create dict with templated typenames""" + """Based on the template parameters and arguments, create dict with templated typenames.""" templated_typenames = {} for i, param in enumerate(template_parameters): if "typename " in param: @@ -964,7 +971,7 @@ def get_templated_typenames(template_parameters, template_arguments): def wrap_templated_kernel(kernel_string, kernel_name): - """rewrite kernel_string to insert wrapper function for templated kernel""" + """Rewrite kernel_string to insert wrapper function for templated kernel.""" # parse kernel_name to find template_arguments and real kernel name name = kernel_name.split("<")[0] template_arguments = re.search(r".*?<(.*)>", kernel_name, re.S).group(1).split(",") diff --git a/kernel_tuner/file_utils.py b/kernel_tuner/file_utils.py index e5d3dcb90..2b75cc023 100644 --- a/kernel_tuner/file_utils.py +++ b/kernel_tuner/file_utils.py @@ -1,19 +1,43 @@ """This module contains utility functions for operations on files, mostly JSON cache files.""" import json -import os import subprocess from importlib.metadata import PackageNotFoundError, requires, version from pathlib import Path from sys import platform +import jsonschema import xmltodict from packaging.requirements import Requirement from kernel_tuner import util -schema_dir = os.path.dirname(os.path.realpath(__file__)) + "/schema" +schema_dir = Path(__file__).parent / "schema" +def input_file_schema(): + """Get the requested JSON input schema and the version number. + + :returns: the current version of the T1 schemas and the JSON string of the schema + :rtype: string, string + """ + current_version = "1.0.0" + input_file = schema_dir.joinpath(f"T1/{current_version}/input-schema.json") + with input_file.open() as fh: + json_string = json.load(fh) + return current_version, json_string + +def get_input_file(filepath: Path, validate=True) -> dict[str, any]: + """Load the T1 input file from the given path, validates it and returns contents if valid. + + :param filepath: Path to the input file to load. + :returns: the contents of the file if valid. + """ + with filepath.open() as fp: + input_file = json.load(fp) + if validate: + _, input_schema = input_file_schema() + jsonschema.validate(input_file, input_schema) + return input_file def output_file_schema(target): """Get the requested JSON schema and the version number. @@ -26,8 +50,8 @@ def output_file_schema(target): """ current_version = "1.0.0" - output_file = schema_dir + f"/T4/{current_version}/{target}-schema.json" - with open(output_file, "r") as fh: + output_file = schema_dir.joinpath(f"T4/{current_version}/{target}-schema.json") + with output_file.open() as fh: json_string = json.load(fh) return current_version, json_string @@ -63,13 +87,10 @@ def make_filenamepath(filenamepath: Path): filepath.mkdir() -def store_output_file(output_filename: str, results, tune_params, objective="time"): - """Store the obtained auto-tuning results in a JSON output file. +def get_t4_results(results, tune_params, objective="time"): + """Get the obtained auto-tuning results in a dictionary. - This function produces a JSON file that adheres to the T4 auto-tuning output JSON schema. - - :param output_filename: Name or 'path / name' of the to be created output file - :type output_filename: string + This function produces a dictionary that adheres to the T4 auto-tuning output JSON schema. :param results: Results list as return by tune_kernel :type results: list of dicts @@ -81,9 +102,6 @@ def store_output_file(output_filename: str, results, tune_params, objective="tim :type objective: string """ - output_filenamepath = Path(filename_ensure_json_extension(output_filename)) - make_filenamepath(output_filenamepath) - timing_keys = ["compile_time", "benchmark_time", "framework_time", "strategy_time", "verification_time"] not_measurement_keys = list(tune_params.keys()) + timing_keys + ["timestamp"] + ["times"] @@ -134,7 +152,30 @@ def store_output_file(output_filename: str, results, tune_params, objective="tim # write output_data to a JSON file version, _ = output_file_schema("results") - output_json = dict(results=output_data, schema_version=version) + output_json = dict(results=output_data, schema_version=version, metadata={'timeunit': 'miliseconds'}) + return output_json + +def store_output_file(output_filename: str, results, tune_params, objective="time"): + """Store the obtained auto-tuning results in a JSON output file. + + This function produces a JSON file that adheres to the T4 auto-tuning output JSON schema. + + :param output_filename: Name or 'path / name' of the to be created output file + :type output_filename: string + + :param results: Results list as return by tune_kernel + :type results: list of dicts + + :param tune_params: Tunable parameters as passed to tune_kernel + :type tune_params: dict + + :param objective: The objective used during auto-tuning, default is 'time'. + :type objective: string + + """ + output_filenamepath = Path(filename_ensure_json_extension(output_filename)) + make_filenamepath(output_filenamepath) + output_json = get_t4_results(results, tune_params, objective) with open(output_filenamepath, "w+") as fh: json.dump(output_json, fh, cls=util.NpEncoder) @@ -175,17 +216,11 @@ def get_device_query(target): raise ValueError("get_device_query target not supported") -def store_metadata_file(metadata_filename: str): - """Store the metadata about the current hardware and software environment in a JSON output file. - - This function produces a JSON file that adheres to the T4 auto-tuning metadata JSON schema. - - :param metadata_filename: Name or 'path / name' of the to be created metadata file - :type metadata_filename: string +def get_t4_metadata(): + """Get the metadata about the current hardware and software environment. + This function produces a dictionary that adheres to the T4 auto-tuning metadata JSON schema. """ - metadata_filenamepath = Path(filename_ensure_json_extension(metadata_filename)) - make_filenamepath(metadata_filenamepath) metadata = {} supported_operating_systems = ["linux", "win32", "darwin"] @@ -250,5 +285,20 @@ def store_metadata_file(metadata_filename: str): # write metadata to JSON file version, _ = output_file_schema("metadata") metadata_json = dict(metadata=metadata, schema_version=version) + return metadata_json + +def store_metadata_file(metadata_filename: str): + """Store the metadata about the current hardware and software environment in a JSON output file. + + This function produces a JSON file that adheres to the T4 auto-tuning metadata JSON schema. + + :param metadata_filename: Name or 'path / name' of the to be created metadata file + :type metadata_filename: string + + """ + metadata_filenamepath = Path(filename_ensure_json_extension(metadata_filename)) + make_filenamepath(metadata_filenamepath) + metadata_json = get_t4_metadata() with open(metadata_filenamepath, "w+") as fh: json.dump(metadata_json, fh, indent=" ") + diff --git a/kernel_tuner/hyper.py b/kernel_tuner/hyper.py index f002882f3..b94c58986 100644 --- a/kernel_tuner/hyper.py +++ b/kernel_tuner/hyper.py @@ -1,15 +1,23 @@ -""" Module for functions related to hyperparameter optimization """ +"""Module for functions related to hyperparameter optimization.""" -import itertools -import warnings -import numpy as np + +from pathlib import Path +from random import randint import kernel_tuner -from kernel_tuner.util import get_config_string -def tune_hyper_params(target_strategy, hyper_params, *args, **kwargs): - """ Tune hyperparameters for a given strategy and kernel +def get_random_unique_filename(prefix = '', suffix=''): + """Get a random, unique filename that does not yet exist.""" + def randpath(): + return Path(f"{prefix}{randint(1000, 9999)}{suffix}") + path = randpath() + while path.exists(): + path = randpath() + return path + +def tune_hyper_params(target_strategy: str, hyper_params: dict, *args, **kwargs): + """Tune hyperparameters for a given strategy and kernel. This function is to be called just like tune_kernel, except that you specify a strategy and a dictionary with hyperparameters in front of the arguments you pass to tune_kernel. @@ -32,58 +40,59 @@ def tune_hyper_params(target_strategy, hyper_params, *args, **kwargs): :type kwargs: dict """ - if "cache" not in kwargs: - raise ValueError("Please specify a cachefile to store benchmarking data when tuning hyperparameters") + # v Have the methodology as a dependency + # - User inputs: + # - a set of bruteforced cachefiles / template experiments file + # - an optimization algorithm + # - the hyperparameter values to try + # - overarching optimization algorithm (meta-strategy) + # - At each round: + # - The meta-strategy selects a hyperparameter configuration to try + # - Kernel Tuner generates an experiments file with the hyperparameter configuration + # - Kernel Tuner executes this experiments file using the methodology + # - The methodology returns the fitness metric + # - The fitness metric is fed back into the meta-strategy + + iterations = 1 + if "iterations" in kwargs: + iterations = kwargs['iterations'] + del kwargs['iterations'] + + # pass a temporary cache file to avoid duplicate execution + cachefile = get_random_unique_filename('temp_', '.json') + kwargs['cache'] = str(cachefile) def put_if_not_present(target_dict, key, value): target_dict[key] = value if key not in target_dict else target_dict[key] - put_if_not_present(kwargs, "verbose", False) - put_if_not_present(kwargs, "quiet", True) - put_if_not_present(kwargs, "simulation_mode", True) - kwargs['strategy'] = 'brute_force' - - #last position argument is tune_params - tune_params = args[-1] - - #find optimum - kwargs["strategy"] = "brute_force" - results, _ = kernel_tuner.tune_kernel(*args, **kwargs) - optimum = min(results, key=lambda p: p["time"])["time"] - - #could throw a warning for the kwargs that will be overwritten, strategy(_options) - kwargs["strategy"] = target_strategy - - parameter_space = itertools.product(*hyper_params.values()) - all_results = [] - - for params in parameter_space: - strategy_options = dict(zip(hyper_params.keys(), params)) - - kwargs["strategy_options"] = strategy_options - - fevals = [] - p_of_opt = [] - for _ in range(100): - #measure - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - results, _ = kernel_tuner.tune_kernel(*args, **kwargs) - - #get unique function evaluations - unique_fevals = {",".join([str(v) for k, v in record.items() if k in tune_params]) - for record in results} - - fevals.append(len(unique_fevals)) - p_of_opt.append(min(results, key=lambda p: p["time"])["time"] / optimum * 100) - - strategy_options["fevals"] = np.average(fevals) - strategy_options["fevals_std"] = np.std(fevals) - - strategy_options["p_of_opt"] = np.average(p_of_opt) - strategy_options["p_of_opt_std"] = np.std(p_of_opt) - - print(get_config_string(strategy_options)) - all_results.append(strategy_options) - - return all_results + put_if_not_present(kwargs, "verbose", True) + put_if_not_present(kwargs, "quiet", False) + kwargs['simulation_mode'] = False + kwargs['strategy'] = 'dual_annealing' + kwargs['verify'] = None + arguments = [target_strategy] + + # execute the hyperparameter tuning + result, env = kernel_tuner.tune_kernel('hyperparamtuning', None, [], arguments, hyper_params, *args, lang='Hypertuner', + objective='score', objective_higher_is_better=True, iterations=iterations, **kwargs) + + # remove the temporary cachefile and return only unique results in order + cachefile.unlink() + result_unique = dict() + for r in result: + config_id = ",".join(str(r[k]) for k in hyper_params.keys()) + if config_id not in result_unique: + result_unique[config_id] = r + return list(result_unique.values()), env + +if __name__ == "__main__": # TODO remove in production + hyperparams = { + 'popsize': [10, 20, 30], + 'maxiter': [50, 100, 150], + 'w': [0.25, 0.5, 0.75], + 'c1': [1.0, 2.0, 3.0], + 'c2': [0.5, 1.0, 1.5] + } + result, env = tune_hyper_params('pso', hyperparams) + print(result) + print(env['best_config']) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 97ae22848..7570c321e 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -23,14 +23,19 @@ See the License for the specific language governing permissions and limitations under the License. """ + import logging +from argparse import ArgumentParser +from ast import literal_eval from datetime import datetime +from pathlib import Path from time import perf_counter import numpy import kernel_tuner.core as core import kernel_tuner.util as util +from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results from kernel_tuner.integration import get_objective_defaults from kernel_tuner.runners.sequential import SequentialRunner from kernel_tuner.runners.simulation import SimulationRunner @@ -399,7 +404,7 @@ def __deepcopy__(self, _): All strategies support the following two options: 1. "max_fevals": the maximum number of unique valid function evaluations (i.e. compiling and - benchmarking a kernel configuration the strategy is allowed to perform as part of the optimization. + benchmarking a kernel configuration) the strategy is allowed to perform as part of the optimization. Note that some strategies implement a default max_fevals of 100. 2. "time_limit": the maximum amount of time in seconds the strategy is allowed to spent on trying to @@ -660,6 +665,8 @@ def tune_kernel( # process cache if cache: + if isinstance(cache, Path): + cache = str(cache.resolve()) if cache[-5:] != ".json": cache += ".json" @@ -684,7 +691,7 @@ def tune_kernel( if results: # checks if results is not empty best_config = util.get_best_config(results, objective, objective_higher_is_better) # add the best configuration to env - env['best_config'] = best_config + env["best_config"] = best_config if not device_options.quiet: units = getattr(runner, "units", None) print("best performing configuration:") @@ -835,3 +842,126 @@ def _check_user_input(kernel_name, kernelsource, arguments, block_size_names): # check for types and length of block_size_names util.check_block_size_names(block_size_names) + + +def tune_kernel_T1(input_filepath: Path, cache_filepath: Path = None, simulation_mode = False, output_T4 = True, iterations = 7, strategy_options = None): + """Call the tune function with a T1 input file.""" + inputs = get_input_file(input_filepath) + kernelspec: dict = inputs["KernelSpecification"] + kernel_name: str = kernelspec["KernelName"] + kernel_filepath = Path(kernelspec["KernelFile"]) + kernel_source = ( + kernel_filepath if kernel_filepath.exists() else Path(input_filepath).parent.parent / kernel_filepath + ) + assert kernel_source.exists(), f"KernelFile '{kernel_source}' does not exist at {kernel_source.resolve()}" + language: str = kernelspec["Language"] + problem_size = kernelspec["ProblemSize"] + device = kernelspec["Device"]["Name"] + strategy = inputs["Search"]["Name"] + + if cache_filepath is None and "SimulationInput" in kernelspec: + cache_filepath = Path(kernelspec["SimulationInput"]) + + # get the grid divisions + grid_divs = {} + for grid_div in ["GridDivX", "GridDivY", "GridDivZ"]: + grid_divs[grid_div] = None + if grid_div in kernelspec and len(kernelspec[grid_div]) > 0: + grid_divs[grid_div] = kernelspec[grid_div] + + # convert tuneable parameters + tune_params = dict() + for param in inputs["ConfigurationSpace"]["TuningParameters"]: + tune_param = None + if param["Type"] in ["int", "float"]: + vals = param["Values"] + if vals[:5] == "list(" or (vals[0] == "[" and vals[-1] == "]"): + tune_param = eval(vals) + else: + tune_param = literal_eval(vals) + if tune_param is not None: + tune_params[param["Name"]] = tune_param + else: + raise NotImplementedError(f"Conversion for this type of parameter has not yet been implemented: {param}") + + # convert restrictions + restrictions = list() + for res in inputs["ConfigurationSpace"]["Conditions"]: + restriction = None + if isinstance(res["Expression"], str): + restriction = res["Expression"] + if restriction is not None: + restrictions.append(restriction) + else: + raise NotImplementedError(f"Conversion for this type of restriction has not yet been implemented: {res}") + + # convert arguments (must be after resolving tune_params) + arguments = list() + cmem_arguments = {} + for arg in kernelspec["Arguments"]: + argument = None + if arg["Type"] == "float" and arg["MemoryType"] == "Vector": + size = arg["Size"] + if isinstance(size, str): + args = tune_params.copy() + args["ProblemSize"] = problem_size + size = int(eval(size, args)) + if not isinstance(size, int): + raise TypeError(f"Size should be an integer, but is {size} (type ({type(size)}, from {arg['Size']}))") + if arg["FillType"] == "Constant": + argument = numpy.full(size, arg["FillValue"]).astype(numpy.float32) + elif arg["FillType"] == "Random": + argument = numpy.random.randn(size).astype(numpy.float32) + else: + raise NotImplementedError(f"Conversion for fill type '{arg['FillType']}' has not yet been implemented") + if argument is not None: + arguments.append(argument) + if "MemType" in arg and arg["MemType"] == "Constant": + cmem_arguments[arg["Name"]] = argument + else: + raise NotImplementedError(f"Conversion for this type of argument has not yet been implemented: {arg}") + + # tune with the converted inputs + # TODO add objective to tune_kernel and get_t4_results calls once available in T1 + results, env = tune_kernel( + kernel_name, + kernel_source, + problem_size, + arguments, + tune_params, + device=device, + grid_div_x=grid_divs["GridDivX"], + grid_div_y=grid_divs["GridDivY"], + grid_div_z=grid_divs["GridDivZ"], + cmem_args=cmem_arguments, + restrictions=restrictions, + lang=language, + cache=cache_filepath, + simulation_mode=simulation_mode, + quiet=True, + verbose=False, + iterations=iterations, + strategy=strategy, + strategy_options=strategy_options + ) + if output_T4: + return get_t4_metadata(), get_t4_results(results, tune_params) + return results, env + + +def entry_point(args=None): # pragma: no cover + """Command-line interface entry point.""" + cli = ArgumentParser() + cli.add_argument("input_file", type=str, help="The path to the input json file to execute (T1 standard)") + cli.add_argument( + "cache_file", type=str, help="The path to the cachefile to use (optional)", required=False, default=None + ) + args = cli.parse_args(args) + input_filepath_arg: str = args.input_file + if input_filepath_arg is None or input_filepath_arg == "": + raise ValueError("Invalid '--input_file' option. Run 'kernel_tuner -h' to read more.") + input_filepath = Path(input_filepath_arg) + cachefile_filepath_arg = args.cache_file + if cachefile_filepath_arg is not None: + cachefile_filepath_arg = Path(cachefile_filepath_arg) + tune_kernel_T1(input_filepath, cache_filepath=cachefile_filepath_arg) diff --git a/kernel_tuner/runners/sequential.py b/kernel_tuner/runners/sequential.py index aeebd5116..5e53093be 100644 --- a/kernel_tuner/runners/sequential.py +++ b/kernel_tuner/runners/sequential.py @@ -34,7 +34,7 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.units = self.dev.units self.quiet = device_options.quiet self.kernel_source = kernel_source - self.warmed_up = False + self.warmed_up = False if self.dev.requires_warmup else True self.simulation_mode = False self.start_time = perf_counter() self.last_strategy_start_time = self.start_time diff --git a/kernel_tuner/schema/T1/1.0.0/input-schema.json b/kernel_tuner/schema/T1/1.0.0/input-schema.json new file mode 100644 index 000000000..bb53ee594 --- /dev/null +++ b/kernel_tuner/schema/T1/1.0.0/input-schema.json @@ -0,0 +1,412 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "https://github.com/odgaard/TuningSchema/blob/main/TuningSchema.json", + "title": "Tuning format", + "description": "A description of a tuning problem which can be loaded by an autotuning framework", + "type": "object", + "required": [ + "ConfigurationSpace", + "KernelSpecification" + ], + "properties": { + "ConfigurationSpace": { + "type": "object", + "required": [ + "TuningParameters" + ], + "properties": { + "TuningParameters": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Name", + "Type", + "Values" + ], + "properties": { + "Name": { + "type": "string" + }, + "Type": { + "enum": [ + "int", + "uint", + "float", + "bool", + "string" + ] + }, + "Values": { + "type": "string" + } + } + } + }, + "Conditions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Parameters", + "Expression" + ], + "properties": { + "Parameters": { + "type": "array", + "items": { + "type": "string" + } + }, + "Expression": { + "type": "string" + } + } + } + } + } + }, + "Search": { + "type": "object", + "required": [ + "Name" + ], + "properties": { + "Name": { + "type": "string" + }, + "Attributes": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Name", + "Value" + ], + "properties": { + "Name": { + "type": "string" + }, + "Value": { + "type": [ + "number", + "string", + "boolean", + "object", + "array" + ] + } + } + } + } + } + }, + "Budget": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Type", + "BudgetValue" + ], + "properties": { + "Type": { + "enum": [ + "TuningDuration", + "ConfigurationCount", + "ConfigurationFraction" + ] + }, + "BudgetValue": { + "type": "number" + } + } + } + }, + "General": { + "type": "object", + "properties": { + "FormatVersion": { + "type": "integer" + }, + "LoggingLevel": { + "enum": [ + "Off", + "Error", + "Warning", + "Info", + "Debug" + ] + }, + "TimeUnit": { + "enum": [ + "Nanoseconds", + "Microseconds", + "Milliseconds", + "Seconds" + ] + }, + "OutputFile": { + "type": "string", + "examples": [ + "ReductionOutput", + "Results" + ] + }, + "OutputFormat": { + "enum": [ + "JSON", + "XML" + ] + } + } + }, + "KernelSpecification": { + "type": "object", + "required": [ + "Language", + "KernelName", + "KernelFile", + "GlobalSize", + "LocalSize" + ], + "properties": { + "Device": { + "type": "object", + "properties": { + "PlatformId": { + "type": "integer" + }, + "DeviceId": { + "type": "integer" + }, + "Name": { + "type": "string" + } + } + }, + "Language": { + "enum": [ + "OpenCL", + "CUDA", + "Vulkan" + ] + }, + "CompilerOptions": { + "type": "array", + "items": { + "type": "string" + } + }, + "Profiling": { + "type": "boolean" + }, + "KernelName": { + "type": "string" + }, + "KernelFile": { + "type": "string" + }, + "GlobalSizeType": { + "enum": [ + "OpenCL", + "CUDA", + "Vulkan" + ] + }, + "SharedMemory": { + "type": "integer" + }, + "SimulationInput": { + "type": "string" + }, + "GlobalSize": { + "type": "object", + "required": [ + "X" + ], + "properties": { + "X": { + "type": "string" + }, + "Y": { + "type": "string" + }, + "Z": { + "type": "string" + } + } + }, + "LocalSize": { + "type": "object", + "required": [ + "X" + ], + "properties": { + "X": { + "type": "string" + }, + "Y": { + "type": "string" + }, + "Z": { + "type": "string" + } + } + }, + "Arguments": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Type", + "MemoryType" + ], + "properties": { + "Name": { + "type": "string" + }, + "Type": { + "enum": [ + "bool", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + "half", + "half2", + "half4", + "half8", + "half16", + "float", + "float2", + "float4", + "float8", + "float16", + "double", + "double2", + "double4", + "double8", + "double16", + "custom" + ] + }, + "Size": { + "type": [ + "integer", + "string" + ], + "examples": [ + 720, + 26000, + "ProblemSize[0]+max(filter_width)-1" + ] + }, + "TypeSize": { + "type": "integer", + "examples": [ + 4, + 16 + ] + }, + "FillType": { + "enum": [ + "Constant", + "Random", + "Generator", + "Script", + "BinaryRaw", + "BinaryHDF" + ] + }, + "FillValue": { + "type": "number", + "examples": [ + 40, + 1.0 + ] + }, + "DataSource": { + "type": "string" + }, + "RandomSeed": { + "type": "integer" + }, + "AccessType": { + "enum": [ + "ReadOnly", + "WriteOnly", + "ReadWrite" + ] + }, + "MemoryType": { + "enum": [ + "Scalar", + "Vector", + "Local", + "Symbol" + ] + } + } + } + }, + "ReferenceArguments": { + "type": "array", + "items": { + "type": "object", + "required": [ + "Name", + "TargetName", + "FillType" + ], + "properties": { + "Name": { + "type": "string" + }, + "TargetName": { + "type": "string" + }, + "FillType": { + "enum": [ + "Constant", + "Random", + "Generator", + "Script", + "BinaryRaw", + "BinaryHDF" + ] + }, + "FillValue": { + "type": "number", + "examples": [ + 40, + 1.0 + ] + }, + "DataSource": { + "type": "string" + }, + "RandomSeed": { + "type": "integer" + }, + "ValidationMethod": { + "enum": [ + "AbsoluteDifference", + "SideBySideComparison", + "SideBySideRelativeComparison" + ] + }, + "ValidationThreshold": { + "type": "number" + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index 5ee7f7ce2..3024bbf90 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -18,7 +18,7 @@ ) from kernel_tuner.util import check_restrictions as check_instance_restrictions -from kernel_tuner.util import compile_restrictions, default_block_size_names +from kernel_tuner.util import compile_restrictions, default_block_size_names, get_interval supported_neighbor_methods = ["strictly-adjacent", "adjacent", "Hamming"] @@ -47,6 +47,7 @@ def __init__( Optionally sort the searchspace by the order in which the parameter values were specified. By default, sort goes from first to last parameter, to reverse this use sort_last_param_first. """ # set the object attributes using the arguments + framework_l = framework.lower() restrictions = restrictions if restrictions is not None else [] self.tune_params = tune_params self.restrictions = restrictions @@ -66,21 +67,27 @@ def __init__( if ( len(restrictions) > 0 and any(isinstance(restriction, str) for restriction in restrictions) - and not (framework.lower() == "pysmt" or framework.lower() == "bruteforce") + and not (framework_l == "pysmt" or framework_l == "bruteforce") ): self.restrictions = compile_restrictions( - restrictions, tune_params, monolithic=False, try_to_constraint=framework.lower() == "pythonconstraint" + restrictions, + tune_params, + monolithic=False, + format=framework_l if framework_l == "pyatf" else None, + try_to_constraint=framework_l == "pythonconstraint", ) # get the framework given the framework argument - if framework.lower() == "pythonconstraint": + if framework_l == "pythonconstraint": searchspace_builder = self.__build_searchspace - elif framework.lower() == "pysmt": + elif framework_l == "pysmt": searchspace_builder = self.__build_searchspace_pysmt - elif framework.lower() == "atf_cache": + elif framework_l == "pyatf": + searchspace_builder = self.__build_searchspace_pyATF + elif framework_l == "atf_cache": searchspace_builder = self.__build_searchspace_ATF_cache self.path_to_ATF_cache = path_to_ATF_cache - elif framework.lower() == "bruteforce": + elif framework_l == "bruteforce": searchspace_builder = self.__build_searchspace_bruteforce else: raise ValueError(f"Invalid framework parameter {framework}") @@ -142,7 +149,7 @@ def __init__( # num_solutions: int = csp.n_solutions() # number of solutions # solutions = [csp.values(sol=i) for i in range(num_solutions)] # list of solutions - def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: int, solver = None): + def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: int, solver=None): # bruteforce solving of the searchspace from itertools import product @@ -164,9 +171,15 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in restrictions = [restrictions] block_size_restriction_spaced = f"{' * '.join(used_block_size_names)} <= {max_threads}" block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}" - if block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions: + if ( + block_size_restriction_spaced not in restrictions + and block_size_restriction_unspaced not in restrictions + ): restrictions.append(block_size_restriction_spaced) - if isinstance(self._modified_restrictions, list) and block_size_restriction_spaced not in self._modified_restrictions: + if ( + isinstance(self._modified_restrictions, list) + and block_size_restriction_spaced not in self._modified_restrictions + ): self._modified_restrictions.append(block_size_restriction_spaced) if isinstance(self.restrictions, list): self.restrictions.append(block_size_restriction_spaced) @@ -247,6 +260,73 @@ def all_smt(formula, keys) -> list: return self.__parameter_space_list_to_lookup_and_return_type(parameter_space_list) + def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver): + """Builds the searchspace using pyATF.""" + from pyatf import TP, Set, Interval, Tuner + from pyatf.cost_functions.generic import CostFunction + from pyatf.search_techniques import Exhaustive + + # Define a bogus cost function + costfunc = CostFunction(":") # bash no-op + + # add the Kernel Tuner default blocksize threads restrictions + assert isinstance(self.restrictions, list) + valid_block_size_names = list( + block_size_name for block_size_name in block_size_names if block_size_name in self.param_names + ) + if len(valid_block_size_names) > 0: + # adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter + max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}" + restrictions = self._modified_restrictions.copy() + [max_block_size_product] + self.restrictions = compile_restrictions(restrictions, self.tune_params, format="pyatf", try_to_constraint=False) + + # build a dictionary of the restrictions, combined based on last parameter + res_dict = dict() + registered_params = list() + registered_restrictions = list() + for param in self.tune_params.keys(): + registered_params.append(param) + for index, (res, params, source) in enumerate(self.restrictions): + if index in registered_restrictions: + continue + if all(p in registered_params for p in params): + if param in res_dict: + raise KeyError(f"`{param}` is already in res_dict with `{res_dict[param][1]}`, can't add `{source}`") + res_dict[param] = (res, source) + print(source, res, param, params) + registered_restrictions.append(index) + + # define the Tunable Parameters + def get_params(): + params = list() + for index, (key, values) in enumerate(self.tune_params.items()): + vi = get_interval(values) + vals = Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten()) + constraint = res_dict.get(key, None) + constraint_source = None + if constraint is not None: + constraint, constraint_source = constraint + # in case of a leftover monolithic restriction, append at the last parameter + if index == len(self.tune_params) - 1 and len(res_dict) == 0 and len(self.restrictions) == 1: + res, params, source = self.restrictions[0] + assert callable(res) + constraint = res + params.append(TP(key, vals, constraint, constraint_source)) + return params + + # tune + _, _, tuning_data = ( + Tuner().verbosity(0).tuning_parameters(*get_params()).search_technique(Exhaustive()).tune(costfunc) + ) + + # transform the result into a list of parameter configurations for validation + tune_params = self.tune_params + parameter_tuple_list = list() + for entry in tuning_data.history._entries: + parameter_tuple_list.append(tuple(entry.configuration[p] for p in tune_params.keys())) + pl = self.__parameter_space_list_to_lookup_and_return_type(parameter_tuple_list) + return pl + def __build_searchspace_ATF_cache(self, block_size_names: list, max_threads: int, solver: Solver): """Imports the valid configurations from an ATF CSV file, returns the searchspace, a dict of the searchspace for fast lookups and the size.""" if block_size_names != default_block_size_names or max_threads != 1024: @@ -298,10 +378,13 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: if len(valid_block_size_names) > 0: parameter_space.addConstraint(MaxProdConstraint(max_threads), valid_block_size_names) max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}" - if isinstance(self._modified_restrictions, list) and max_block_size_product not in self._modified_restrictions: + if ( + isinstance(self._modified_restrictions, list) + and max_block_size_product not in self._modified_restrictions + ): self._modified_restrictions.append(max_block_size_product) if isinstance(self.restrictions, list): - self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names)) + self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names, None)) # construct the parameter space with the constraints applied return parameter_space.getSolutionsAsListDict(order=self.param_names) @@ -314,7 +397,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: # convert to a Constraint type if necessary if isinstance(restriction, tuple): - restriction, required_params = restriction + restriction, required_params, _ = restriction if callable(restriction) and not isinstance(restriction, Constraint): restriction = FunctionConstraint(restriction) @@ -323,10 +406,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: parameter_space.addConstraint(restriction, required_params) elif isinstance(restriction, Constraint): all_params_required = all(param_name in required_params for param_name in self.param_names) - parameter_space.addConstraint( - restriction, - None if all_params_required else required_params - ) + parameter_space.addConstraint(restriction, None if all_params_required else required_params) else: raise ValueError(f"Unrecognized restriction {restriction}") diff --git a/kernel_tuner/strategies/bayes_opt.py b/kernel_tuner/strategies/bayes_opt.py index bd20e29a9..47d3733f1 100644 --- a/kernel_tuner/strategies/bayes_opt.py +++ b/kernel_tuner/strategies/bayes_opt.py @@ -93,6 +93,9 @@ def tune(searchspace: Searchspace, runner, tuning_options): """ max_fevals = tuning_options.strategy_options.get("max_fevals", 100) + # limit max_fevals to max size of the parameter space + max_fevals = min(searchspace.size, max_fevals) + prune_parameterspace = tuning_options.strategy_options.get("pruneparameterspace", True) if not bayes_opt_present: raise ImportError( diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index d01eae937..3420c86ea 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -55,10 +55,12 @@ def get_options(strategy_options, options): class CostFunc: def __init__(self, searchspace: Searchspace, tuning_options, runner, *, scaling=False, snap=True): self.runner = runner - self.tuning_options = tuning_options self.snap = snap self.scaling = scaling self.searchspace = searchspace + self.tuning_options = tuning_options + if isinstance(self.tuning_options, dict): + self.tuning_options['max_fevals'] = min(tuning_options['max_fevals'] if 'max_fevals' in tuning_options else np.inf, searchspace.size) self.results = [] def __call__(self, x, check_restrictions=True): diff --git a/kernel_tuner/strategies/hillclimbers.py b/kernel_tuner/strategies/hillclimbers.py index b64e7d733..ccd4eebf0 100644 --- a/kernel_tuner/strategies/hillclimbers.py +++ b/kernel_tuner/strategies/hillclimbers.py @@ -1,13 +1,12 @@ import random -from kernel_tuner import util from kernel_tuner.searchspace import Searchspace from kernel_tuner.strategies.common import CostFunc def base_hillclimb(base_sol: tuple, neighbor_method: str, max_fevals: int, searchspace: Searchspace, tuning_options, cost_func: CostFunc, restart=True, randomize=True, order=None): - """ Hillclimbing search until max_fevals is reached or no improvement is found + """Hillclimbing search until max_fevals is reached or no improvement is found. Base hillclimber that evaluates neighbouring solutions in a random or fixed order and possibly immediately moves to the neighbour if it is an improvement. @@ -51,6 +50,9 @@ def base_hillclimb(base_sol: tuple, neighbor_method: str, max_fevals: int, searc """ if randomize and order: raise ValueError("Using a preset order and randomize at the same time is not supported.") + + # limit max_fevals to max size of the parameter space + max_fevals = min(searchspace.size, max_fevals) tune_params = searchspace.tune_params diff --git a/kernel_tuner/strategies/random_sample.py b/kernel_tuner/strategies/random_sample.py index 022eda534..06ab4b9f6 100644 --- a/kernel_tuner/strategies/random_sample.py +++ b/kernel_tuner/strategies/random_sample.py @@ -17,7 +17,7 @@ def tune(searchspace: Searchspace, runner, tuning_options): # override if max_fevals is specified if "max_fevals" in tuning_options: - num_samples = tuning_options.max_fevals + num_samples = min(tuning_options.max_fevals, searchspace.size) samples = searchspace.get_random_sample(num_samples) diff --git a/kernel_tuner/strategies/simulated_annealing.py b/kernel_tuner/strategies/simulated_annealing.py index dce929b7b..dcb9e3f26 100644 --- a/kernel_tuner/strategies/simulated_annealing.py +++ b/kernel_tuner/strategies/simulated_annealing.py @@ -27,7 +27,10 @@ def tune(searchspace: Searchspace, runner, tuning_options): # if user supplied max_fevals that is lower then max_iter we will # scale the annealing schedule to fit max_fevals - max_feval = tuning_options.strategy_options.get("max_fevals", max_iter) + max_fevals = tuning_options.strategy_options.get("max_fevals", max_iter) + + # limit max_fevals to max size of the parameter space + max_fevals = min(searchspace.size, max_fevals) # get random starting point and evaluate cost pos = list(searchspace.get_random_sample(1)[0]) @@ -64,7 +67,7 @@ def tune(searchspace: Searchspace, runner, tuning_options): old_cost = new_cost c = len(tuning_options.unique_results) - T = T_start * alpha**(max_iter/max_feval*c) + T = T_start * alpha**(max_iter/max_fevals*c) # check if solver gets stuck and if so restart from random position if c == c_old: diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 0d2cef696..860e139ae 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -1,4 +1,5 @@ """Module for kernel tuner utility functions.""" + import errno import json import logging @@ -9,6 +10,7 @@ import time import warnings from inspect import signature +from pathlib import Path from types import FunctionType from typing import Optional, Union @@ -138,9 +140,7 @@ def check_argument_list(kernel_name, kernel_string, args): for arguments_set, arguments in enumerate(kernel_arguments): collected_errors.append(list()) if len(arguments) != len(args): - collected_errors[arguments_set].append( - "Kernel and host argument lists do not match in size." - ) + collected_errors[arguments_set].append("Kernel and host argument lists do not match in size.") continue for i, arg in enumerate(args): kernel_argument = arguments[i] @@ -186,10 +186,7 @@ def check_stop_criterion(to): """Checks if max_fevals is reached or time limit is exceeded.""" if "max_fevals" in to and len(to.unique_results) >= to.max_fevals: raise StopCriterionReached("max_fevals reached") - if "time_limit" in to and ( - ((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) - > to.time_limit - ): + if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit): raise StopCriterionReached("time limit exceeded") @@ -198,13 +195,7 @@ def check_tune_params_list(tune_params, observers, simulation_mode=False): forbidden_names = ("grid_size_x", "grid_size_y", "grid_size_z", "time") for name, param in tune_params.items(): if name in forbidden_names: - raise ValueError( - "Tune parameter " - + name - + " with value " - + str(param) - + " has a forbidden name!" - ) + raise ValueError("Tune parameter " + name + " with value " + str(param) + " has a forbidden name!") if any("nvml_" in param for param in tune_params): if not simulation_mode and (not observers or not any(isinstance(obs, NVMLObserver) for obs in observers)): raise ValueError("Tune parameters starting with nvml_ require an NVMLObserver!") @@ -243,6 +234,7 @@ def check_block_size_params_names_list(block_size_names, tune_params): UserWarning, ) + def check_restriction(restrict, params: dict) -> bool: """Check whether a configuration meets a search space restriction.""" # if it's a python-constraint, convert to function and execute @@ -256,10 +248,17 @@ def check_restriction(restrict, params: dict) -> bool: elif callable(restrict): return restrict(**params) # if it's a tuple, use only the parameters in the second argument to call the restriction - elif (isinstance(restrict, tuple) and len(restrict) == 2 - and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))): + elif ( + isinstance(restrict, tuple) + and (len(restrict) == 2 or len(restrict) == 3) + and callable(restrict[0]) + and isinstance(restrict[1], (list, tuple)) + ): # unpack the tuple - restrict, selected_params = restrict + if len(restrict) == 2: + restrict, selected_params = restrict + else: + restrict, selected_params, source = restrict # look up the selected parameters and their value selected_params = dict((key, params[key]) for key in selected_params) # call the restriction @@ -272,6 +271,7 @@ def check_restriction(restrict, params: dict) -> bool: else: raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})") + def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: """Check whether a configuration meets the search space restrictions.""" if callable(restrictions): @@ -296,29 +296,45 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: def convert_constraint_restriction(restrict: Constraint): """Convert the python-constraint to a function for backwards compatibility.""" if isinstance(restrict, FunctionConstraint): + def f_restrict(p): return restrict._func(*p) + elif isinstance(restrict, AllDifferentConstraint): + def f_restrict(p): return len(set(p)) == len(p) + elif isinstance(restrict, AllEqualConstraint): + def f_restrict(p): return all(x == p[0] for x in p) + elif isinstance(restrict, MaxProdConstraint): + def f_restrict(p): return np.prod(p) <= restrict._maxprod + elif isinstance(restrict, MinProdConstraint): + def f_restrict(p): return np.prod(p) >= restrict._minprod + elif isinstance(restrict, MaxSumConstraint): + def f_restrict(p): return sum(p) <= restrict._maxsum + elif isinstance(restrict, ExactSumConstraint): + def f_restrict(p): return sum(p) == restrict._exactsum + elif isinstance(restrict, MinSumConstraint): + def f_restrict(p): return sum(p) >= restrict._minsum + elif isinstance(restrict, (InSetConstraint, NotInSetConstraint, SomeInSetConstraint, SomeNotInSetConstraint)): raise NotImplementedError( f"Restriction of the type {type(restrict)} is explicitely not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm." @@ -343,9 +359,7 @@ def config_valid(config, tuning_options, max_threads): if not legal: return False block_size_names = tuning_options.get("block_size_names", None) - valid_thread_block_dimensions = check_thread_block_dimensions( - params, max_threads, block_size_names - ) + valid_thread_block_dimensions = check_thread_block_dimensions(params, max_threads, block_size_names) return valid_thread_block_dimensions @@ -372,9 +386,7 @@ def detect_language(kernel_string): def get_best_config(results, objective, objective_higher_is_better=False): """Returns the best configuration from a list of results according to some objective.""" func = max if objective_higher_is_better else min - ignore_val = ( - sys.float_info.max if not objective_higher_is_better else -sys.float_info.max - ) + ignore_val = sys.float_info.max if not objective_higher_is_better else -sys.float_info.max best_config = func( results, key=lambda x: x[objective] if isinstance(x[objective], float) else ignore_val, @@ -419,18 +431,10 @@ def get_dimension_divisor(divisor_list, default, params): if callable(divisor_list): return divisor_list(params) else: - return np.prod( - [int(eval(replace_param_occurrences(s, params))) for s in divisor_list] - ) + return np.prod([int(eval(replace_param_occurrences(s, params))) for s in divisor_list]) - divisors = [ - get_dimension_divisor(d, block_size_names[i], params) - for i, d in enumerate(grid_div) - ] - return tuple( - int(np.ceil(float(current_problem_size[i]) / float(d))) - for i, d in enumerate(divisors) - ) + divisors = [get_dimension_divisor(d, block_size_names[i], params) for i, d in enumerate(grid_div)] + return tuple(int(np.ceil(float(current_problem_size[i]) / float(d))) for i, d in enumerate(divisors)) def get_instance_string(params): @@ -438,6 +442,28 @@ def get_instance_string(params): return "_".join([str(i) for i in params.values()]) +def get_interval(a: list): + """Checks if an array can be an interval. Returns (start, end, step) if interval, otherwise None.""" + if len(a) < 3: + return None + if not all(isinstance(e, (int, float)) for e in a): + return None + a_min = min(a) + a_max = max(a) + if len(a) <= 2: + return (a_min, a_max, a_max-a_min) + # determine the first step size + step = a[1]-a_min + # for each element, the step size should be equal to the first step + for i, e in enumerate(a): + if e-a[i-1] != step: + return None + result = (a_min, a_max, step) + if not all(isinstance(e, (int, float)) for e in result): + return None + return result + + def get_kernel_string(kernel_source, params=None): """Retrieve the kernel source and return as a string. @@ -452,8 +478,9 @@ def get_kernel_string(kernel_source, params=None): after all. :param kernel_source: One of the sources for the kernel, could be a - function that generates the kernel code, a string containing a filename - that points to the kernel source, or just a string that contains the code. + function that generates the kernel code, a string or Path containing a + filename that points to the kernel source, or just a string that + contains the code. :type kernel_source: string or callable :param params: Dictionary containing the tunable parameters for this specific @@ -469,6 +496,8 @@ def get_kernel_string(kernel_source, params=None): kernel_string = None if callable(kernel_source): kernel_string = kernel_source(params) + elif isinstance(kernel_source, Path): + kernel_string = read_file(kernel_source) elif isinstance(kernel_source, str): if looks_like_a_filename(kernel_source): kernel_string = read_file(kernel_source) or kernel_source @@ -492,9 +521,7 @@ def get_problem_size(problem_size, params): elif isinstance(s, (int, np.integer)): current_problem_size[i] = s else: - raise TypeError( - "Error: problem_size should only contain strings or integers" - ) + raise TypeError("Error: problem_size should only contain strings or integers") return current_problem_size @@ -569,11 +596,12 @@ def get_total_timings(results, env, overhead_time): return env -NVRTC_VALID_CC = np.array(['50', '52', '53', '60', '61', '62', '70', '72', '75', '80', '87', '89', '90', '90a']) +NVRTC_VALID_CC = np.array(["50", "52", "53", "60", "61", "62", "70", "72", "75", "80", "87", "89", "90", "90a"]) + def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str: """Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options.""" - return max(NVRTC_VALID_CC[NVRTC_VALID_CC<=compute_capability], default='52') + return max(NVRTC_VALID_CC[NVRTC_VALID_CC <= compute_capability], default="52") def print_config(config, tuning_options, runner): @@ -761,7 +789,10 @@ def prepare_kernel_string(kernel_name, kernel_string, params, grid, threads, blo def read_file(filename): """Return the contents of the file named filename or None if file not found.""" - if os.path.isfile(filename): + if isinstance(filename, Path): + with filename.open() as f: + return f.read() + elif os.path.isfile(filename): with open(filename, "r") as f: return f.read() @@ -822,14 +853,16 @@ def has_kw_argument(func, name): return lambda answer, result_host, atol: v(answer, result_host) -def parse_restrictions(restrictions: list[str], tune_params: dict, monolithic = False, try_to_constraint = True) -> list[tuple[Union[Constraint, str], list[str]]]: +def parse_restrictions( + restrictions: list[str], tune_params: dict, monolithic=False, format=None, try_to_constraint=True +) -> list[tuple[Union[Constraint, str], list[str]]]: """Parses restrictions from a list of strings into compilable functions and constraints, or a single compilable function (if monolithic is True). Returns a list of tuples of (strings or constraints) and parameters.""" # rewrite the restrictions so variables are singled out regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" def replace_params(match_object): key = match_object.group(1) - if key in tune_params: + if key in tune_params and format != "pyatf": param = str(key) return "params[params_index['" + param + "']]" else: @@ -854,8 +887,8 @@ def to_multiple_restrictions(restrictions: list[str]) -> list[str]: split_restrictions.append(res) continue # find the indices of splittable comparators - comparators = ['<=', '>=', '>', '<'] - comparators_indices = [(m.start(0), m.end(0)) for m in re.finditer('|'.join(comparators), res)] + comparators = ["<=", ">=", ">", "<"] + comparators_indices = [(m.start(0), m.end(0)) for m in re.finditer("|".join(comparators), res)] if len(comparators_indices) <= 1: # this can't be split further split_restrictions.append(res) @@ -863,15 +896,19 @@ def to_multiple_restrictions(restrictions: list[str]) -> list[str]: # split the restrictions from the previous to the next comparator for index in range(len(comparators_indices)): temp_copy = res - prev_stop = comparators_indices[index-1][1] + 1 if index > 0 else 0 - next_stop = comparators_indices[index+1][0] if index < len(comparators_indices) - 1 else len(temp_copy) + prev_stop = comparators_indices[index - 1][1] + 1 if index > 0 else 0 + next_stop = ( + comparators_indices[index + 1][0] if index < len(comparators_indices) - 1 else len(temp_copy) + ) split_restrictions.append(temp_copy[prev_stop:next_stop].strip()) return split_restrictions - def to_numeric_constraint(restriction: str, params: list[str]) -> Optional[Union[MinSumConstraint, ExactSumConstraint, MaxSumConstraint, MaxProdConstraint]]: + def to_numeric_constraint( + restriction: str, params: list[str] + ) -> Optional[Union[MinSumConstraint, ExactSumConstraint, MaxSumConstraint, MaxProdConstraint]]: """Converts a restriction to a built-in numeric constraint if possible.""" - comparators = ['<=', '==', '>=', '>', '<'] - comparators_found = re.findall('|'.join(comparators), restriction) + comparators = ["<=", "==", ">=", ">", "<"] + comparators_found = re.findall("|".join(comparators), restriction) # check if there is exactly one comparator, if not, return None if len(comparators_found) != 1: return None @@ -897,19 +934,21 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: if (left_num is None and right_num is None) or (left_num is not None and right_num is not None): # left_num and right_num can't be both None or both a constant return None - number, variables, variables_on_left = (left_num, right.strip(), False) if left_num is not None else (right_num, left.strip(), True) + number, variables, variables_on_left = ( + (left_num, right.strip(), False) if left_num is not None else (right_num, left.strip(), True) + ) # if the number is an integer, we can map '>' to '>=' and '<' to '<=' by changing the number (does not work with floating points!) number_is_int = isinstance(number, int) if number_is_int: - if comparator == '<': + if comparator == "<": if variables_on_left: # (x < 2) == (x <= 2-1) number -= 1 else: # (2 < x) == (2+1 <= x) number += 1 - elif comparator == '>': + elif comparator == ">": if variables_on_left: # (x > 2) == (x >= 2+1) number += 1 @@ -918,8 +957,8 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: number -= 1 # check if an operator is applied on the variables, if not return - operators = [r'\*\*', r'\*', r'\+'] - operators_found = re.findall(str('|'.join(operators)), variables) + operators = [r"\*\*", r"\*", r"\+"] + operators_found = re.findall(str("|".join(operators)), variables) if len(operators_found) == 0: # no operators found, return only based on comparator if len(params) != 1 or variables not in params: @@ -927,12 +966,12 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: return None # map to a Constraint # if there are restrictions with a single variable, it will be used to prune the domain at the start - elif comparator == '==': + elif comparator == "==": return ExactSumConstraint(number) - elif comparator == '<=' or (comparator == '<' and number_is_int): - return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): - return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) + elif comparator == "<=" or (comparator == "<" and number_is_int): + return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) + elif comparator == ">=" or (comparator == ">" and number_is_int): + return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) raise ValueError(f"Invalid comparator {comparator}") # check which operator is applied on the variables @@ -946,34 +985,36 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: # check if there are only pure, non-recurring variables (no operations or constants) in the restriction if len(splitted) == len(params) and all(s.strip() in params for s in splitted): # map to a Constraint - if operator == '**': + if operator == "**": # power operations are not (yet) supported, added to avoid matching the double asterisk return None - elif operator == '*': - if comparator == '<=' or (comparator == '<' and number_is_int): + elif operator == "*": + if comparator == "<=" or (comparator == "<" and number_is_int): return MaxProdConstraint(number) if variables_on_left else MinProdConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): + elif comparator == ">=" or (comparator == ">" and number_is_int): return MinProdConstraint(number) if variables_on_left else MaxProdConstraint(number) - elif operator == '+': - if comparator == '==': + elif operator == "+": + if comparator == "==": return ExactSumConstraint(number) - elif comparator == '<=' or (comparator == '<' and number_is_int): + elif comparator == "<=" or (comparator == "<" and number_is_int): return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): + elif comparator == ">=" or (comparator == ">" and number_is_int): return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) else: raise ValueError(f"Invalid operator {operator}") return None - def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Union[AllEqualConstraint, AllDifferentConstraint]]: + def to_equality_constraint( + restriction: str, params: list[str] + ) -> Optional[Union[AllEqualConstraint, AllDifferentConstraint]]: """Converts a restriction to either an equality or inequality constraint on all the parameters if possible.""" # check if all parameters are involved if len(params) != len(tune_params): return None # find whether (in)equalities appear in this restriction - equalities_found = re.findall('==', restriction) - inequalities_found = re.findall('!=', restriction) + equalities_found = re.findall("==", restriction) + inequalities_found = re.findall("!=", restriction) # check if one of the two have been found, if none or both have been found, return None if not (len(equalities_found) > 0 ^ len(inequalities_found) > 0): return None @@ -984,12 +1025,21 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio # check if there are only pure, non-recurring variables (no operations or constants) in the restriction if len(splitted) == len(params) and all(s.strip() in params for s in splitted): # map to a Constraint - if comparator == '==': + if comparator == "==": return AllEqualConstraint() - elif comparator == '!=': + elif comparator == "!=": return AllDifferentConstraint() return ValueError(f"Not possible: comparator should be '==' or '!=', is {comparator}") return None + + # remove functionally duplicate restrictions (preserves order and whitespace) + if all(isinstance(r, str) for r in restrictions): + # clean the restriction strings to functional equivalence + restrictions_cleaned = [r.replace(' ', '') for r in restrictions] + restrictions_cleaned_unique = list(dict.fromkeys(restrictions_cleaned)) # dict preserves order + # get the indices of the unique restrictions, use these to build a new list of restrictions + restrictions_unique_indices = [restrictions_cleaned.index(r) for r in restrictions_cleaned_unique] + restrictions = [restrictions[i] for i in restrictions_unique_indices] # create the parsed restrictions if monolithic is False: @@ -1005,7 +1055,12 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio finalized_constraint = None if try_to_constraint and " or " not in res and " and " not in res: # if applicable, strip the outermost round brackets - while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]: + while ( + parsed_restriction[0] == "(" + and parsed_restriction[-1] == ")" + and "(" not in parsed_restriction[1:] + and ")" not in parsed_restriction[:1] + ): parsed_restriction = parsed_restriction[1:-1] # check if we can turn this into the built-in numeric comparison constraint finalized_constraint = to_numeric_constraint(parsed_restriction, params_used) @@ -1014,11 +1069,40 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio finalized_constraint = to_equality_constraint(parsed_restriction, params_used) if finalized_constraint is None: # we must turn it into a general function - finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n" + if format is not None and format.lower() == "pyatf": + finalized_constraint = parsed_restriction + else: + finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n" parsed_restrictions.append((finalized_constraint, params_used)) + + # if pyATF, restrictions that are set on the same parameter must be combined into one + if format is not None and format.lower() == "pyatf": + res_dict = dict() + registered_params = list() + registered_restrictions = list() + parsed_restrictions_pyatf = list() + for param in tune_params.keys(): + registered_params.append(param) + for index, (res, params) in enumerate(parsed_restrictions): + if index in registered_restrictions: + continue + if all(p in registered_params for p in params): + if param not in res_dict: + res_dict[param] = (list(), list()) + res_dict[param][0].append(res) + res_dict[param][1].extend(params) + registered_restrictions.append(index) + # combine multiple restrictions into one + for res_tuple in res_dict.values(): + res, params_used = res_tuple + params_used = list(dict.fromkeys(params_used)) # param_used should only contain unique, dict preserves order + parsed_restrictions_pyatf.append((f"def r({', '.join(params_used)}): return ({') and ('.join(res)}) \n", params_used)) + parsed_restrictions = parsed_restrictions_pyatf else: # create one monolithic function - parsed_restrictions = ") and (".join([re.sub(regex_match_variable, replace_params, res) for res in restrictions]) + parsed_restrictions = ") and (".join( + [re.sub(regex_match_variable, replace_params, res) for res in restrictions] + ) # tidy up the code by removing the last suffix and unnecessary spaces parsed_restrictions = "(" + parsed_restrictions.strip() + ")" @@ -1027,13 +1111,28 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio # provide a mapping of the parameter names to the index in the tuple received params_index = dict(zip(tune_params.keys(), range(len(tune_params.keys())))) - parsed_restrictions = [(f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", list(tune_params.keys()))] + if format == "pyatf": + parsed_restrictions = [ + ( + f"def restrictions({', '.join(params_index.keys())}): return {parsed_restrictions} \n", + list(tune_params.keys()), + ) + ] + else: + parsed_restrictions = [ + ( + f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", + list(tune_params.keys()), + ) + ] return parsed_restrictions -def compile_restrictions(restrictions: list, tune_params: dict, monolithic = False, try_to_constraint = True) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]: - """Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used, or a single Function if monolithic is true.""" +def compile_restrictions( + restrictions: list, tune_params: dict, monolithic=False, format=None, try_to_constraint=True +) -> list[tuple[Union[str, Constraint, FunctionType], list[str], Union[str, None]]]: + """Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used and source, or a single Function if monolithic is true.""" # filter the restrictions to get only the strings restrictions_str, restrictions_ignore = [], [] for r in restrictions: @@ -1042,7 +1141,9 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal return restrictions_ignore # parse the strings - parsed_restrictions = parse_restrictions(restrictions_str, tune_params, monolithic=monolithic, try_to_constraint=try_to_constraint) + parsed_restrictions = parse_restrictions( + restrictions_str, tune_params, monolithic=monolithic, format=format, try_to_constraint=try_to_constraint + ) # compile the parsed restrictions into a function compiled_restrictions: list[tuple] = list() @@ -1051,10 +1152,10 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal # if it's a string, parse it to a function code_object = compile(restriction, "", "exec") func = FunctionType(code_object.co_consts[0], globals()) - compiled_restrictions.append((func, params_used)) + compiled_restrictions.append((func, params_used, restriction)) elif isinstance(restriction, Constraint): # otherwise it already is a Constraint, pass it directly - compiled_restrictions.append((restriction, params_used)) + compiled_restrictions.append((restriction, params_used, None)) else: raise ValueError(f"Restriction {restriction} is neither a string or Constraint {type(restriction)}") @@ -1066,9 +1167,10 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal noncompiled_restrictions = [] for r in restrictions_ignore: if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)): - noncompiled_restrictions.append(r) + restriction, params_used = r + noncompiled_restrictions.append((restriction, params_used, restriction)) else: - noncompiled_restrictions.append((r, ())) + noncompiled_restrictions.append((r, [], r)) return noncompiled_restrictions + compiled_restrictions @@ -1103,18 +1205,12 @@ def process_cache(cache, kernel_options, tuning_options, runner): # if file does not exist, create new cache if not os.path.isfile(cache): if tuning_options.simulation_mode: - raise ValueError( - f"Simulation mode requires an existing cachefile: file {cache} does not exist" - ) + raise ValueError(f"Simulation mode requires an existing cachefile: file {cache} does not exist") c = dict() c["device_name"] = runner.dev.name c["kernel_name"] = kernel_options.kernel_name - c["problem_size"] = ( - kernel_options.problem_size - if not callable(kernel_options.problem_size) - else "callable" - ) + c["problem_size"] = kernel_options.problem_size if not callable(kernel_options.problem_size) else "callable" c["tune_params_keys"] = list(tuning_options.tune_params.keys()) c["tune_params"] = tuning_options.tune_params c["objective"] = tuning_options.objective @@ -1140,33 +1236,25 @@ def process_cache(cache, kernel_options, tuning_options, runner): # check if it is safe to continue tuning from this cache if cached_data["device_name"] != runner.dev.name: raise ValueError( - "Cannot load cache which contains results for different device" + f"Cannot load cache which contains results for different device (cache: {cached_data['device_name']}, actual: {runner.dev.name})" ) if cached_data["kernel_name"] != kernel_options.kernel_name: raise ValueError( - "Cannot load cache which contains results for different kernel" + f"Cannot load cache which contains results for different kernel (cache: {cached_data['kernel_name']}, actual: {kernel_options.kernel_name})" ) if "problem_size" in cached_data and not callable(kernel_options.problem_size): + # if it's a single value, convert to an array + if isinstance(cached_data["problem_size"], int): + cached_data["problem_size"] = [cached_data["problem_size"]] # if problem_size is not iterable, compare directly if not hasattr(kernel_options.problem_size, "__iter__"): if cached_data["problem_size"] != kernel_options.problem_size: - raise ValueError( - "Cannot load cache which contains results for different problem_size" - ) + raise ValueError("Cannot load cache which contains results for different problem_size") # else (problem_size is iterable) # cache returns list, problem_size is likely a tuple. Therefore, the next check # checks the equality of all items in the list/tuples individually - elif not all( - [ - i == j - for i, j in zip( - cached_data["problem_size"], kernel_options.problem_size - ) - ] - ): - raise ValueError( - "Cannot load cache which contains results for different problem_size" - ) + elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]): + raise ValueError("Cannot load cache which contains results for different problem_size") if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()): if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]): raise ValueError( @@ -1189,7 +1277,7 @@ def correct_open_cache(cache, open_cache=True): filestr = cachefile.read().strip() # if file was not properly closed, pretend it was properly closed - if len(filestr) > 0 and not filestr[-3:] in ["}\n}", "}}}"]: + if len(filestr) > 0 and filestr[-3:] not in ["}\n}", "}}}"]: # remove the trailing comma if any, and append closing brackets if filestr[-1] == ",": filestr = filestr[:-1] @@ -1203,6 +1291,7 @@ def correct_open_cache(cache, open_cache=True): return filestr + def read_cache(cache, open_cache=True): """Read the cachefile into a dictionary, if open_cache=True prepare the cachefile for appending.""" filestr = correct_open_cache(cache, open_cache) diff --git a/noxfile.py b/noxfile.py index e32bbb588..75c9ea902 100644 --- a/noxfile.py +++ b/noxfile.py @@ -84,7 +84,7 @@ def check_development_environment(session: Session) -> None: return None output: str = session.run("poetry", "install", "--sync", "--dry-run", "--with", "test", silent=True, external=True) match = re.search(r"Package operations: (\d+) installs, (\d+) updates, (\d+) removals, \d+ skipped", output) - assert match is not None + assert match is not None, f"Could not check development environment, reason: {output}" groups = match.groups() installs, updates, removals = int(groups[0]), int(groups[1]), int(groups[2]) if installs > 0 or updates > 0: diff --git a/pyproject.toml b/pyproject.toml index 13d1cb647..323978437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ repository = "https://github.com/KernelTuner/kernel_tuner" "Tracker" = "https://github.com/KernelTuner/kernel_tuner/issues" [tool.poetry.build] generate-setup-file = false +[tool.poetry.scripts] +kernel_tuner = "kernel_tuner.interface:entry_point" # ATTENTION: if anything is changed here, run `poetry update` [tool.poetry.dependencies] diff --git a/test/context.py b/test/context.py index cccc3332a..ba5030430 100644 --- a/test/context.py +++ b/test/context.py @@ -63,6 +63,12 @@ except ImportError: pyhip_present = False +try: + from autotuning_methodology.report_experiments import get_strategy_scores + methodology_present = True +except ImportError: + methodology_present = False + skip_if_no_pycuda = pytest.mark.skipif( not pycuda_present, reason="PyCuda not installed or no CUDA device detected" ) @@ -83,6 +89,7 @@ skip_if_no_openmp = pytest.mark.skipif(not openmp_present, reason="No OpenMP found") skip_if_no_openacc = pytest.mark.skipif(not openacc_present, reason="No nvc++ on PATH") skip_if_no_pyhip = pytest.mark.skipif(not pyhip_present, reason="No PyHIP found") +skip_if_no_methodology = pytest.mark.skipif(not methodology_present, reason="Autotuning Methodology not found") def skip_backend(backend: str): diff --git a/test/convolution.cu b/test/convolution.cu new file mode 100644 index 000000000..ecafcf4b8 --- /dev/null +++ b/test/convolution.cu @@ -0,0 +1,166 @@ +#define image_height 4096 +#define image_width 4096 + +#ifndef filter_height + #define filter_height 17 +#endif +#ifndef filter_width + #define filter_width 17 +#endif + +#define border_height ((filter_height/2)*2) +#define border_width ((filter_width/2)*2) +#define input_height (image_height + border_height) +#define input_width (image_width + border_width) + +#ifndef block_size_x + #define block_size_x 16 +#endif +#ifndef block_size_y + #define block_size_y 16 +#endif +#ifndef block_size_z + #define block_size_z 1 +#endif +#ifndef tile_size_x + #define tile_size_x 1 +#endif +#ifndef tile_size_y + #define tile_size_y 1 +#endif + +#define i_end min(block_size_y*tile_size_y+border_height, input_height) +#define j_end min(block_size_x*tile_size_x+border_width, input_width) + +/* + * If requested, we can use the __ldg directive to load data through the + * read-only cache. + */ +#define USE_READ_ONLY_CACHE read_only +#if USE_READ_ONLY_CACHE == 1 +#define LDG(x, y) __ldg(x+y) +#elif USE_READ_ONLY_CACHE == 0 +#define LDG(x, y) x[y] +#endif + +__constant__ float d_filter[33*33]; //large enough for the largest filter + +/* + * If use_padding == 1, we introduce (only when necessary) a number of padding + * columns in shared memory to avoid shared memory bank conflicts + * + * padding columns are only inserted when block_size_x is not a multiple of 32 (the assumed number of memory banks) + * and when the width of the data needed is not a multiple of 32. The latter is because some filter_widths never + * cause bank conflicts. + * + * If not passed as a tunable parameter, padding is on by default + */ +#define shared_mem_width (block_size_x*tile_size_x+border_width) +#ifndef use_padding + #define use_padding 1 +#endif +#if use_padding == 1 + #if (((block_size_x % 32)!=0) && (((shared_mem_width-block_size_x)%32) != 0)) + // next line uses &31 instead of %32, because % in C is remainder not modulo + #define padding_columns ((32 - (border_width + block_size_x*tile_size_x - block_size_x)) & 31) + #undef shared_mem_width + #define shared_mem_width (block_size_x*tile_size_x+border_width+padding_columns) + #endif +#endif + + +__global__ void convolution_kernel(float *output, float *input, float *filter) { + int ty = threadIdx.y; + int tx = threadIdx.x; + int by = blockIdx.y * block_size_y * tile_size_y; + int bx = blockIdx.x * block_size_x * tile_size_x; + + //shared memory to hold all input data need by this thread block + __shared__ float sh_input[block_size_y*tile_size_y+border_height][shared_mem_width]; + + //load all input data needed by this thread block into shared memory + #pragma unroll + for (int i=ty; i 0 - + result, env = tune_hyper_params(target_strategy, hyper_params, iterations=1, verbose=True) + assert len(result) == 2 + assert 'best_config' in env