diff --git a/tests/math_builtin_api/modules/test_generator.py b/tests/math_builtin_api/modules/test_generator.py index b5cc86d7b..fd3852bab 100644 --- a/tests/math_builtin_api/modules/test_generator.py +++ b/tests/math_builtin_api/modules/test_generator.py @@ -77,57 +77,62 @@ """) } -def generate_value(base_type, dim): - val = "" - for i in range(dim): - if base_type == "bool": - val += "true," - if base_type == "float" or base_type == "double" or base_type == "sycl::half": - # 10 digits of precision for floats, doubles and half. - val += str(round(random.uniform(0.1, 0.9), 10)) - if base_type == "double": - val += "," - else: - val += "f," - # random 8 bit integer - if base_type == "char": - val += str(random.randint(0, 127)) + "," - if base_type == "signed char" or base_type == "int8_t": - val += str(random.randint(-128, 127)) + "," - if base_type == "unsigned char" or base_type == "uint8_t": - val += str(random.randint(0, 255)) + "," +def get_literal_suffix(base_type): + mapping = { + "float": "f", "unsigned long": "U", "uint32_t": "U", + "sycl::half": "f", "long long": "LL", "int64_t": "LL", + "unsigned long long": "LLU", "uint64_t": "LLU" } + return mapping[base_type] if base_type in mapping else "" + +def generate_literal_value(base_type): + if base_type == "bool": + return "true" + if base_type == "float" or base_type == "double" or base_type == "sycl::half": + # 10 digits of precision for floats, doubles and half. + return round(random.uniform(0.1, 0.9), 10) + # random 8 bit integer + if base_type == "char": + return random.randint(0, 127) + if base_type == "signed char" or base_type == "int8_t": + return random.randint(-128, 127) + if base_type == "unsigned char" or base_type == "uint8_t": + return random.randint(0, 255) # random 16 bit integer - if base_type == "int" or base_type == "short" or base_type == "int16_t": - val += str(random.randint(-32768, 32767)) + "," - if base_type == "unsigned" or base_type == "unsigned short" or base_type == "uint16_t": - val += str(random.randint(0, 65535)) + "," + if base_type == "int" or base_type == "short" or base_type == "int16_t": + return random.randint(-32768, 32767) + if base_type == "unsigned" or base_type == "unsigned short" or base_type == "uint16_t": + return random.randint(0, 65535) # random 32 bit integer - if base_type == "long" or base_type == "int32_t": - val += str(random.randint(-2147483648, 2147483647)) + "," - if base_type == "unsigned long" or base_type == "uint32_t": - val += str(random.randint(0, 4294967295)) + "U" + "," + if base_type == "long" or base_type == "int32_t": + return random.randint(-2147483648, 2147483647) + if base_type == "unsigned long" or base_type == "uint32_t": + return random.randint(0, 4294967295) # random 64 bit integer - if base_type == "long long" or base_type == "int64_t": - val += str(random.randint(-9223372036854775808, 9223372036854775807)) + "LL" + "," - if base_type == "unsigned long long" or base_type == "uint64_t": - val += str(random.randint(0, 18446744073709551615)) + "LLU" + "," - return val[:-1] + if base_type == "long long" or base_type == "int64_t": + return random.randint(-9223372036854775808, 9223372036854775807) + if base_type == "unsigned long long" or base_type == "uint64_t": + return random.randint(0, 18446744073709551615) + +def generate_value(base_type, dim): + values = [str(generate_literal_value(base_type)) + get_literal_suffix(base_type) for _ in range(dim)] + return ','.join(values) def generate_multi_ptr(var_name, var_type, memory, decorated): decl = "" + value = generate_value(var_type.base_type, var_type.dim) if memory == "global": source_name = "multiPtrSourceData" - decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n" + decl = var_type.name + " " + source_name + "(" + value + ");\n" decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::global_space," + decorated + "> " decl += var_name + "(acc);\n" if memory == "local": source_name = "multiPtrSourceData" - decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n" + decl = var_type.name + " " + source_name + "(" + value + ");\n" decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::local_space," + decorated + "> " decl += var_name + "(acc);\n" if memory == "private": source_name = "multiPtrSourceData" - decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n" + decl = var_type.name + " " + source_name + "(" + value + ");\n" decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space," + decorated + "> " decl += var_name + " = sycl::address_space_cast(&" decl += source_name + ");\n" @@ -136,6 +141,26 @@ def generate_multi_ptr(var_name, var_type, memory, decorated): def generate_variable(var_name, var_type, var_index): return var_type.name + " " + var_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n" +# argument generator for clamp which makes sure that its third argument is at least equal to its second argument in every dimension. +def generate_arguments_clamp(sig): + arg_types = sig.arg_types + arg_names = ["inputData_" + str(i) for i in range(3)] + arg0 = [str(generate_literal_value(arg_types[0].base_type)) + get_literal_suffix(arg_types[0].base_type) for _ in range(arg_types[0].dim)] + arg1 = [generate_literal_value(arg_types[1].base_type) for _ in range(arg_types[1].dim)] + arg2 = [generate_literal_value(arg_types[2].base_type) for _ in range(arg_types[2].dim)] + + # clamp requires that minval (arg1) <= maxval (arg2) + for i in range(arg_types[1].dim): + if arg1[i] > arg2[i]: + arg1[i], arg2[i] = arg2[i], arg1[i] # swap + + arg1 = [str(x) + get_literal_suffix(arg_types[1].base_type) for x in arg1] + arg2 = [str(x) + get_literal_suffix(arg_types[2].base_type) for x in arg2] + arg_vals = [arg0, arg1, arg2] + args = [arg_types[i].name + " " + arg_names[i] + "(" + ",".join(arg_vals[i]) + ");\n" for i in range(3)] + return (arg_names, " ".join(args)) + + def generate_arguments(sig, memory, decorated): arg_src = "" arg_names = [] @@ -156,7 +181,6 @@ def generate_arguments(sig, memory, decorated): current_arg = generate_multi_ptr(arg_name, arg, memory, decorated ) else: current_arg = generate_variable(arg_name, arg, arg_index) - arg_src += current_arg + " " arg_index += 1 return (arg_names, arg_src) @@ -283,7 +307,9 @@ def generate_reference_ptr(types, sig, arg_names, arg_src): def generate_test_case(test_id, types, sig, memory, check, decorated = ""): testCaseSource = test_case_templates_check[memory] if check else test_case_templates[memory] testCaseId = str(test_id) - (arg_names, arg_src) = generate_arguments(sig, memory, decorated) + # for the clamp function we use a separate argument generator to make sure that its preconditions are met, + # otherwise argument generation for clamp would be completely random. + (arg_names, arg_src) = generate_arguments(sig, memory, decorated) if sig.name != "clamp" else generate_arguments_clamp(sig) testCaseSource = testCaseSource.replace("$REFERENCE", generate_reference(sig, arg_names, arg_src)) testCaseSource = testCaseSource.replace("$PTR_REF", generate_reference_ptr(types, sig, arg_names, arg_src)) testCaseSource = testCaseSource.replace("$TEST_ID", testCaseId) diff --git a/util/math_helper.h b/util/math_helper.h index 3953050a7..9c4dbbeb4 100644 --- a/util/math_helper.h +++ b/util/math_helper.h @@ -162,6 +162,7 @@ run_func_on_vector_result_ref(funT fun, Args... args) { sycl::vec res; std::map undefined; for (int i = 0; i < N; i++) { + undefined[i] = false; resultRef element = fun(getElement(args, i)...); if (element.undefined.empty()) setElement(res, i, element.res);