diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 14873f891f..cfdcf10835 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -741,7 +741,7 @@ impl_arg_reduce_strided(OP, NAME, T) \ template struct MD { T m; - T d; + float d; constexpr MD() = default; constexpr MD() threadgroup = default; @@ -807,9 +807,9 @@ struct finalize_softmax { const uint thread_id, const uint stop_idx ) { - const T d_total_inverse = fast_divide(static_cast(1.0), md_total.d); + const float d_total_inverse = fast_divide(1.0, md_total.d); for (uint idx = thread_id; idx < stop_idx; idx += BLOCKSIZE) { - dst[idx] = fast_exp(src[idx] - md_total.m) * d_total_inverse; + dst[idx] = static_cast(fast_exp(src[idx] - md_total.m) * d_total_inverse); } } }; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 9976666415..d12167389f 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1095,7 +1095,7 @@ fn softmax() { let results = run_softmax(&v, last_dim, "softmax_f16"); assert_eq!( approx_f16(results, 4), - vec![0.0043, 0.0116, 0.0315, 0.0858, 0.233, 0.6333] + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2332, 0.6338] ); let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] @@ -1106,7 +1106,7 @@ fn softmax() { let results = run_softmax(&v, last_dim, "softmax_bf16"); assert_eq!( approx_bf16(results, 4), - vec![0.0043, 0.0117, 0.0317, 0.0864, 0.2334, 0.6367] + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] ); }