Skip to content

Commit

Permalink
Improve f16/bf16 softmax precision by accumulating in f32
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jan 20, 2025
1 parent e8499c8 commit 4c94925
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ impl_arg_reduce_strided(OP, NAME, T) \
template <typename T>
struct MD {
T m;
T d;
float d;

constexpr MD<T>() = default;
constexpr MD<T>() threadgroup = default;
Expand Down Expand Up @@ -807,9 +807,9 @@ struct finalize_softmax {
const uint thread_id,
const uint stop_idx
) {
const T d_total_inverse = fast_divide(static_cast<T>(1.0), md_total.d);
const float d_total_inverse = fast_divide(1.0, md_total.d);
for (uint idx = thread_id; idx < stop_idx; idx += BLOCKSIZE) {
dst[idx] = fast_exp(src[idx] - md_total.m) * d_total_inverse;
dst[idx] = static_cast<T>(fast_exp(src[idx] - md_total.m) * d_total_inverse);
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ fn softmax() {
let results = run_softmax(&v, last_dim, "softmax_f16");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.233, 0.6333]
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2332, 0.6338]
);

let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
Expand All @@ -1106,7 +1106,7 @@ fn softmax() {
let results = run_softmax(&v, last_dim, "softmax_bf16");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0117, 0.0317, 0.0864, 0.2334, 0.6367]
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
);
}

Expand Down

0 comments on commit 4c94925

Please sign in to comment.