Skip to content

Commit

Permalink
[stdlib] Simplify SIMD.reduce_op functions
Browse files Browse the repository at this point in the history
- Introduced a new `reduce` overload for non-capturing functions
- Simplify `SIMD.reduce_op` functions using `SIMD.op` methods directly

Signed-off-by: Yiwu Chen <[email protected]>
  • Loading branch information
soraros committed Jun 20, 2024
1 parent d96acc9 commit 2f1efe7
Showing 1 changed file with 31 additions and 68 deletions.
99 changes: 31 additions & 68 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2151,13 +2151,32 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# Reduce operations
# ===------------------------------------------------------------------=== #

alias _T = SIMD[type, _]

# TODO: remove when non-capturing can be converted to capturing.
@always_inline
fn reduce[
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) -> Self._T[width],
size_out: Int = 1,
](self) -> Self._T[size_out]:
@always_inline
@parameter
fn body[w: Int](lhs: Self._T[w], rhs: Self._T[w]) -> Self._T[w]:
return func(lhs, rhs)

return self.reduce[body, size_out]()

@always_inline
fn reduce[
func: fn[type: DType, width: Int] (
SIMD[type, width], SIMD[type, width]
) capturing -> SIMD[type, width],
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) capturing -> Self._T[width],
size_out: Int = 1,
](self) -> SIMD[type, size_out]:
](self) -> Self._T[size_out]:
"""Reduces the vector using a provided reduce operator.
Parameters:
Expand All @@ -2174,15 +2193,15 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](

@parameter
if size == size_out:
return rebind[SIMD[type, size_out]](self)
return rebind[Self._T[size_out]](self)
else:
var lhs: Self._SIMDHalfType
var rhs: Self._SIMDHalfType
lhs, rhs = self.split()
return func(lhs, rhs).reduce[func, size_out]()

@always_inline("nodebug")
fn reduce_max[size_out: Int = 1](self) -> SIMD[type, size_out]:
fn reduce_max[size_out: Int = 1](self) -> Self._T[size_out]:
"""Reduces the vector using the `max` operator.
Parameters:
Expand All @@ -2202,17 +2221,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](

@parameter
if is_x86() or size_out > 1:

@always_inline
@parameter
fn max_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1.max(v2)

return self.reduce[max_reduce_body, size_out]()
return self.reduce[Self._T.max, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2260,17 +2269,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](

@parameter
if is_x86() or size_out > 1:

@always_inline
@parameter
fn min_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1.min(v2)

return self.reduce[min_reduce_body, size_out]()
return self.reduce[Self._T.min, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2311,15 +2310,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
The sum of all vector elements.
"""

@always_inline
@parameter
fn add_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 + v2

return self.reduce[add_reduce_body, size_out]()
return self.reduce[Self._T.__add__, size_out]()

@always_inline
fn reduce_mul[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2335,15 +2326,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Returns:
The product of all vector elements.
"""

@always_inline
@parameter
fn mul_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 * v2

return self.reduce[mul_reduce_body, size_out]()
return self.reduce[Self._T.__mul__, size_out]()

@always_inline
fn reduce_and[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2369,17 +2352,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](

@parameter
if size_out > 1:

@always_inline
@parameter
fn and_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 & v2

return self.reduce[and_reduce_body, size_out]()
return self.reduce[Self._T.__and__, size_out]()

@parameter
if size == 1:
Expand Down Expand Up @@ -2415,17 +2388,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](

@parameter
if size_out > 1:

@always_inline
@parameter
fn or_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 | v2

return self.reduce[or_reduce_body, size_out]()
return self.reduce[Self._T.__or__, size_out]()

@parameter
if size == 1:
Expand Down

0 comments on commit 2f1efe7

Please sign in to comment.