From 35e46a189610afd12565fcfb4d2838465cd6736d Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Mon, 20 Jan 2025 20:24:30 +0300 Subject: [PATCH 1/3] support for batch generation --- multigen/pipes.py | 38 +++++++------ multigen/sessions.py | 133 ++++++++++++++++++++++++++++--------------- tests/pipe_test.py | 28 ++++----- 3 files changed, 123 insertions(+), 76 deletions(-) diff --git a/multigen/pipes.py b/multigen/pipes.py index a8a0145..609978a 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -411,7 +411,7 @@ def gen(self, inputs: dict): """ kwargs = self.prepare_inputs(inputs) logging.debug("Prompt2ImPipe.gen calling pipe") - image = self.pipe(**kwargs).images[0] + image = self.pipe(**kwargs).images return image @@ -501,10 +501,12 @@ def gen(self, inputs: dict): # so we update kwargs with inputs after pipe_params kwargs.update({"image": self._input_image}) self.try_set_scheduler(kwargs) - image = self.pipe(**kwargs).images[0] - logging.debug(f'generated image {image}') - result = image.crop((0, 0, self._original_size[0], self._original_size[1])) - return result + res = [] + for image in self.pipe(**kwargs).images: + logging.debug(f'generated image {image}') + result = image.crop((0, 0, self._original_size[0], self._original_size[1])) + res.append(result) + return res class MaskedIm2ImPipe(Im2ImPipe): @@ -644,13 +646,15 @@ def gen(self, inputs): if 'sample_mode' not in inputs: inputs['sample_mode'] = self._sample_mode inputs['original_image'] = normalised - img_gen = super().gen(inputs) - - # compose with original using mask - img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1])) - # convert to PIL image - img_compose = Image.fromarray(img_compose.astype(np.uint8)) - return img_compose + images = super().gen(inputs) + res = [] + for img_gen in images: + # compose with original using mask + img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1])) + # convert to PIL image + img_compose = Image.fromarray(img_compose.astype(np.uint8)) + res.append(img_compose) + return res class Cond2ImPipe(BasePipe): @@ -914,10 +918,12 @@ def gen(self, inputs): inputs = self.prepare_inputs(inputs) inputs.update({"image": self._input_image, "control_image": self._condition_image}) - image = self.pipe(**inputs).images[0] - result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'), + res = [] + for image in self.pipe(**inputs).images: + result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'), self._original_size[1] if self._use_input_size else inputs.get('width') )) - return result + res.append(image) + return res class CIm2ImPipe(Cond2ImPipe): @@ -1126,7 +1132,7 @@ def gen(self, inputs): "mask_image": self._mask_image, "control_image": self._control_image }) - image = self.pipe(**inputs).images[0] + image = self.pipe(**inputs).images return image def _make_inpaint_condition(self, image, image_mask): diff --git a/multigen/sessions.py b/multigen/sessions.py index 99696e5..d7b88ee 100755 --- a/multigen/sessions.py +++ b/multigen/sessions.py @@ -7,11 +7,9 @@ class GenSession: - def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""): """ Initialize a GenSession instance. - Args: session_dir (str): The directory to store the session files. @@ -28,6 +26,8 @@ def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""): self.confg = config self.last_conf = None self.name_prefix = name_prefix + # Check if sequential CPU offloading is enabled + self.offload_gpu_id = getattr(pipe, 'offload_gpu_id', None) def get_last_conf(self): conf = {**self.last_conf} @@ -36,9 +36,9 @@ def get_last_conf(self): 'feedback': '?', 'cversion': "0.0.1"}) return conf - - def get_last_file_prefix(self): - idxs = self.name_prefix + str(self.last_index).zfill(5) + + def get_file_prefix(self, index): + idxs = self.name_prefix + str(index).zfill(5) f_prefix = os.path.join(self.session_dir, idxs) if os.path.isfile(f_prefix + ".txt"): cnt = 1 @@ -47,17 +47,15 @@ def get_last_file_prefix(self): f_prefix += "_" + str(cnt) return f_prefix - def save_last_conf(self): - self.last_cfg_name = self.get_last_file_prefix() + ".txt" - with open(self.last_cfg_name, 'w') as f: - print(json.dumps(self.get_last_conf(), indent=4), file=f) - - def gen_sess(self, add_count = 0, save_img=True, - drop_cfg=False, force_collect=False, - callback=None, save_metadata=False): + def save_conf(self, index, conf): + cfg_name = self.get_file_prefix(index) + ".txt" + with open(cfg_name, 'w') as f: + print(json.dumps(conf, indent=4), file=f) + + def gen_sess(self, add_count=0, save_img=True, drop_cfg=False, + force_collect=False, callback=None, save_metadata=False): """ - Run image generation session - + Run image generation session. Args: add_count (int, *optional*): The number of additional iterations to add. Defaults to 0. @@ -71,7 +69,6 @@ def gen_sess(self, add_count = 0, save_img=True, A callback function to be called after each iteration. Defaults to None. save_metadata (bool, *optional*): Whether to save metadata in the image EXIF. Defaults to False. - Returns: List[Image.Image]: The generated images if `save_img` is False or `force_collect` is True. """ @@ -79,35 +76,79 @@ def gen_sess(self, add_count = 0, save_img=True, self.confg.start_count = self.confg.count self.last_img_name = None self.last_cfg_name = None - images = None + images = [] + if save_img: os.makedirs(self.session_dir, exist_ok=True) - # collecting images to return if requested or images are not saved - if not save_img or force_collect: - images = [] - logging.info(f"add count = {add_count}") - jk = 0 - for inputs in self.confg: - self.last_index = self.confg.count - 1 - self.last_conf = {**inputs} - # TODO: multiple inputs? - inputs['generator'] = torch.Generator().manual_seed(inputs['generator']) - logging.debug("start generation") - image = self.pipe.gen(inputs) - if save_img: - self.last_img_name = self.get_last_file_prefix() + ".png" - exif = None - if save_metadata: - exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf())) - image.save(self.last_img_name, exif=exif) - if not save_img or force_collect: - images += [image] - # saving cfg only if images are saved and dropping is not requested - if save_img and not drop_cfg: - self.save_last_conf() - if callback is not None: - logging.debug("call callback after generation") - callback() - jk += 1 - logging.debug(f"done iteration {jk}") - return images + + # Determine batch size + if self.offload_gpu_id is not None: + # Sequential CPU offloading is enabled, set batch_size to a reasonable number + batch_size = 8 # You can adjust this value based on your environment + else: + batch_size = 1 # Process one input at a time + + logging.info(f"Starting generation with batch_size = {batch_size}") + confg_iter = iter(self.confg) + index = self.confg.start_count + + while True: + batch_inputs_list = [] + # Collect inputs into batch + for _ in range(batch_size): + try: + inputs = next(confg_iter) + except StopIteration: + break # No more inputs + batch_inputs_list.append(inputs) + + if not batch_inputs_list: + break # All inputs have been processed + + # Prepare batch inputs + batch_inputs_dict = {} + for key in batch_inputs_list[0]: + batch_inputs_dict[key] = [input[key] for input in batch_inputs_list] + + # Adjust 'generator' field with manual seeds + batch_generators = [] + for seed in batch_inputs_dict.get('generator', [None] * len(batch_inputs_list)): + if seed is not None: + batch_generators.append(torch.Generator().manual_seed(seed)) + else: + batch_generators.append(torch.Generator()) + batch_inputs_dict['generator'] = batch_generators + + # Generate images + batch_images = self.pipe.gen(batch_inputs_dict) + + # Process generated images + for i, image in enumerate(batch_images): + idx = index + i + self.last_index = idx + self.last_conf = {**batch_inputs_list[i % len(batch_inputs_list)]} + self.last_conf.update(self.pipe.get_config()) + self.last_conf.update({'feedback': '?', 'cversion': '0.0.1'}) + + if save_img: + f_prefix = self.get_file_prefix(idx) + img_name = f_prefix + ".png" + exif = None + if save_metadata: + exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf())) + image.save(img_name, exif=exif) + self.last_img_name = img_name + if not drop_cfg: + # Save configuration + self.save_conf(idx, self.get_last_conf()) + if not save_img or force_collect: + images.append(image) + if callback is not None: + logging.debug("Call callback after generation") + callback() + + index += len(batch_images) + logging.debug(f"Processed batch up to index {index}") + + logging.debug(f"Generation session completed.") + return images if images else None diff --git a/tests/pipe_test.py b/tests/pipe_test.py index 71af63d..23a1fc1 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -28,7 +28,7 @@ def test_basic_txt2im(self): params = dict(prompt="a cube planet, cube-shaped, space photo, masterpiece", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image = pipe.gen(params) + image = pipe.gen(params)[0] image.save("cube_test.png") # generate with different scheduler @@ -36,7 +36,7 @@ def test_basic_txt2im(self): params.update(generator=torch.Generator().manual_seed(seed + 1)) else: params.update(scheduler=self.schedulers[1]) - image_ddim = pipe.gen(params) + image_ddim = pipe.gen(params)[0] image_ddim.save("cube_test2_dimm.png") diff = self.compute_diff(image_ddim, image) # check that difference is large @@ -73,15 +73,15 @@ def test_img2img_basic(self): seed = 49045438434843 pipe.setup(im, strength=0.7, steps=5, guidance_scale=3.3) self.assertEqual(3.3, pipe.pipe_params['guidance_scale']) - image = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed))) + image = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0] image.save('test_img2img_basic.png') pipe.setup(im, strength=0.7, steps=5, guidance_scale=7.6) - image1 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed))) + image1 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0] diff = self.compute_diff(image1, image) # check that difference is large self.assertGreater(diff, 1000) pipe.setup(im, strength=0.7, steps=5, guidance_scale=3.3) - image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed))) + image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0] diff = self.compute_diff(image2, image) # check that difference is small self.assertLess(diff, 1) @@ -108,18 +108,18 @@ def test_maskedimg2img_basic(self): pipe.setup(**param_3_3) self.assertEqual(3.3, pipe.pipe_params['guidance_scale']) image = pipe.gen(dict(prompt="cube planet cartoon style", - generator=torch.Generator().manual_seed(seed))) + generator=torch.Generator().manual_seed(seed)))[0] self.assertEqual(image.width, img.width) self.assertEqual(image.height, img.height) image.save('test_img2img_basic.png') pipe.setup(**param_7_6) image1 = pipe.gen(dict(prompt="cube planet cartoon style", - generator=torch.Generator().manual_seed(seed))) + generator=torch.Generator().manual_seed(seed)))[0] diff = self.compute_diff(image1, image) # check that difference is large self.assertGreater(diff, 1000) pipe.setup(**param_3_3) - image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed))) + image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0] diff = self.compute_diff(image2, image) # check that difference is small self.assertLess(diff, 1) @@ -138,12 +138,12 @@ def test_lpw(self): params = dict(prompt=prompt, negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image = pipe.gen(params) + image = pipe.gen(params)[0] image.save("cube_test_lpw.png") params = dict(prompt=prompt + " , best quality, famous photo", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image1 = pipe.gen(params) + image1 = pipe.gen(params)[0] image.save("cube_test_lpw1.png") diff = self.compute_diff(image1, image) # check that difference is large @@ -161,12 +161,12 @@ def test_lpw_turned_off(self): params = dict(prompt=prompt, negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image = pipe.gen(params) + image = pipe.gen(params)[0] image.save("cube_test_no_lpw.png") params = dict(prompt=prompt + " , best quality, famous photo", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image1 = pipe.gen(params) + image1 = pipe.gen(params)[0] image.save("cube_test_no_lpw1.png") diff = self.compute_diff(image1, image) # check that difference is large @@ -193,7 +193,7 @@ def test_controlnet(self): params = dict(prompt="cube planet minecraft style", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - image = pipe.gen(params) + image = pipe.gen(params)[0] image.save("mech_test.png") img_ref = PIL.Image.open(imgpth) self.assertEqual(image.width, img_ref.width) @@ -205,7 +205,7 @@ def test_controlnet(self): else: # generate with different scheduler params.update(scheduler=self.schedulers[1]) - image_ddim = pipe.gen(params) + image_ddim = pipe.gen(params)[0] image_ddim.save("cube_test2_dimm.png") diff = self.compute_diff(image_ddim, image) # check that difference is large From c324921a42bda8e018ef08a8079dc30337ae0bbe Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Tue, 21 Jan 2025 16:13:52 +0300 Subject: [PATCH 2/3] fix padding issue --- multigen/pipes.py | 23 +++++++++++++------ multigen/util.py | 55 ++++++++++++++++++++++++++++------------------ tests/pipe_test.py | 4 ++-- 3 files changed, 52 insertions(+), 30 deletions(-) diff --git a/multigen/pipes.py b/multigen/pipes.py index 609978a..dbad91c 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -362,6 +362,14 @@ def prepare_inputs(self, inputs): self.try_set_scheduler(inputs) return kwargs + @property + def pad(self): + pad = 8 + if hasattr(self.pipe, 'image_processor'): + if hasattr(self.pipe.image_processor, 'vae_scale_factor'): + pad = self.pipe.image_processor.vae_scale_factor + return pad + class Prompt2ImPipe(BasePipe): """ @@ -446,7 +454,7 @@ def setup(self, fimage, image=None, strength=0.75, self._input_image = self.scale_image(self._input_image, scale) self._original_size = self._input_image.size logging.debug("origin image size {self._original_size}") - self._input_image = util.pad_image_to_multiple_of_8(self._input_image) + self._input_image = util.pad_image_to_multiple(self._input_image, self.pad) self.pipe_params.update({ "width": self._input_image.width if width is None else width, "height": self._input_image.height if height is None else height, @@ -599,12 +607,13 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4, input_image = self._image_painted if self._image_painted is not None else self._original_image super().setup(fimage=None, image=input_image, scale=scale, **kwargs) + if self._original_image is not None: self._original_image = self.scale_image(self._original_image, scale) - self._original_image = util.pad_image_to_multiple_of_8(self._original_image) + self._original_image = util.pad_image_to_multiple(self._original_image, self.pad) if self._image_painted is not None: self._image_painted = self.scale_image(self._image_painted, scale) - self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted) + self._image_painted = util.pad_image_to_multiple(self._image_painted, self.pad) # there are two options: # 1. mask is provided @@ -621,7 +630,7 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4, pil_mask = Image.fromarray(mask) if pil_mask.mode != "L": pil_mask = pil_mask.convert("L") - pil_mask = util.pad_image_to_multiple_of_8(pil_mask) + pil_mask = util.pad_image_to_multiple(pil_mask, self.pad) self._mask = pil_mask self._mask_blur = self.blur_mask(pil_mask, blur) self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1])) @@ -867,7 +876,7 @@ def setup(self, fimage, width=None, height=None, image = Image.open(fimage).convert("RGB") if image is None else image self._original_size = image.size self._use_input_size = width is None or height is None - image = util.pad_image_to_multiple_of_8(image) + image = util.pad_image_to_multiple(image, self.pad) self._condition_image = [image] self._input_image = [image] if cscales is None: @@ -922,7 +931,7 @@ def gen(self, inputs): for image in self.pipe(**inputs).images: result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'), self._original_size[1] if self._use_input_size else inputs.get('width') )) - res.append(image) + res.append(result) return res @@ -1044,7 +1053,7 @@ def _proc_cimg(self, oriImg): condition_image += [Image.fromarray(formatted)] else: condition_image += [Image.fromarray(oriImg)] - return condition_image + return [c.resize((oriImg.shape[1], oriImg.shape[0])) for c in condition_image] class InpaintingPipe(MaskedIm2ImPipe): diff --git a/multigen/util.py b/multigen/util.py index de6953e..5aa69e7 100644 --- a/multigen/util.py +++ b/multigen/util.py @@ -28,37 +28,50 @@ def create_exif_metadata(im: Image, custom_metadata): return exif -def pad_image_to_multiple_of_8(image: Image) -> Image: +def pad_image_to_multiple(image: Image, padding_size: int = 8) -> Image: """ - Pads the input image by repeating the bottom or right-most column of pixels - so that the height and width of the image is divisible by 8. + Pads the input image by repeating the bottom and right-most rows and columns of pixels + so that its dimensions are divisible by 'padding_size'. Args: - image (Image): The input PIL image. + image (Image): The input PIL Image. + padding_size (int): The multiple to which dimensions are padded. Returns: - Image: The padded PIL image. + Image: The padded PIL Image. """ - # Calculate the new dimensions - new_width = (image.width + 7) // 8 * 8 - new_height = (image.height + 7) // 8 * 8 + new_width = ((image.width + padding_size - 1) // padding_size) * padding_size + new_height = ((image.height + padding_size - 1) // padding_size) * padding_size + + # Calculate padding amounts + pad_right = new_width - image.width + pad_bottom = new_height - image.height - # Create a new image with the new dimensions and paste the original image onto it + # Create a new image with the new dimensions padded_image = Image.new(image.mode, (new_width, new_height)) padded_image.paste(image, (0, 0)) - # Repeat the right-most column of pixels to fill the horizontal padding - for x in range(new_width - image.width): - box = (image.width + x, 0, image.width + x + 1, image.height) - region = image.crop((image.width - 1, 0, image.width, image.height)) - padded_image.paste(region, box) - - # Repeat the bottom-most row of pixels to fill the vertical padding - for y in range(new_height - image.height): - box = (0, image.height + y, image.width, image.height + y + 1) - region = image.crop((0, image.height - 1, image.width, image.height)) - padded_image.paste(region, box) + # Check if padding is needed + if pad_right > 0 or pad_bottom > 0: + # Get the last column and row + if pad_right > 0: + last_column = image.crop((image.width - 1, 0, image.width, image.height)) + # Resize the last column to fill the right padding area + right_padding = last_column.resize((pad_right, image.height), Image.NEAREST) + padded_image.paste(right_padding, (image.width, 0)) + + if pad_bottom > 0: + last_row = image.crop((0, image.height - 1, image.width, image.height)) + # Resize the last row to fill the bottom padding area + bottom_padding = last_row.resize((image.width, pad_bottom), Image.NEAREST) + padded_image.paste(bottom_padding, (0, image.height)) + + if pad_right > 0 and pad_bottom > 0: + # Fill the bottom-right corner + last_pixel = image.getpixel((image.width - 1, image.height - 1)) + corner = Image.new(image.mode, (pad_right, pad_bottom), last_pixel) + padded_image.paste(corner, (image.width, image.height)) return padded_image @@ -97,7 +110,7 @@ def awailable_ram(): def quantize(pipe, dtype=qfloat8): components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae'] - + for component in components: if hasattr(pipe, component): component_obj = getattr(pipe, component) diff --git a/tests/pipe_test.py b/tests/pipe_test.py index 23a1fc1..8690107 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -220,10 +220,10 @@ def test_cond2im(self): params = dict(prompt="child in the coat playing in sandbox", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - img = pipe.gen(params) + img = pipe.gen(params)[0] self.assertEqual(img.size, (768, 768)) pipe.setup("./pose6.jpeg") - img1 = pipe.gen(params) + img1 = pipe.gen(params)[0] self.assertEqual(img1.size, (450, 450)) From c784cdd88e28324edd0c73dc4387b763cf50c9fe Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Tue, 21 Jan 2025 17:46:04 +0300 Subject: [PATCH 3/3] fix logging --- multigen/pipes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/multigen/pipes.py b/multigen/pipes.py index dbad91c..b17c999 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -565,8 +565,8 @@ def __verify_from_pipe(self, cls, pipe, **args): source_components = set(pipe.components.keys()) target_components = set(allowed) - logging.debug("Missing components: ", target_components - source_components) - logging.debug("Extra components: ", source_components - target_components) + logging.debug("Missing components: " + str(target_components - source_components)) + logging.debug("Extra components: " + str(source_components - target_components)) return cls(**{k: v for (k, v) in pipe.components.items() if k in allowed}, **args) def setup(self, image=None, image_painted=None, mask=None, blur=4,