diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index 24483f939e3..ca2cad4f355 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -2131,13 +2131,45 @@ struct SIMD[type: DType, size: Int]( # Reduce operations # ===------------------------------------------------------------------=== # + alias _T = SIMD[type, _] + + # TODO: remove when non-capturing can be converted to capturing. @always_inline fn reduce[ - func: fn[type: DType, width: Int] ( - SIMD[type, width], SIMD[type, width] - ) capturing -> SIMD[type, width], + func: fn[width: Int] ( + Self._T[width], + Self._T[width], + ) -> Self._T[width], size_out: Int = 1, - ](self) -> SIMD[type, size_out]: + ](self) -> Self._T[size_out]: + """Reduces the vector using a provided reduce operator. + + Parameters: + func: The reduce function to apply to elements in this SIMD. + size_out: The width of the reduction. + + Constraints: + `size_out` must not exceed width of the vector. + + Returns: + A new scalar which is the reduction of all vector elements. + """ + + @always_inline + @parameter + fn body[w: Int](lhs: Self._T[w], rhs: Self._T[w]) -> Self._T[w]: + return func(lhs, rhs) + + return self.reduce[body, size_out]() + + @always_inline + fn reduce[ + func: fn[width: Int] ( + Self._T[width], + Self._T[width], + ) capturing -> Self._T[width], + size_out: Int = 1, + ](self) -> Self._T[size_out]: """Reduces the vector using a provided reduce operator. Parameters: @@ -2154,7 +2186,7 @@ struct SIMD[type: DType, size: Int]( @parameter if size == size_out: - return rebind[SIMD[type, size_out]](self) + return rebind[Self._T[size_out]](self) else: var lhs: Self._SIMDHalfType var rhs: Self._SIMDHalfType @@ -2162,7 +2194,7 @@ struct SIMD[type: DType, size: Int]( return func(lhs, rhs).reduce[func, size_out]() @always_inline("nodebug") - fn reduce_max[size_out: Int = 1](self) -> SIMD[type, size_out]: + fn reduce_max[size_out: Int = 1](self) -> Self._T[size_out]: """Reduces the vector using the `max` operator. Parameters: @@ -2185,14 +2217,12 @@ struct SIMD[type: DType, size: Int]( @always_inline @parameter - fn max_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: + fn body[ + width: Int + ](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]: return max(v1, v2) - return self.reduce[max_reduce_body, size_out]() + return self.reduce[body, size_out]() @parameter if type.is_floating_point(): @@ -2243,14 +2273,12 @@ struct SIMD[type: DType, size: Int]( @always_inline @parameter - fn min_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: + fn body[ + width: Int + ](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]: return min(v1, v2) - return self.reduce[min_reduce_body, size_out]() + return self.reduce[body, size_out]() @parameter if type.is_floating_point(): @@ -2291,15 +2319,7 @@ struct SIMD[type: DType, size: Int]( The sum of all vector elements. """ - - @always_inline - @parameter - fn add_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]: - return v1 + v2 - - return self.reduce[add_reduce_body, size_out]() + return self.reduce[Self._T.__add__, size_out]() @always_inline fn reduce_mul[size_out: Int = 1](self) -> SIMD[type, size_out]: @@ -2315,15 +2335,7 @@ struct SIMD[type: DType, size: Int]( Returns: The product of all vector elements. """ - - @always_inline - @parameter - fn mul_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]: - return v1 * v2 - - return self.reduce[mul_reduce_body, size_out]() + return self.reduce[Self._T.__mul__, size_out]() @always_inline fn reduce_and[size_out: Int = 1](self) -> SIMD[type, size_out]: @@ -2349,17 +2361,7 @@ struct SIMD[type: DType, size: Int]( @parameter if size_out > 1: - - @always_inline - @parameter - fn and_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1 & v2 - - return self.reduce[and_reduce_body, size_out]() + return self.reduce[Self._T.__and__, size_out]() @parameter if size == 1: @@ -2395,17 +2397,7 @@ struct SIMD[type: DType, size: Int]( @parameter if size_out > 1: - - @always_inline - @parameter - fn or_reduce_body[ - type: DType, width: Int - ](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[ - type, width - ]: - return v1 | v2 - - return self.reduce[or_reduce_body, size_out]() + return self.reduce[Self._T.__or__, size_out]() @parameter if size == 1: