diff --git a/entropix/torch_model.py b/entropix/torch_model.py index 0ebb3e9..6217dfe 100644 --- a/entropix/torch_model.py +++ b/entropix/torch_model.py @@ -42,7 +42,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po xq = F.linear(x, layer_weights.wq).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim) xk = F.linear(x, layer_weights.wk).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) xv = F.linear(x, layer_weights.wv).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=xq.dtype) keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep) xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim) keys = torch.permute(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen) @@ -55,7 +55,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE) padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE) scores = F.softmax(padded_logits, dim=-1).to(torch.float32) - output = torch.matmul(scores, values) + output = torch.matmul(scores.to(values.dtype), values) output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1) out = F.linear(output, layer_weights.wo) return out, kvcache, pre_scores @@ -77,4 +77,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output) - return logits, kvcache, scores, attn_stats \ No newline at end of file + return logits, kvcache, scores, attn_stats