Skip to content

Commit

Permalink
Mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Apr 25, 2024
1 parent 2b8e301 commit 1d4afcc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def select(self, ksel: KernelSelection):
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(t1_desc.t.shape, t1_desc.t.dtype)
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t1, t2 = kb.arg_bindings
result_type = t1.type # type: ignore
result = Operation.create(
"tosa.add", results=[t1.type], operands=[t1, t2]
"tosa.add", results=[result_type], operands=[t1, t2]
).result
kb.yield_results(result)
8 changes: 5 additions & 3 deletions shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def enumerated_device_id(self) -> int:
@property
def enumerated_path(self) -> str:
try:
return self.enumerated_device_id["path"]
return self.enumerated_info["path"]
except KeyError as e:
raise RuntimeError("No enumerated path for device") from e

@property
def enumerated_name(self) -> str:
try:
return self.enumerated_device_id["name"]
return self.enumerated_info["name"]
except KeyError as e:
raise RuntimeError("No enumerated name for device") from e

Expand Down Expand Up @@ -351,7 +351,9 @@ def _device_export_torch_tensor_cuda_hip(
state = device._s
device_type_code = state.dlpack_device_type_code
assert device_type_code > 0
device_index = state.torch_device.index
torch_device = state.torch_device
assert torch_device is not None
device_index = torch_device.index
t = torch.from_dlpack(
device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index)
)
Expand Down

0 comments on commit 1d4afcc

Please sign in to comment.