Skip to content

Commit

Permalink
Specify subgroup size
Browse files Browse the repository at this point in the history
Signed-off-by: Lukas Sommer <[email protected]>
  • Loading branch information
sommerlukas committed Jan 17, 2025
1 parent 1e63399 commit 58d6cb3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
30 changes: 25 additions & 5 deletions python/cutlass/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,32 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op

# Find SPIR-V device code in temporary directory
spv_files = list(pathlib.Path(temp_dump_dir.name).glob("*.spv"))
if len(spv_files) != 1:
raise RuntimeError("More than one SPIR-V files generated")


# TODO(Lukas): Bring this back then we delete the loop to find the right kernel.
# if len(spv_files) != 1:
# raise RuntimeError("More than one SPIR-V files generated")

# TODO(Lukas): This is a temporary solution to be removed.
# When specifying a specific subgroup size, DPC++ currently
# generates multiple SPIR-V files. We create a kernel from each of
# them to find the one with the correct subgroup size. This is
# rather efficient, as we need to create all these programs first.
q = dpctl.SyclQueue(cutlass.sycl_device())
op_name = f"__sycl_kernel_{operation_list[0].name()}"
for f in spv_files:
with open(f, "rb") as spirv_file:
spirv_image = spirv_file.read()
program = dpctl.program.create_program_from_spirv(q, spirv_image)
spirv_kernel = program.get_sycl_kernel(op_name)
print(spirv_kernel.max_sub_group_size)
if spirv_kernel.max_sub_group_size == 16:
cubin_image = spirv_image
break

# TODO(Lukas): Bring this back when we delete the loop to find the right kernel.
# Load the SPIR-V image
with open(spv_files[0], "rb") as file:
cubin_image = file.read()
# with open(spv_files[0], "rb") as file:
# cubin_image = file.read()

else: # with nvcc backend
# emit code
Expand Down
2 changes: 1 addition & 1 deletion python/cutlass/backend/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ class GemmRTUniversal3x(GemmRTUniversal):
#if defined(CUTLASS_ENABLE_SYCL)
SYCL_EXTERNAL SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
(sycl::ext::oneapi::experimental::nd_range_kernel<
3>))
3>)) [[sycl::reqd_sub_group_size(16)]]
void ${operation_name}(typename Operator::Params const params,
sycl::ext::oneapi::experimental::work_group_memory<char[]> mem) {
auto* smem = &mem[0];
Expand Down

0 comments on commit 58d6cb3

Please sign in to comment.