Skip to content

Commit

Permalink
[stdlib] Simplify SIMD.reduce_op functions
Browse files Browse the repository at this point in the history
- Introduced a new `reduce` overload for non-capturing functions
- Simplify `SIMD.reduce_op` functions using `SIMD.op` methods directly

Signed-off-by: Yiwu Chen <[email protected]>
  • Loading branch information
soraros committed Jul 30, 2024
1 parent d232e9f commit 963d261
Showing 1 changed file with 50 additions and 58 deletions.
108 changes: 50 additions & 58 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2009,13 +2009,45 @@ struct SIMD[type: DType, size: Int](
# Reduce operations
# ===------------------------------------------------------------------=== #

alias _T = SIMD[type, _]

# TODO: remove when non-capturing can be converted to capturing.
@always_inline
fn reduce[
func: fn[type: DType, width: Int] (
SIMD[type, width], SIMD[type, width]
) capturing -> SIMD[type, width],
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) -> Self._T[width],
size_out: Int = 1,
](self) -> SIMD[type, size_out]:
](self) -> Self._T[size_out]:
"""Reduces the vector using a provided reduce operator.
Parameters:
func: The reduce function to apply to elements in this SIMD.
size_out: The width of the reduction.
Constraints:
`size_out` must not exceed width of the vector.
Returns:
A new scalar which is the reduction of all vector elements.
"""

@always_inline
@parameter
fn body[w: Int](lhs: Self._T[w], rhs: Self._T[w]) -> Self._T[w]:
return func(lhs, rhs)

return self.reduce[body, size_out]()

@always_inline
fn reduce[
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) capturing -> Self._T[width],
size_out: Int = 1,
](self) -> Self._T[size_out]:
"""Reduces the vector using a provided reduce operator.
Parameters:
Expand All @@ -2032,15 +2064,15 @@ struct SIMD[type: DType, size: Int](

@parameter
if size == size_out:
return rebind[SIMD[type, size_out]](self)
return rebind[Self._T[size_out]](self)
else:
var lhs: Self._SIMDHalfType
var rhs: Self._SIMDHalfType
lhs, rhs = self.split()
return func(lhs, rhs).reduce[func, size_out]()

@always_inline("nodebug")
fn reduce_max[size_out: Int = 1](self) -> SIMD[type, size_out]:
fn reduce_max[size_out: Int = 1](self) -> Self._T[size_out]:
"""Reduces the vector using the `max` operator.
Parameters:
Expand All @@ -2063,14 +2095,12 @@ struct SIMD[type: DType, size: Int](

@always_inline
@parameter
fn max_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
fn body[
width: Int
](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]:
return max(v1, v2)

return self.reduce[max_reduce_body, size_out]()
return self.reduce[body, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2121,14 +2151,12 @@ struct SIMD[type: DType, size: Int](

@always_inline
@parameter
fn min_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
fn body[
width: Int
](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]:
return min(v1, v2)

return self.reduce[min_reduce_body, size_out]()
return self.reduce[body, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2169,15 +2197,7 @@ struct SIMD[type: DType, size: Int](
The sum of all vector elements.
"""

@always_inline
@parameter
fn add_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 + v2

return self.reduce[add_reduce_body, size_out]()
return self.reduce[Self._T.__add__, size_out]()

@always_inline
fn reduce_mul[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2193,15 +2213,7 @@ struct SIMD[type: DType, size: Int](
Returns:
The product of all vector elements.
"""

@always_inline
@parameter
fn mul_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 * v2

return self.reduce[mul_reduce_body, size_out]()
return self.reduce[Self._T.__mul__, size_out]()

@always_inline
fn reduce_and[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2227,17 +2239,7 @@ struct SIMD[type: DType, size: Int](

@parameter
if size_out > 1:

@always_inline
@parameter
fn and_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 & v2

return self.reduce[and_reduce_body, size_out]()
return self.reduce[Self._T.__and__, size_out]()

@parameter
if size == 1:
Expand Down Expand Up @@ -2273,17 +2275,7 @@ struct SIMD[type: DType, size: Int](

@parameter
if size_out > 1:

@always_inline
@parameter
fn or_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 | v2

return self.reduce[or_reduce_body, size_out]()
return self.reduce[Self._T.__or__, size_out]()

@parameter
if size == 1:
Expand Down

0 comments on commit 963d261

Please sign in to comment.