From 7c01aa61059eca35362fb16a34ca9266c4b259a3 Mon Sep 17 00:00:00 2001 From: Hailey Collet <55606304+HaileyStorm@users.noreply.github.com> Date: Mon, 7 Oct 2024 16:32:53 -0600 Subject: [PATCH] Fix type issue introduced by #28 Commit #28 changed `apply_rotary_embed` to have dtype parameter with default float32, and forces attention softmax to be done float32. Since `attention` doesn't specify the dtype parameter when calling `apply_rotary_embed`, and output matmul doesn't convert back from float32 to match the values type, this is an issue if you're running BF16. This specifies the existing xq.dtype for the dtype parameter when calling `apply_rotary_embed` (alternatively, we could cast keys to float32 in `scores = torch.matmul(xq, keys)`), and converts the scores to match values at the output matmul. --- entropix/torch_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/entropix/torch_model.py b/entropix/torch_model.py index 0ebb3e9..6217dfe 100644 --- a/entropix/torch_model.py +++ b/entropix/torch_model.py @@ -42,7 +42,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po xq = F.linear(x, layer_weights.wq).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim) xk = F.linear(x, layer_weights.wk).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) xv = F.linear(x, layer_weights.wv).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=xq.dtype) keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep) xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim) keys = torch.permute(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen) @@ -55,7 +55,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE) padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE) scores = F.softmax(padded_logits, dim=-1).to(torch.float32) - output = torch.matmul(scores, values) + output = torch.matmul(scores.to(values.dtype), values) output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1) out = F.linear(output, layer_weights.wo) return out, kvcache, pre_scores @@ -77,4 +77,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output) - return logits, kvcache, scores, attn_stats \ No newline at end of file + return logits, kvcache, scores, attn_stats