From dbfe3d91bfc3568534f0d55caee3792b0f1202fa Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 17 Jul 2024 17:59:54 +0800 Subject: [PATCH] rope positions need higher precision Signed-off-by: Yu Chin Fabian Lim --- .../hf_models/modeling_utils/position_embedding/rope.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py index 7dd9232..1dd5bd6 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py @@ -55,9 +55,7 @@ def _set_cos_sin_cache( self, seq_len: int, device: torch.device, dtype: torch.dtype ) -> None: self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation