Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First steps to enable SYCL backend in Python Interface #155

Open
wants to merge 22 commits into
base: sycl-develop
Choose a base branch
from

Conversation

sommerlukas
Copy link
Collaborator

@sommerlukas sommerlukas commented Nov 14, 2024

First implementation steps towards supporting the SYCL backend in the CUTLASS Python Interface.

The main additions from this PR are:

  • Generating a suitable GEMM template and arguments for the CUTLASS 3.x API and Intel PVC as target.
  • Calling DPC++ instead of nvcc to compile device and host code.
  • Using the DPCTL library to transfer data and launch the kernel via SYCL.

The support so far focuses on a simple GEMM, epilogues (e.g, with visitor) are not yet supported.

Compilation is currently only possible with development versions of DPC++, the -fsycl-rtc-mode flag that was added to support CUTLASS nested parameter classes in free-function kernels as part of this work is not yet available in releases.

The activation of the SYCL backend via environment variable is a temporary solution, a follow-up will look into a cleaner solution.

@sommerlukas sommerlukas self-assigned this Nov 27, 2024

math_instructions = [
MathInstruction(
[16, 8, 16],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be 8, 16, 16 to match the 8x16x16 (M,N,K) MMA operation for bfloat?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that probably makes sense. This value was probably based on values used for CUDA devices, makes sense to adapt it for PVC.

I changed it in the latest commit.

@@ -7026,6 +7026,47 @@ def GenerateSM90(manifest, cuda_version):

###################################################################################################

def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is cuda version here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the CUDA version, e.g., 12.4.0, defined here.

Right now, we don't use that parameter. If we come to a point where we need to make distinctions based on SYCL version or similar, we can change this to reflect a version that we need.

For now, we only have this parameter to be compatible with the expected interface here (via generate_function_name and generate_function).

def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version):
# TODO: Add remaining supported configurations
layouts = [
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 8 the alignment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so.

Copy link
Collaborator

@aacostadiaz aacostadiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good, thanks!!!

I left some questions but I think they will be more relevant for follow up PRs.

@sommerlukas sommerlukas force-pushed the python-interface-enable-sycl branch from f4b9079 to 19524ee Compare December 13, 2024 13:17
Signed-off-by: Lukas Sommer <[email protected]>
Copy link
Collaborator

@FMarno FMarno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. Hard to see individual issues, but I also don't really have a knowledge of the whole system.

Comment on lines 212 to 224
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
module = dpctl.program.create_program_from_spirv(q, cubin_image)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))

if self._is_sycl():
kernel = module.get_sycl_kernel(operation_name)
else:
err, kernel = cuda.cuModuleGetFunction(
module, bytes(str.encode(operation_name)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
module = dpctl.program.create_program_from_spirv(q, cubin_image)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if self._is_sycl():
kernel = module.get_sycl_kernel(operation_name)
else:
err, kernel = cuda.cuModuleGetFunction(
module, bytes(str.encode(operation_name)))
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
module = dpctl.program.create_program_from_spirv(q, cubin_image)
kernel = module.get_sycl_kernel(operation_name)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
err, kernel = cuda.cuModuleGetFunction(
module, bytes(str.encode(operation_name)))

if self.backend == "nvrtc":
# 3. compile
# 3. compile
if self.backend == "nvrtc": # with nvrtc backend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.backend == "nvrtc": # with nvrtc backend
if self.backend == "nvrtc":

@@ -303,6 +335,50 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))

elif self.backend == "dpcpp": # with DPC++ backend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif self.backend == "dpcpp": # with DPC++ backend
elif self.backend == "dpcpp":

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants