Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for any case spelling of block size name defaults #277

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ def tune_kernel(
# check for forbidden names in tune parameters
util.check_tune_params_list(tune_params, observers, simulation_mode=simulation_mode)

# check whether block_size_names are used as expected
util.check_block_size_params_names_list(block_size_names, tune_params)
# check whether block_size_names are used
block_size_names = util.check_block_size_params_names_list(block_size_names, tune_params)

# ensure there is always at least three names
util.append_default_block_size_names(block_size_names)
Expand Down
13 changes: 12 additions & 1 deletion kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,22 @@ def check_block_size_params_names_list(block_size_names, tune_params):
"Block size name " + name + " is not specified in the tunable parameters list!", UserWarning
)
else: # if default block size names are used
if not any([k in default_block_size_names for k in tune_params.keys()]):
if not any([k.lower() in default_block_size_names for k in tune_params.keys()]):
warnings.warn(
"None of the tunable parameters specify thread block dimensions!",
UserWarning,
)
else:
# check for alternative case spelling of defaults such as BLOCK_SIZE_X or block_Size_X etc
result = []
for k in tune_params.keys():
if k.lower() in default_block_size_names and k not in default_block_size_names:
result.append(k)
# ensure order of block_size_names is correct regardless of case used
block_size_names = sorted(result, key=str.casefold)

return block_size_names


def check_restriction(restrict, params: dict) -> bool:
"""Check whether a configuration meets a search space restriction."""
Expand Down