Skip to content

Commit

Permalink
fix how sampling params are passing in through the nnsight api to the…
Browse files Browse the repository at this point in the history
… vllm executor
  • Loading branch information
AdamBelfki3 committed Nov 19, 2024
1 parent 7257be9 commit f9a5620
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/nnsight/modeling/vllm/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NNsightSamplingParams(SamplingParams):

intervention_graph: Optional[InterventionGraph] = None
invoker_group: Optional[int] = None
is_default_param: bool = True

def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
Expand Down
11 changes: 11 additions & 0 deletions src/nnsight/modeling/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def _prepare_input(
**kwargs,
)

if kwargs != {}:
param.is_default_param = False

prompts.append(prompt)
params.append(param)

Expand Down Expand Up @@ -265,6 +268,14 @@ def _execute(
**kwargs,
) -> Any:

kwargs.pop('invoker_group')

for param in params:
if param.is_default_param:
for attr, value in kwargs.items():
if hasattr(NNsightSamplingParams, attr):
setattr(param, attr, value)

self.vllm_entrypoint.generate(prompts, sampling_params=params)

if TYPE_CHECKING:
Expand Down
4 changes: 3 additions & 1 deletion src/nnsight/tracing/hacks/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
def handle_conditional(frame: FrameType, condition: "Proxy"):

line_no = frame.f_lineno
source_lines, _ = inspect.getsourcelines(frame)
source_file = inspect.getsourcefile(frame)
with open(source_file, "r") as file:
source_lines = file.readlines()
source = "".join(source_lines)
tree = ast.parse(source)

Expand Down
4 changes: 3 additions & 1 deletion src/nnsight/tracing/hacks/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
def handle_iterator(frame: FrameType, collection: "Proxy"):

line_no = frame.f_lineno
source_lines, _ = inspect.getsourcelines(frame)
source_file = inspect.getsourcefile(frame)
with open(source_file, "r") as file:
source_lines = file.readlines()
source = "".join(source_lines)
tree = ast.parse(source)

Expand Down

0 comments on commit f9a5620

Please sign in to comment.