From 2f1efe7eea70d22e8c47e1f5813884fa52221c97 Mon Sep 17 00:00:00 2001 From: Yiwu Chen <210at85@gmail.com> Date: Thu, 20 Jun 2024 00:00:51 +0000 Subject: [PATCH] [stdlib] Simplify `SIMD.reduce_op` functions - Introduced a new `reduce` overload for non-capturing functions - Simplify `SIMD.reduce_op` functions using `SIMD.op` methods directly Signed-off-by: Yiwu Chen <210at85@gmail.com> --- stdlib/src/builtin/simd.mojo | 99 +++++++++++------------------------- 1 file changed, 31 insertions(+), 68 deletions(-) diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index f0cc3e4dffe..d8dda53076f 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -2151,13 +2151,32 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( # Reduce operations # ===------------------------------------------------------------------=== # + alias _T = SIMD[type, _] + + # TODO: remove when non-capturing can be converted to capturing. + @always_inline + fn reduce[ + func: fn[width: Int] ( + Self._T[width], + Self._T[width], + ) -> Self._T[width], + size_out: Int = 1, + ](self) -> Self._T[size_out]: + @always_inline + @parameter + fn body[w: Int](lhs: Self._T[w], rhs: Self._T[w]) -> Self._T[w]: + return func(lhs, rhs) + + return self.reduce[body, size_out]() + @always_inline fn reduce[ - func: fn[type: DType, width: Int] ( - SIMD[type, width], SIMD[type, width] - ) capturing -> SIMD[type, width], + func: fn[width: Int] ( + Self._T[width], + Self._T[width], + ) capturing -> Self._T[width], size_out: Int = 1, - ](self) -> SIMD[type, size_out]: + ](self) -> Self._T[size_out]: """Reduces the vector using a provided reduce operator. Parameters: @@ -2174,7 +2193,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( @parameter if size == size_out: - return rebind[SIMD[type, size_out]](self) + return rebind[Self._T[size_out]](self) else: var lhs: Self._SIMDHalfType var rhs: Self._SIMDHalfType @@ -2182,7 +2201,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( return func(lhs, rhs).reduce[func, size_out]() @always_inline("nodebug") - fn reduce_max[size_out: Int = 1](self) -> SIMD[type, size_out]: + fn reduce_max[size_out: Int = 1](self) -> Self._T[size_out]: """Reduces the vector using the `max` operator. Parameters: @@ -2202,17 +2221,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( @parameter if is_x86() or size_out > 1: - - @always_inline - @parameter - fn max_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1.max(v2) - - return self.reduce[max_reduce_body, size_out]() + return self.reduce[Self._T.max, size_out]() @parameter if type.is_floating_point(): @@ -2260,17 +2269,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( @parameter if is_x86() or size_out > 1: - - @always_inline - @parameter - fn min_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1.min(v2) - - return self.reduce[min_reduce_body, size_out]() + return self.reduce[Self._T.min, size_out]() @parameter if type.is_floating_point(): @@ -2311,15 +2310,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( The sum of all vector elements. """ - - @always_inline - @parameter - fn add_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]: - return v1 + v2 - - return self.reduce[add_reduce_body, size_out]() + return self.reduce[Self._T.__add__, size_out]() @always_inline fn reduce_mul[size_out: Int = 1](self) -> SIMD[type, size_out]: @@ -2335,15 +2326,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( Returns: The product of all vector elements. """ - - @always_inline - @parameter - fn mul_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]: - return v1 * v2 - - return self.reduce[mul_reduce_body, size_out]() + return self.reduce[Self._T.__mul__, size_out]() @always_inline fn reduce_and[size_out: Int = 1](self) -> SIMD[type, size_out]: @@ -2369,17 +2352,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( @parameter if size_out > 1: - - @always_inline - @parameter - fn and_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1 & v2 - - return self.reduce[and_reduce_body, size_out]() + return self.reduce[Self._T.__and__, size_out]() @parameter if size == 1: @@ -2415,17 +2388,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( @parameter if size_out > 1: - - @always_inline - @parameter - fn or_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1 | v2 - - return self.reduce[or_reduce_body, size_out]() + return self.reduce[Self._T.__or__, size_out]() @parameter if size == 1: