From ab7165f2c7ea358df969d68a0fb0ce9bb184a083 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 18 Aug 2024 01:15:10 -0700 Subject: [PATCH] [TPU] Optimize RoPE forward_native2 (#7636) --- .../model_executor/layers/rotary_embedding.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 7b3acd7f3c9ea..fa85f72e3dfaf 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: def _apply_rotary_emb( x: torch.Tensor, - freqs_cis: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: - x_ = torch.view_as_complex( - torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) - x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) - x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) - x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], - -1).transpose(1, 2) - return x_out + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + """ + orig_dtype = x.dtype + x = x.float() + x1, x2 = torch.chunk(x, 2, dim=-1) + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + return torch.cat((o1, o2), dim=-1).to(orig_dtype) class RotaryEmbedding(CustomOp): @@ -78,14 +86,10 @@ def __init__( self.dtype = dtype cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.use_native2 = current_platform.is_tpu() and is_neox_style - if not self.use_native2: - cache = cache.to(dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - else: - cos, sin = cache.chunk(2, dim=-1) - freqs_cis = cos + 1j * sin - self.register_buffer("freqs_cis", freqs_cis, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" @@ -173,28 +177,25 @@ def forward_native2( This method might perform better than `forward_native()` when compiled. """ - if positions.dim() == 1: - batch_size = 1 - seq_len = positions.shape[0] - else: - batch_size, seq_len = positions.shape if offsets is not None: positions = positions + offsets - freqs_cis = self.freqs_cis.index_select(0, positions.flatten()) - freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1) + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape - query = query.view(batch_size, seq_len, -1, self.head_size) + query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, freqs_cis) + query_rot = _apply_rotary_emb(query_rot, cos, sin) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape - key = key.view(batch_size, seq_len, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, freqs_cis) + key_rot = _apply_rotary_emb(key_rot, cos, sin) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key