Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #222

Merged
merged 10 commits into from
Sep 1, 2024
2 changes: 1 addition & 1 deletion docs/source/notebooks/tutorials/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
},
"source": [
"An interactive version of this walkthrough can be found\n",
"[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthough.ipynb)\n",
"[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthrough.ipynb)\n",
"\n",
"In this era of large-scale deep learning, the most interesting AI models are\n",
"massive black boxes that are hard to run. Ordinary commercial inference service\n",
Expand Down
51 changes: 28 additions & 23 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ def apply(
Returns:
InterventionProxy: Proxy of applying that function.
"""



proxy_value = inspect._empty

if validate is False:

proxy_value = None

return self.graph.create(
target=target,
proxy_value=proxy_value,
Expand Down Expand Up @@ -182,12 +181,12 @@ def list(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable list."""

return self.apply(list, *args, **kwargs)

def set(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable set."""

return self.apply(set, *args, **kwargs)

def dict(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable dictionary."""

Expand Down Expand Up @@ -247,15 +246,17 @@ def bridge_backend_handle(self, bridge: Bridge) -> None:
from torch.utils import data


def global_patch(fn):
def global_patch(fn, applied_fn=None):

if applied_fn is None:

applied_fn = fn

@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
fn,
*args,
**kwargs
applied_fn, *args, **kwargs
)

return inner
Expand All @@ -272,8 +273,18 @@ class GlobalTracingContext(GraphBasedContext):
TORCH_HANDLER: GlobalTracingContext.GlobalTracingTorchHandler
PATCHER: Patcher = Patcher(
[
Patch(torch.nn, global_patch(torch.nn.Parameter), "Parameter"),
Patch(data, global_patch(data.DataLoader), "DataLoader"),
Patch(
torch.nn.Parameter,
global_patch(
torch.nn.Parameter.__init__, applied_fn=torch.nn.Parameter
),
"__init__",
),
Patch(
data.DataLoader,
global_patch(data.DataLoader.__init__, applied_fn=data.DataLoader),
"__init__",
),
Patch(torch, global_patch(torch.arange), "arange"),
Patch(torch, global_patch(torch.empty), "empty"),
Patch(torch, global_patch(torch.eye), "eye"),
Expand All @@ -288,7 +299,7 @@ class GlobalTracingContext(GraphBasedContext):
Patch(torch, global_patch(torch.zeros), "zeros"),
]
+ [
Patch(torch.optim, global_patch(value), key)
Patch(value, global_patch(value.__init__, applied_fn=value), "__init__")
for key, value in getmembers(torch.optim, isclass)
if issubclass(value, torch.optim.Optimizer)
]
Expand All @@ -304,9 +315,7 @@ def __torch_function__(self, func, types, args, kwargs=None):

if "_VariableFunctionsClass" in func.__qualname__:
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
func,
*args,
**kwargs
func, *args, **kwargs
)

return func(*args, **kwargs)
Expand Down Expand Up @@ -391,9 +400,7 @@ def register(graph_based_context: GraphBasedContext) -> None:

assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is None

GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = (
graph_based_context.graph
)
GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = graph_based_context.graph

GlobalTracingContext.TORCH_HANDLER.__enter__()
GlobalTracingContext.PATCHER.__enter__()
Expand Down Expand Up @@ -440,6 +447,4 @@ def __getattribute__(self, name: str) -> Any:


GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()
GlobalTracingContext.TORCH_HANDLER = (
GlobalTracingContext.GlobalTracingTorchHandler()
)
GlobalTracingContext.TORCH_HANDLER = GlobalTracingContext.GlobalTracingTorchHandler()
201 changes: 37 additions & 164 deletions src/nnsight/models/DiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from diffusers import DiffusionPipeline, SchedulerMixin
from PIL import Image
from transformers import BatchEncoding, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from transformers import BatchEncoding
from typing_extensions import Self

from .. import util
from ..envoy import Envoy
from .mixins import GenerationMixin
from .NNsightModel import NNsight
from torch._guards import detect_fake_mode
from .. import util


class Diffuser(util.WrapperModule):
def __init__(self, *args, **kwargs) -> None:
Expand All @@ -23,135 +24,15 @@ def __init__(self, *args, **kwargs) -> None:
setattr(self, key, value)

self.tokenizer = self.pipeline.tokenizer

@torch.no_grad()
def scan(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):

# 0. Default height and width to unet
height = (
height
or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor
)
width = (
width
or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor
)

# 1. Check inputs. Raise error if not correct
self.pipeline.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt(
prompt,
"meta",
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
timesteps = self.pipeline.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.pipeline.unet.config.in_channels
latents = self.pipeline.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
"meta",
generator,
latents,
)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.pipeline.prepare_extra_step_kwargs(generator, eta)

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)

# predict the noise residual
noise_pred = self.pipeline.unet(
latent_model_input,
0,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)

if not output_type == "latent":
image = self.pipeline.vae.decode(
latents / self.pipeline.vae.config.scaling_factor, return_dict=False
)[0]
else:
image = latents
has_nsfw_concept = None
class DiffusionModel(GenerationMixin, NNsight):

def __new__(cls, *args, **kwargs) -> Self | Envoy:
return object.__new__(cls)

class DiffusionModel(GenerationMixin, NNsight):
def __init__(self, *args, **kwargs) -> None:

self._model: Diffuser = None

super().__init__(*args, **kwargs)
Expand All @@ -162,7 +43,6 @@ def _load(self, repo_id: str, device_map=None, **kwargs) -> Diffuser:

model = Diffuser(
repo_id,
trust_remote_code=True,
device_map=None,
low_cpu_mem_usage=False,
**kwargs,
Expand All @@ -178,57 +58,50 @@ def _prepare_inputs(
self,
inputs: Union[str, List[str]],
) -> Any:

if isinstance(inputs, str):
inputs = [inputs]

return (inputs,), len(inputs)

# def _forward(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None:
# text_tokens, latents = inputs

# text_embeddings = self.meta_model.get_text_embeddings(text_tokens, n_imgs)

# latents = torch.cat([latents] * 2).to("meta")
return (inputs,), len(inputs)

# return self.meta_model.unet(
# latents,
# torch.zeros((1,), device="meta"),
# encoder_hidden_states=text_embeddings,
# ).sample
def _batch_inputs(
self,
batched_inputs: Optional[Dict[str, Any]],
prepared_inputs: BatchEncoding,
) -> torch.Tensor:

def _batch_inputs(self, batched_inputs: Optional[Dict[str, Any]],
prepared_inputs: BatchEncoding,) -> torch.Tensor:

if batched_inputs is None:

return prepared_inputs

return batched_inputs + prepared_inputs

def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):

device = next(self._model.parameters()).device
def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):

return self._model.unet(
prepared_inputs,
*args
*args,
**kwargs,
)

def _execute_generate(
self, prepared_inputs: Any, *args, **kwargs
self, prepared_inputs: Any, *args, seed: int = None, **kwargs
):
device = next(self._model.parameters()).device

if detect_fake_mode(prepared_inputs):

output = self._model.scan(*prepared_inputs)

else:

output = self._model.pipeline(prepared_inputs, *args, **kwargs)


if self._scanning():

kwargs["num_inference_steps"] = 1

generator = torch.Generator()

if seed is not None:

generator = generator.manual_seed(seed)

output = self._model.pipeline(
prepared_inputs, *args, generator=generator, **kwargs
)

output = self._model(output)

return output
Loading
Loading