Skip to content

Commit

Permalink
Add support to overload multiple ops in same handle_fn
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Feb 4, 2025
1 parent e080622 commit 5531c94
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
20 changes: 20 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register":
...


def gt(lhs: "Register", rhs: "Register") -> "Register":
...


def ge(lhs: "Register", rhs: "Register") -> "Register":
...


def lt(lhs: "Register", rhs: "Register") -> "Register":
...


def le(lhs: "Register", rhs: "Register") -> "Register":
...


def cast(src: "Register", dtype: DataType) -> "Register":
...

Expand Down Expand Up @@ -758,6 +774,10 @@ def infer_type(self):
@define_py_op(operator.ge)
@define_py_op(operator.lt)
@define_py_op(operator.le)
@define_interface_op("gt")
@define_interface_op("ge")
@define_interface_op("lt")
@define_interface_op("le")
@dataclass
class ComparisonPyOp(BinaryOpBase, ABC):
def infer_type(self):
Expand Down
36 changes: 16 additions & 20 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@
exp2,
extract,
extract_slice,
ge,
get_custom,
get_result,
gt,
le,
log2,
lt,
maximum,
minimum,
mma,
Expand Down Expand Up @@ -582,11 +586,17 @@ def get_constant_attr(value: Any, element_type: IrType) -> Attribute:
raise CodegenError(f"Cannot create a constant attribute for type `{element_type}`")


def handle_op(op: Callable[..., Any]):
def handle_op(op: Callable[..., Any] | list[Callable[..., Any]]):
def decorator(
f: Callable[[WaveEmitter, fx.Node], None]
) -> Callable[[WaveEmitter, fx.Node], None]:
WaveEmitter.OP_HANDLERS[op.__name__] = f
if isinstance(op, Callable):
WaveEmitter.OP_HANDLERS[op.__name__] = f
elif isinstance(op, list):
for op_iter in op:
WaveEmitter.OP_HANDLERS[op_iter.__name__] = f
else:
raise ValueError("handle_op only handle Callable or list of Callable")
return f

return decorator
Expand Down Expand Up @@ -1247,14 +1257,12 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.divsi(lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.divui(lhs, rhs)
else:
raise ValidationError(f"Found unhandled operand type for div: {element_type}")
return result


@handle_binary_op(operator.gt)
@handle_binary_op([operator.gt, gt])
def handle_gt(lhs: Value, rhs: Value) -> OpResult:
element_type = get_type_or_element_type(lhs.type)
if _is_float_type(element_type):
Expand All @@ -1263,14 +1271,12 @@ def handle_gt(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.cmpi(arith_d.CmpIPredicate.ugt, lhs, rhs)
else:
raise ValidationError(f"Found unhandled operand type for gt: {element_type}")
return result


@handle_binary_op(operator.ge)
@handle_binary_op([ge, operator.ge])
def handle_ge(lhs: Value, rhs: Value) -> OpResult:
element_type = get_type_or_element_type(lhs.type)
if _is_float_type(element_type):
Expand All @@ -1279,14 +1285,12 @@ def handle_ge(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.cmpi(arith_d.CmpIPredicate.sge, lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.cmpi(arith_d.CmpIPredicate.uge, lhs, rhs)
else:
raise ValidationError(f"Found unhandled operand type for ge: {element_type}")
return result


@handle_binary_op(operator.lt)
@handle_binary_op([operator.lt, lt])
def handle_lt(lhs: Value, rhs: Value) -> OpResult:
element_type = get_type_or_element_type(lhs.type)
if _is_float_type(element_type):
Expand All @@ -1295,14 +1299,12 @@ def handle_lt(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.cmpi(arith_d.CmpIPredicate.slt, lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.cmpi(arith_d.CmpIPredicate.ult, lhs, rhs)
else:
raise ValidationError(f"Found unhandled operand type for lt: {element_type}")
return result


@handle_binary_op(operator.le)
@handle_binary_op([operator.le, le])
def handle_le(lhs: Value, rhs: Value) -> OpResult:
element_type = get_type_or_element_type(lhs.type)
if _is_float_type(element_type):
Expand All @@ -1311,8 +1313,6 @@ def handle_le(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.cmpi(arith_d.CmpIPredicate.sle, lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.cmpi(arith_d.CmpIPredicate.ule, lhs, rhs)
else:
raise ValidationError(f"Found unhandled operand type for le: {element_type}")
return result
Expand All @@ -1327,8 +1327,6 @@ def handle_maximum(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed or element_type.is_signless
):
result = arith_d.maxsi(lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.maxui(lhs, rhs)
else:
raise ValidationError(
f"Found unhandled operand type for maximum: {element_type}"
Expand All @@ -1345,8 +1343,6 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult:
element_type.is_signed() or element_type.is_signless()
):
result = arith_d.minsi(lhs, rhs)
elif _is_integer_like_type(element_type) and element_type.is_unsigned:
result = arith_d.minui(lhs, rhs)
else:
raise ValidationError(
f"Found unhandled operand type for minimum: {element_type}"
Expand Down

0 comments on commit 5531c94

Please sign in to comment.