Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Feb 11, 2025
1 parent 9a0da81 commit 6f5b773
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DecodeData:
attn_metadata: Optional[PallasMetadata] = None


class TPUModelRunner(ModelRunnerBase):
class TPUModelRunner:

def __init__(
self,
Expand Down Expand Up @@ -692,7 +692,7 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
# Update cached state
self.update_states(scheduler_output)
self._update_states(scheduler_output)

# If necessary, swap decodes/prompts to have all decodes on the start
ensure_decodes_first(self.input_batch)
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

Expand All @@ -19,14 +20,13 @@
from vllm.v1.outputs import ModelRunnerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.utils import bind_kv_cache

logger = init_logger(__name__)


class TPUWorker(WorkerBase):
class TPUWorker:

def __init__(
self,
Expand Down Expand Up @@ -212,6 +212,7 @@ def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return


def init_tpu_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
Expand Down

0 comments on commit 6f5b773

Please sign in to comment.