From 0c2fa50b84dddd4866313ac37074255d29a94055 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 18 Aug 2024 00:18:53 -0700 Subject: [PATCH] [TPU] Use mark_dynamic only for dummy run (#7634) --- vllm/worker/tpu_model_runner.py | 76 ++++++++++++--------------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 86fc1d08c0812..14f14e40b4c0b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -144,7 +144,11 @@ def load_model(self) -> None: ) model = model.eval() xm.wait_device_ops() - self.model = CompiledModelWrapper(model) + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) def _dummy_run( self, @@ -206,9 +210,31 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) # Dummy run. - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, num_samples, kv_caches) @@ -682,52 +708,6 @@ def forward( return next_token_ids -class CompiledModelWrapper: - - def __init__(self, model: nn.Module): - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) - - def __call__( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, - input_lens: torch.Tensor, - t: torch.Tensor, - p: torch.Tensor, - num_samples: int, - kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], - ) -> torch.Tensor: - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if attn_metadata.num_prefills > 0: - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) - torch._dynamo.mark_dynamic(p, 0) - return self.model(token_ids, position_ids, attn_metadata, input_lens, - t, p, num_samples, kv_caches) - - def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest