From 09eec2f99894a8de3f0e551b4b2d138367aeb88b Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Mon, 6 Jan 2025 17:48:30 +0300 Subject: [PATCH] check for allowed components in from_pipe method --- multigen/pipes.py | 21 +++++++++++++++------ multigen/util.py | 7 +++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/multigen/pipes.py b/multigen/pipes.py index 91e5bba..396e561 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -539,13 +539,22 @@ def __init__(self, *args, pipe: Optional[StableDiffusionImg2ImgPipeline] = None, def _from_pipe(self, pipe, **args): cls = pipe.__class__ if 'StableDiffusionXLPipeline' in str(cls) : - return self._classxl(**pipe.components, **args) + return self.__verify_from_pipe(self._classxl, pipe, **args) elif 'StableDiffusionPipeline' in str(cls): - return self._class(**pipe.components, **args) + return self.__verify_from_pipe(self._class, pipe, **args) elif 'Flux' in str(cls): - return self._classflux(**pipe.components, **args) + return self.__verify_from_pipe(self._classflux, pipe, **args) raise RuntimeError(f"can't load pipeline from type {cls}") + def __verify_from_pipe(self, cls, pipe, **args): + allowed = util.get_allowed_components(cls) + source_components = set(pipe.components.keys()) + target_components = set(allowed) + + logging.debug("Missing components: ", target_components - source_components) + logging.debug("Extra components: ", source_components - target_components) + return cls(**{k: v for (k, v) in pipe.components.items() if k in allowed}, **args) + def setup(self, image=None, image_painted=None, mask=None, blur=4, blur_compose=4, sample_mode='sample', scale=None, **kwargs): """ @@ -697,12 +706,12 @@ class Cond2ImPipe(BasePipe): "inpaint": 1.0, "qr": 1.5 }) - - cond_scales_defaults_flux = defaultdict(lambda: 0.8, + + cond_scales_defaults_flux = defaultdict(lambda: 0.8, {"canny-dev": 0.6}) def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] = None, - ctypes=["soft"], cnets: Optional[List[ControlNetModel]]=None, + ctypes=["soft"], cnets: Optional[List[ControlNetModel]]=None, cnet_ids: Optional[List[str]]=None, model_type=None, **args): """ Constructor diff --git a/multigen/util.py b/multigen/util.py index efbbfc9..de6953e 100644 --- a/multigen/util.py +++ b/multigen/util.py @@ -3,6 +3,7 @@ import psutil from PIL import Image import copy as cp +from inspect import signature import optimum.quanto from optimum.quanto import freeze, qfloat8, quantize as _quantize from diffusers.utils import is_accelerate_available @@ -146,3 +147,9 @@ def weightshare_copy(pipe): # some buffers might not be transfered from pipe to copy copy.to(pipe.device) return copy + + +def get_allowed_components(cls: type) -> dict: + params = signature(cls.__init__).parameters + components = [name for name in params.keys() if name != 'self'] + return components