Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor pipe weightsharing for quantised models #87

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ jobs:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python pipe_test.py
- name: Test loader
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_loader.py
- name: Test worker
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_worker.py
- name: Test worker flux
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_worker_flux.py
143 changes: 59 additions & 84 deletions multigen/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type, List
from typing import Type, List, Union, Optional, Any
from dataclasses import dataclass
import random
import copy as cp
from contextlib import nullcontext
Expand All @@ -10,44 +11,32 @@
import diffusers

from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
else:
init_empty_weights = nullcontext

from .util import get_model_size, awailable_ram, quantize, weightshare_copy


logger = logging.getLogger(__file__)


def weightshare_copy(pipe):
@dataclass(frozen=True)
class ModelDescriptor:
"""
Create a new pipe object then assign weights using load_state_dict from passed 'pipe'
Descriptor class for model identification that includes quantization information
"""
copy = pipe.__class__(**pipe.components)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
for key, component in copy.components.items():
if getattr(copy, key) is None:
continue
if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'):
setattr(copy, key, cp.deepcopy(getattr(copy, key)))
continue
cls = getattr(copy, key).__class__
if hasattr(cls, 'from_config'):
setattr(copy, key, cls.from_config(getattr(copy, key).config))
else:
setattr(copy, key, cls(getattr(copy, key).config))
# assign=True is needed since our copy is on "meta" device, i.g. weights are empty
for key, component in copy.components.items():
if key == 'tokenizer' or key == 'tokenizer_2':
continue
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy
model_id: str
quantize_dtype: Optional[Any] = None

def __hash__(self):
return hash((self.model_id, str(self.quantize_dtype)))

def __eq__(self, other):
if isinstance(other, str):
return self.model_id == other

if not isinstance(other, ModelDescriptor):
return False
return (self.model_id == other.model_id and
self.quantize_dtype == other.quantize_dtype)


class Loader:
Expand All @@ -56,9 +45,8 @@ class for loading diffusion pipelines from files.
"""
def __init__(self):
self._lock = threading.RLock()
self._cpu_pipes = dict()
# idx -> list of (model_id, pipe) pairs
self._gpu_pipes = dict()
self._cpu_pipes = dict() # ModelDescriptor -> pipe
self._gpu_pipes = dict() # gpu idx -> list of (ModelDescriptor, pipe) pairs

def get_gpu(self, model_id) -> List[int]:
"""
Expand All @@ -73,24 +61,29 @@ def get_gpu(self, model_id) -> List[int]:
return result

def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bfloat16,
device=None, offload_device=None, **additional_args):
device=None, offload_device=None, quantize_dtype=None, **additional_args):
with self._lock:
logger.debug(f'looking for pipeline {cls} from {path} on {device}')
result = None
descriptor = ModelDescriptor(path, quantize_dtype)
found_quantized = False
if device is None:
device = torch.device('cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, [])
for (key, value) in gpu_pipes:
if key == path:
if key == descriptor:
logger.debug(f'found pipe in gpu cache {key}')
result = self.from_pipe(cls, value, additional_args)
logger.debug(f'created pipe from gpu cache {key} on {device}')
return result
for (key, pipe) in self._cpu_pipes.items():
if key == path:
if key == descriptor:
found_quantized = True
logger.debug(f'found pipe in cpu cache {key} {pipe.device}')
if device.type == 'cuda':
pipe = cp.deepcopy(pipe)
result = self.from_pipe(cls, pipe, additional_args)
break
if result is None:
Expand All @@ -106,16 +99,26 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bf
logger.debug("prepare pipe before returning from loader")
logger.debug(f"{path} on {result.device} {result.dtype}")

# Add quantization if specified
if (not found_quantized) and quantize_dtype is not None:
logger.debug(f'Quantizing pipeline to {quantize_dtype}')
quantize(result, dtype=quantize_dtype)

if result.device != device:
logger.debug(f"move pipe to {device}")
result = result.to(dtype=torch_dtype, device=device)
if result.dtype != torch_dtype:
result = result.to(dtype=torch_dtype)

self.cache_pipeline(result, path)
logger.debug(f'result device before weightshare_copy {result.device}')
result = weightshare_copy(result)
logger.debug(f'result device after weightshare_copy {result.device}')
assert result.device.type == device.type
if device.type == 'cuda':
assert result.device.index == device.index
logger.debug(f'returning {type(result)} from {path} on {result.device}')
logger.debug(f'returning {type(result)} from {path} \
on {result.device} scheduler {id(result.scheduler)}')
return result

def from_pipe(self, cls, pipe, additional_args):
Expand All @@ -131,86 +134,58 @@ def from_pipe(self, cls, pipe, additional_args):
components.pop('controlnet')
return cls(**components, **additional_args)

def cache_pipeline(self, pipe: DiffusionPipeline, model_id):
def cache_pipeline(self, pipe: DiffusionPipeline, descriptor: ModelDescriptor):
logger.debug(f'caching pipeline {descriptor} {pipe}')
with self._lock:
device = pipe.device
if model_id not in self._cpu_pipes:
if descriptor not in self._cpu_pipes:
# deepcopy is needed since Module.to is an inplace operation
size = get_model_size(pipe)
ram = awailable_ram()
logger.debug(f'{model_id} has size {size}, ram {ram}')
logger.debug(f'{descriptor} has size {size}, ram {ram}')
if ram < size * 2.5 and self._cpu_pipes:
key_to_delete = random.choice(list(self._cpu_pipes.keys()))
self._cpu_pipes.pop(key_to_delete)
item = pipe
if pipe.device.type == 'cuda':
item = cp.deepcopy(pipe).to('cpu')
self._cpu_pipes[model_id] = item
logger.debug(f'storing {model_id} on cpu')
device = pipe.device
logger.debug("deepcopy pipe from gpu to save it in cpu cache")
item = cp.deepcopy(pipe.to('cpu'))
pipe.to(device)
self._cpu_pipes[descriptor] = item
logger.debug(f'storing {descriptor} on cpu')
assert pipe.device == device
if pipe.device.type == 'cuda':
self._store_gpu_pipe(pipe, model_id)
logger.debug(f'storing {model_id} on {pipe.device}')
self._store_gpu_pipe(pipe, descriptor)
logger.debug(f'storing {descriptor} on {pipe.device}')

def clear_cache(self, device):
logger.debug(f'clear_cache pipelines from {device}')
with self._lock:
if device.type == 'cuda':
self._gpu_pipes[device.index] = []

def _store_gpu_pipe(self, pipe, model_id):
def _store_gpu_pipe(self, pipe, descriptor: ModelDescriptor):
idx = pipe.device.index
assert idx is not None
# for now just clear all other pipelines
self._gpu_pipes[idx] = [(model_id, pipe)]
self._gpu_pipes[idx] = [(descriptor, pipe)]

def remove_pipeline(self, model_id):
self._cpu_pipes.pop(model_id)

def get_pipeline(self, model_id, device=None):
def get_pipeline(self, descriptor: Union[ModelDescriptor, str], device=None):
with self._lock:
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, ())
for (key, value) in gpu_pipes:
if key == model_id:
if key == descriptor:
return value
for (key, pipe) in self._cpu_pipes.items():
if key == model_id:
if key == descriptor:
return pipe

return None


def count_params(model):
total_size = sum(param.numel() for param in model.parameters())
mul = 2
if model.dtype in (torch.float16, torch.bfloat16):
mul = 2
elif model.dtype == torch.float32:
mul = 4
return total_size * mul


def get_size(obj):
return sys.getsizeof(obj)


def get_model_size(pipeline):
total_size = 0
for name, component in pipeline.components.items():
if isinstance(component, torch.nn.Module):
total_size += count_params(component)
elif hasattr(component, 'tokenizer'):
total_size += count_params(component.tokenizer)
else:
total_size += get_size(component)
return total_size / (1024 * 1024)


def awailable_ram():
mem = psutil.virtual_memory()
available_ram = mem.available
return available_ram / (1024 * 1024)
19 changes: 11 additions & 8 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def _get_model_type(self):
return ModelType.SD
elif module.startswith('diffusers.pipelines.flux.pipeline_flux'):
return ModelType.FLUX
elif 'masked_stable_diffusion_xl_img2img' in module:
return ModelType.SDXL
else:
raise RuntimeError("unsuported model type {self.pipe.__class__}")
raise RuntimeError(f"unsuported model type {self.pipe.__class__}")

def _initialize_pipe(self, device, offload_device):
# sometimes text encoder is on a different device
Expand Down Expand Up @@ -744,7 +746,8 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
raise RuntimeError(f"Unexpected model type {type(self.pipe)}")
self.model_type = t_model_type
logging.debug(f"from_pipe source dtype {self.pipe.dtype}")
device = self.pipe.device
logging.debug(f"from_pipe source dtype {self.pipe.dtype} {device}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), self.pipe.dtype)
prev_dtype = self.pipe.dtype
if self.model_type == ModelType.SDXL:
Expand All @@ -754,11 +757,11 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
self.pipe = self._class.from_pipe(self.pipe, controlnet=cnets)
logging.debug(f"after from_pipe result dtype {self.pipe.dtype}")
for cnet in cnets:
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(self.pipe.device)
for cnet in cnets:
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(device)
else:
# don't load anything, just reuse pipe
super().__init__(model_id=model_id, pipe=pipe, **args)
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
"""
dtype = torch.float32
if torch.cuda.is_available():
dtype = torch.float16
dtype = torch.bfloat16
dtype = args.get('torch_type', dtype)
cnet = ControlNetModel.from_pretrained(
Cond2ImPipe.cpath+Cond2ImPipe.cmodels["inpaint"], torch_dtype=dtype)
Expand Down
Loading