diff --git a/multigen/pipes.py b/multigen/pipes.py index 609978a..dbad91c 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -362,6 +362,14 @@ def prepare_inputs(self, inputs): self.try_set_scheduler(inputs) return kwargs + @property + def pad(self): + pad = 8 + if hasattr(self.pipe, 'image_processor'): + if hasattr(self.pipe.image_processor, 'vae_scale_factor'): + pad = self.pipe.image_processor.vae_scale_factor + return pad + class Prompt2ImPipe(BasePipe): """ @@ -446,7 +454,7 @@ def setup(self, fimage, image=None, strength=0.75, self._input_image = self.scale_image(self._input_image, scale) self._original_size = self._input_image.size logging.debug("origin image size {self._original_size}") - self._input_image = util.pad_image_to_multiple_of_8(self._input_image) + self._input_image = util.pad_image_to_multiple(self._input_image, self.pad) self.pipe_params.update({ "width": self._input_image.width if width is None else width, "height": self._input_image.height if height is None else height, @@ -599,12 +607,13 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4, input_image = self._image_painted if self._image_painted is not None else self._original_image super().setup(fimage=None, image=input_image, scale=scale, **kwargs) + if self._original_image is not None: self._original_image = self.scale_image(self._original_image, scale) - self._original_image = util.pad_image_to_multiple_of_8(self._original_image) + self._original_image = util.pad_image_to_multiple(self._original_image, self.pad) if self._image_painted is not None: self._image_painted = self.scale_image(self._image_painted, scale) - self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted) + self._image_painted = util.pad_image_to_multiple(self._image_painted, self.pad) # there are two options: # 1. mask is provided @@ -621,7 +630,7 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4, pil_mask = Image.fromarray(mask) if pil_mask.mode != "L": pil_mask = pil_mask.convert("L") - pil_mask = util.pad_image_to_multiple_of_8(pil_mask) + pil_mask = util.pad_image_to_multiple(pil_mask, self.pad) self._mask = pil_mask self._mask_blur = self.blur_mask(pil_mask, blur) self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1])) @@ -867,7 +876,7 @@ def setup(self, fimage, width=None, height=None, image = Image.open(fimage).convert("RGB") if image is None else image self._original_size = image.size self._use_input_size = width is None or height is None - image = util.pad_image_to_multiple_of_8(image) + image = util.pad_image_to_multiple(image, self.pad) self._condition_image = [image] self._input_image = [image] if cscales is None: @@ -922,7 +931,7 @@ def gen(self, inputs): for image in self.pipe(**inputs).images: result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'), self._original_size[1] if self._use_input_size else inputs.get('width') )) - res.append(image) + res.append(result) return res @@ -1044,7 +1053,7 @@ def _proc_cimg(self, oriImg): condition_image += [Image.fromarray(formatted)] else: condition_image += [Image.fromarray(oriImg)] - return condition_image + return [c.resize((oriImg.shape[1], oriImg.shape[0])) for c in condition_image] class InpaintingPipe(MaskedIm2ImPipe): diff --git a/multigen/util.py b/multigen/util.py index de6953e..5aa69e7 100644 --- a/multigen/util.py +++ b/multigen/util.py @@ -28,37 +28,50 @@ def create_exif_metadata(im: Image, custom_metadata): return exif -def pad_image_to_multiple_of_8(image: Image) -> Image: +def pad_image_to_multiple(image: Image, padding_size: int = 8) -> Image: """ - Pads the input image by repeating the bottom or right-most column of pixels - so that the height and width of the image is divisible by 8. + Pads the input image by repeating the bottom and right-most rows and columns of pixels + so that its dimensions are divisible by 'padding_size'. Args: - image (Image): The input PIL image. + image (Image): The input PIL Image. + padding_size (int): The multiple to which dimensions are padded. Returns: - Image: The padded PIL image. + Image: The padded PIL Image. """ - # Calculate the new dimensions - new_width = (image.width + 7) // 8 * 8 - new_height = (image.height + 7) // 8 * 8 + new_width = ((image.width + padding_size - 1) // padding_size) * padding_size + new_height = ((image.height + padding_size - 1) // padding_size) * padding_size + + # Calculate padding amounts + pad_right = new_width - image.width + pad_bottom = new_height - image.height - # Create a new image with the new dimensions and paste the original image onto it + # Create a new image with the new dimensions padded_image = Image.new(image.mode, (new_width, new_height)) padded_image.paste(image, (0, 0)) - # Repeat the right-most column of pixels to fill the horizontal padding - for x in range(new_width - image.width): - box = (image.width + x, 0, image.width + x + 1, image.height) - region = image.crop((image.width - 1, 0, image.width, image.height)) - padded_image.paste(region, box) - - # Repeat the bottom-most row of pixels to fill the vertical padding - for y in range(new_height - image.height): - box = (0, image.height + y, image.width, image.height + y + 1) - region = image.crop((0, image.height - 1, image.width, image.height)) - padded_image.paste(region, box) + # Check if padding is needed + if pad_right > 0 or pad_bottom > 0: + # Get the last column and row + if pad_right > 0: + last_column = image.crop((image.width - 1, 0, image.width, image.height)) + # Resize the last column to fill the right padding area + right_padding = last_column.resize((pad_right, image.height), Image.NEAREST) + padded_image.paste(right_padding, (image.width, 0)) + + if pad_bottom > 0: + last_row = image.crop((0, image.height - 1, image.width, image.height)) + # Resize the last row to fill the bottom padding area + bottom_padding = last_row.resize((image.width, pad_bottom), Image.NEAREST) + padded_image.paste(bottom_padding, (0, image.height)) + + if pad_right > 0 and pad_bottom > 0: + # Fill the bottom-right corner + last_pixel = image.getpixel((image.width - 1, image.height - 1)) + corner = Image.new(image.mode, (pad_right, pad_bottom), last_pixel) + padded_image.paste(corner, (image.width, image.height)) return padded_image @@ -97,7 +110,7 @@ def awailable_ram(): def quantize(pipe, dtype=qfloat8): components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae'] - + for component in components: if hasattr(pipe, component): component_obj = getattr(pipe, component) diff --git a/tests/pipe_test.py b/tests/pipe_test.py index 23a1fc1..8690107 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -220,10 +220,10 @@ def test_cond2im(self): params = dict(prompt="child in the coat playing in sandbox", negative_prompt="spherical", generator=torch.Generator().manual_seed(seed)) - img = pipe.gen(params) + img = pipe.gen(params)[0] self.assertEqual(img.size, (768, 768)) pipe.setup("./pose6.jpeg") - img1 = pipe.gen(params) + img1 = pipe.gen(params)[0] self.assertEqual(img1.size, (450, 450))