Skip to content

Commit

Permalink
[TPU] Use mark_dynamic only for dummy run (vllm-project#7634)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 18, 2024
1 parent ce14335 commit 0c2fa50
Showing 1 changed file with 28 additions and 48 deletions.
76 changes: 28 additions & 48 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ def load_model(self) -> None:
)
model = model.eval()
xm.wait_device_ops()
self.model = CompiledModelWrapper(model)
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)

def _dummy_run(
self,
Expand Down Expand Up @@ -206,9 +210,31 @@ def _dummy_run(
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1

# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if is_prompt:
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)

Expand Down Expand Up @@ -682,52 +708,6 @@ def forward(
return next_token_ids


class CompiledModelWrapper:

def __init__(self, model: nn.Module):
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)

def __call__(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
) -> torch.Tensor:
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if attn_metadata.num_prefills > 0:
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
return self.model(token_ids, position_ids, attn_metadata, input_lens,
t, p, num_samples, kv_caches)


def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
Expand Down

0 comments on commit 0c2fa50

Please sign in to comment.