From 90918fcff55f68793b32c6221b3b174092e2977c Mon Sep 17 00:00:00 2001 From: Ben van Werkhoven Date: Fri, 11 Oct 2024 11:21:37 +0200 Subject: [PATCH] add support for any case spelling of block size name defaults --- kernel_tuner/interface.py | 4 ++-- kernel_tuner/util.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 97ae22848..bd421aeab 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -592,8 +592,8 @@ def tune_kernel( # check for forbidden names in tune parameters util.check_tune_params_list(tune_params, observers, simulation_mode=simulation_mode) - # check whether block_size_names are used as expected - util.check_block_size_params_names_list(block_size_names, tune_params) + # check whether block_size_names are used + block_size_names = util.check_block_size_params_names_list(block_size_names, tune_params) # ensure there is always at least three names util.append_default_block_size_names(block_size_names) diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 0d2cef696..52279dcb7 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -237,11 +237,22 @@ def check_block_size_params_names_list(block_size_names, tune_params): "Block size name " + name + " is not specified in the tunable parameters list!", UserWarning ) else: # if default block size names are used - if not any([k in default_block_size_names for k in tune_params.keys()]): + if not any([k.lower() in default_block_size_names for k in tune_params.keys()]): warnings.warn( "None of the tunable parameters specify thread block dimensions!", UserWarning, ) + else: + # check for alternative case spelling of defaults such as BLOCK_SIZE_X or block_Size_X etc + result = [] + for k in tune_params.keys(): + if k.lower() in default_block_size_names and k not in default_block_size_names: + result.append(k) + # ensure order of block_size_names is correct regardless of case used + block_size_names = sorted(result, key=str.casefold) + + return block_size_names + def check_restriction(restrict, params: dict) -> bool: """Check whether a configuration meets a search space restriction."""