Skip to content

Commit

Permalink
nnsight.VLLM is loaded first on the 'meta' device
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Oct 4, 2024
1 parent d63652a commit 066281e
Showing 1 changed file with 67 additions and 11 deletions.
78 changes: 67 additions & 11 deletions src/nnsight/models/VLLM.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Any, List, Tuple, Union, Optional
from typing import List, Optional, Tuple, Union

from ..util import WrapperModule
from .mixins import GenerationMixin

try:
from vllm import LLM, RequestOutput
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.loader import _initialize_model
except Exception as e:

raise type(e)(
Expand Down Expand Up @@ -39,33 +45,84 @@ class VLLModel(WrapperModule):
''' Pytorch Wrapper for the vLLM engine to work seamlessly with NNsight.
Attributes:
llm (vllm.LLM): vLLM inference engine instance.
model (torch.nn.Module): Underlying model of the vLLM instance.
llm_engine (vllm.LLM): vLLM inference engine instance.
'''

def __init__(self, *args, **kwargs) -> None:
def __init__(self, model, llm_engine = None) -> None:

super().__init__()

self.llm = LLM(*args, dtype="half", **kwargs)
self.model = model

self.model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
self.llm_engine = llm_engine

def __init__(self, model_key: str, *args, **kwargs) -> None:

model_key = self._load(model_key, **kwargs)

super().__init__(model_key, *args, **kwargs)

def _load(self, repo_id: str, **kwargs) -> VLLModel:

model = VLLM.VLLModel(model=repo_id, **kwargs)
if self._model is None:

# no parallelism during initialization
kwargs["tensor_parallel_size"] = 1
kwargs["pipeline_parallel_size"] = 1

# creating vLLM Engine args
engine_args = EngineArgs(
model=repo_id,
**kwargs,
)

# creating the vllm engine configuration
engine_config_dict = engine_args.create_engine_config().to_dict()

# starting the distributed environment
init_distributed_environment(
engine_config_dict["parallel_config"].world_size,
0,
'tcp://127.0.0.1:47303',
0,
backend="nccl"
)

# start tensor parallel group
initialize_model_parallel(
engine_config_dict["parallel_config"].tensor_parallel_size,
engine_config_dict["parallel_config"].pipeline_parallel_size,
'nccl'
)

# initialize the model
model = _initialize_model(
model_config=engine_config_dict["model_config"],
load_config=engine_config_dict["load_config"],
lora_config=None,
cache_config=engine_config_dict["cache_config"],
scheduler_config=engine_config_dict["scheduler_config"]
)

return VLLM.VLLModel(model)
else:

# destroy the distributed environment created from the initial model initialization
destroy_model_parallel()
destroy_distributed_environment()

if "tensor_parallel_size" in kwargs.keys():
if kwargs["tensor_parallel_size"] > 1:
raise Exception("Tensor Parallelism currently not supported with nnsight.VLLM")

llm = LLM(repo_id, **kwargs)

model = llm.llm_engine.model_executor.driver_worker.model_runner.model

return model
return VLLM.VLLModel(model, llm)

def _execute(self, prepared_inputs: Union[List[str], str], *args, generate=True, **kwargs) -> List[RequestOutput]:

output = self._model.llm.generate(prepared_inputs, *args, use_tqdm=False, **kwargs)
output = self._model.llm_engine.generate(prepared_inputs, *args, use_tqdm=False, **kwargs)

output = self._model(output)

Expand All @@ -82,7 +139,6 @@ def _batch_inputs(
batched_inputs: Optional[Tuple[List[str]]],
prepared_inputs: List[str],
) -> Tuple[List[str]]:
breakpoint()
if batched_inputs is None:

return (prepared_inputs, )
Expand Down

0 comments on commit 066281e

Please sign in to comment.