Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device bridge support for HIP and CUDA. #3

Merged
merged 7 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import cast

from ..support.ir_imports import (
Operation,
RankedTensorType,
StringAttr,
Value,
Expand Down Expand Up @@ -60,3 +61,24 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
key = cast(AttrArg, ksel.arg_descs[0])
_emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]])
kb.yield_results(kb.arg_bindings[1])


@CustomOp.register(library=IREE_LIBRARY)
class _test_add(CustomOp):
signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t1, t2 = kb.arg_bindings
result_type = t1.type # type: ignore
result = Operation.create(
"tosa.add", results=[result_type], operands=[t1, t2]
).result
kb.yield_results(result)
164 changes: 158 additions & 6 deletions shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import lru_cache
from typing import Callable, Optional, Union
from threading import local, Lock
import warnings

import torch

Expand Down Expand Up @@ -73,6 +74,9 @@ class DeviceState:
"device",
"driver",
"instance",
"enumerated_info",
"torch_device",
"dlpack_device_type_code",
]

def __init__(
Expand All @@ -81,10 +85,37 @@ def __init__(
driver: Union[str, HalDriver],
device: Optional[HalDevice] = None,
vm_instance: Optional[VmInstance] = None,
enumerated_info: Optional[dict] = None,
torch_device: Optional[torch.device] = None,
dlpack_device_type_code: int = 0,
):
self.instance = vm_instance or get_vm_instance()
self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver)
self.device = device if device else self.driver.create_default_device()
self.enumerated_info = enumerated_info or {}
self.torch_device = torch_device
self.dlpack_device_type_code = dlpack_device_type_code

@property
def enumerated_device_id(self) -> int:
try:
return self.enumerated_info["device_id"]
except KeyError as e:
raise RuntimeError("No enumerated device_id for device") from e

@property
def enumerated_path(self) -> str:
try:
return self.enumerated_info["path"]
except KeyError as e:
raise RuntimeError("No enumerated path for device") from e

@property
def enumerated_name(self) -> str:
try:
return self.enumerated_info["name"]
except KeyError as e:
raise RuntimeError("No enumerated name for device") from e

@staticmethod
@lru_cache(maxsize=None)
Expand Down Expand Up @@ -139,7 +170,10 @@ class Device:
compile_target_flags: tuple[str, ...]

def __new__(
cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None
cls,
uri: Optional[str] = None,
*,
device_state: Optional[DeviceState] = None,
):
if uri is not None:
# Construction by URI is cached on the thread.
Expand Down Expand Up @@ -243,6 +277,9 @@ def clear(self):
...
raise MismatchedDeviceSetClearError()

def dump_device_info(self) -> str:
return self._s.driver.dump_device_info(self._s.enumerated_device_id)

def __repr__(self):
return f"<Turbine Device: {self._s.device}>"

Expand All @@ -256,6 +293,11 @@ def __exit__(self, type, value, traceback):
_CURRENT_THREAD.stack.pop()


################################################################################
# CPU import/export
################################################################################


def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBufferView:
hal_device = device.hal_device
element_type = dtype_to_element_type(t.dtype)
Expand Down Expand Up @@ -284,20 +326,65 @@ def _device_export_torch_tensor_cpu(
return torch.from_numpy(mapped_array)


################################################################################
# CUDA and HIP import/export
################################################################################


def _device_import_torch_tensor_cuda_hip(
device: Device, t: torch.Tensor
) -> HalBufferView:
# We currently only support contiguous, so ensure that.
if not t.is_contiguous():
t = t.contiguous()
# TODO: The 'None' here tells the producer to synchronize on the default
# stream. For async, we should advance our timeline and signal when an
# event is raised on Torch's stream at the current position.
capsule = t.__dlpack__(None)
bv = device.hal_device.from_dlpack_capsule(capsule)
return bv


def _device_export_torch_tensor_cuda_hip(
device: Device, bv: HalBufferView, like: torch.Tensor
) -> torch.Tensor:
state = device._s
device_type_code = state.dlpack_device_type_code
assert device_type_code > 0
torch_device = state.torch_device
assert torch_device is not None
device_index = torch_device.index
t = torch.from_dlpack(
device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index)
)
if t.dtype != like.dtype:
t = t.view(like.dtype)
# TODO: For async, we should enqueue an event on Torch's stream which will
# signal when this tensor is produced (i.e. at the current point in our
# timeline).
return t


# Mapping of torch tensor importers keyed by driver name.
TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = {
"cuda": _device_import_torch_tensor_cuda_hip,
"hip": _device_import_torch_tensor_cuda_hip,
"local-sync": _device_import_torch_tensor_cpu,
"local-task": _device_import_torch_tensor_cpu,
}

TORCH_TENSOR_EXPORTERS: dict[
str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor]
] = {
"cuda": _device_export_torch_tensor_cuda_hip,
"hip": _device_export_torch_tensor_cuda_hip,
"local-sync": _device_export_torch_tensor_cpu,
"local-task": _device_export_torch_tensor_cpu,
}

DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = {
"cuda": ("--iree-hal-target-backends=cuda",),
"hip": ("--iree-hal-target-backends=rocm",),
"local-task": (
"--iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu-features=host",
Expand Down Expand Up @@ -357,14 +444,79 @@ def get_device_from_torch(torch_device: torch.device) -> Device:

def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]:
torch_type = torch_device.type
uri = None
if torch_type == "cpu":
uri = "local-task"
cpu_driver = get_driver("local-task")
cpu_enumerated = cpu_driver.query_available_devices()
assert len(cpu_enumerated) >= 1
cpu_default = cpu_enumerated[0]
cpu_device_state = DeviceState(
driver=cpu_driver,
device=cpu_driver.create_default_device(),
enumerated_info=cpu_default,
torch_device=torch_device,
dlpack_device_type_code=1,
)
return Device(device_state=cpu_device_state)
elif torch_type == "cuda":
# Fork based on HIP or real CUDA.
props = torch.cuda.get_device_properties(torch_device)
if not hasattr(props, "gcnArchName"):
# Real CUDA.
return _create_cuda_device(torch_device, props)
else:
# HIP as CUDA.
return _create_hip_device(torch_device, props)

return None


def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]:
# Note that the dlpack device type code for real CUDA ROCM is 2.
device = _create_cuda_like_device(torch_device, props, "hip", 2)
if device:
device.compile_target_flags = device.compile_target_flags + (
f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}",
)
return device


def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]:
# Note that the dlpack device type code for ROCM is 10.
device = _create_cuda_like_device(torch_device, props, "hip", 10)
if device:
gcn_arch_name = props.gcnArchName
device.compile_target_flags = device.compile_target_flags + (
f"--iree-rocm-target-chip={gcn_arch_name}",
)
return device

if uri is None:
return None

return Device(uri)
def _create_cuda_like_device(
torch_device: torch.device, props, driver_name: str, dlpack_device_type_code: int
) -> Optional[Device]:
if torch.cuda.device_count() > 1:
warnings.warn(
f"Multiple {driver_name} devices detected: Turbine does not yet "
f"guarantee stable device mapping"
)

requested_index = torch_device.index
driver = get_driver(driver_name)
available_infos = driver.query_available_devices()
if requested_index >= len(available_infos):
return None
device_info = available_infos[requested_index]
hal_device = driver.create_device(device_info)
device_state = DeviceState(
driver=driver,
device=hal_device,
vm_instance=get_vm_instance(),
enumerated_info=device_info,
torch_device=torch_device,
dlpack_device_type_code=dlpack_device_type_code,
)
device = Device(device_state=device_state)
return device


###############################################################################
Expand Down
9 changes: 4 additions & 5 deletions shark_turbine/runtime/op_reg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,10 @@ def def_library(ns) -> torch.library.Library:


def default_dispatch_keys() -> list[str]:
# TODO: Dynamically determine what devices to register against.
# Note that we have to register against specific keys instead of the
# fallback, as fallback is too broad and breaks certain elements of
# fx tracing.
return ["CPU"]
keys = ["CPU"]
if torch.cuda.is_available():
keys.append("CUDA")
return keys


# All such custom kernels are registered in the 'turbine' library/namespace.
Expand Down
4 changes: 3 additions & 1 deletion shark_turbine/runtime/op_reg/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def compile_standalone_kernel(
with kb.ip, Location.unknown():
ksel.op.generate(ksel, kb)
kb.module_op.verify()
# DO NOT SUBMIT: https://github.com/iree-org/iree/issues/17132
enable_debug_info = False
module_asm = kb.module_op.get_asm(
binary=True, enable_debug_info=True, print_generic_op_form=False
binary=True, enable_debug_info=enable_debug_info, print_generic_op_form=False
)
generation_time = default_timer() - start

Expand Down
47 changes: 47 additions & 0 deletions tests/runtime/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import unittest
import threading
import warnings

import torch

Expand Down Expand Up @@ -122,6 +123,52 @@ def testCompilerFlags(self):
self.assertIn("--iree-hal-target-backends=llvm-cpu", d.compile_target_flags)


# Make CUDA testing conditional.
test_cuda_device = None
if torch.cuda.is_available():
try:
test_cuda_device = torch.device("cuda:0")
except:
...
if test_cuda_device is None:
warnings.warn("Not testing CUDA interop (device not available)")


@unittest.skipUnless(test_cuda_device, "CUDA not available")
class TorchCUDAInterop(unittest.TestCase):
def setUp(self):
self._cuda_props = torch.cuda.get_device_properties(test_cuda_device)
print(self._cuda_props)
print(dir(self._cuda_props))
self._is_hip = False
if hasattr(self._cuda_props, "gcnArchName"):
print("Detected HIP device as CUDA")
self._is_hip = True

def testFromTorchDevice(self):
torch_device = torch.device("cuda:0")
device = get_device_from_torch(torch_device)
print(device.dump_device_info())

def testJit(self):
from shark_turbine.ops import iree as iree_ops

t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0")
result = iree_ops._test_add(t, t)
expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu")
torch.testing.assert_close(result.cpu(), expected)


class TorchCPUInterop(unittest.TestCase):
def testJit(self):
from shark_turbine.ops import iree as iree_ops

t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu")
result = iree_ops._test_add(t, t)
expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu")
torch.testing.assert_close(result, expected)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
Loading