Skip to content

Commit

Permalink
feat (vllm): Implement vLLM wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Sep 23, 2024
1 parent 760d88b commit d63652a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
4 changes: 1 addition & 3 deletions src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def __init__(
self._custom_model = True
self._dispatched = True
self._model = model_key

# Otherwise load from _load(...).
if not self._custom_model:
else: # Otherwise load from _load(...).
# Load skeleton of model by putting all tensors on meta.
with init_empty_weights(include_buffers=meta_buffers):
self._model = self._load(self._model_key, *args, **kwargs)
Expand Down
90 changes: 90 additions & 0 deletions src/nnsight/models/VLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Any, List, Tuple, Union, Optional

from ..util import WrapperModule
from .mixins import GenerationMixin

try:
from vllm import LLM, RequestOutput
except Exception as e:

raise type(e)(
"Install vllm in your environment to use it with NNsight. " + \
"https://docs.vllm.ai/en/latest/getting_started/installation.html"
) from e

class VLLM(GenerationMixin):
''' NNsight wrapper to conduct interventions on a vLLM inference engine.
.. code-block:: python
from nnsight.models.VLLM import VLLM
from vllm import SamplingParams
model = VLLM("gpt2")
prompt = ["The Eiffel Tower is in the city of"]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, stop=["."])
with model.trace(prompt, sampling_params=sampling_params) as tracer:
model.model.transformer.h[8].output[-1][:] = 0
outputs = model.output.save()
for output in outputs.value:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
'''

class VLLModel(WrapperModule):
''' Pytorch Wrapper for the vLLM engine to work seamlessly with NNsight.
Attributes:
llm (vllm.LLM): vLLM inference engine instance.
model (torch.nn.Module): Underlying model of the vLLM instance.
'''

def __init__(self, *args, **kwargs) -> None:

super().__init__()

self.llm = LLM(*args, dtype="half", **kwargs)

self.model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model

def __init__(self, model_key: str, *args, **kwargs) -> None:

model_key = self._load(model_key, **kwargs)

super().__init__(model_key, *args, **kwargs)

def _load(self, repo_id: str, **kwargs) -> VLLModel:

model = VLLM.VLLModel(model=repo_id, **kwargs)

return model

def _execute(self, prepared_inputs: Union[List[str], str], *args, generate=True, **kwargs) -> List[RequestOutput]:

output = self._model.llm.generate(prepared_inputs, *args, use_tqdm=False, **kwargs)

output = self._model(output)

return output

def _prepare_inputs(self, *inputs: Union[List[str], str]) -> Tuple[Tuple[List[str]], int]:
if isinstance(inputs[0], list):
return inputs, len(inputs[0])
else:
return ([inputs[0]],), 1

def _batch_inputs(
self,
batched_inputs: Optional[Tuple[List[str]]],
prepared_inputs: List[str],
) -> Tuple[List[str]]:
breakpoint()
if batched_inputs is None:

return (prepared_inputs, )

return (batched_inputs[0] + prepared_inputs, )

0 comments on commit d63652a

Please sign in to comment.