Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for batch generation #93

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def prepare_inputs(self, inputs):
self.try_set_scheduler(inputs)
return kwargs

@property
def pad(self):
pad = 8
if hasattr(self.pipe, 'image_processor'):
if hasattr(self.pipe.image_processor, 'vae_scale_factor'):
pad = self.pipe.image_processor.vae_scale_factor
return pad


class Prompt2ImPipe(BasePipe):
"""
Expand Down Expand Up @@ -411,7 +419,7 @@ def gen(self, inputs: dict):
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
image = self.pipe(**kwargs).images
return image


Expand Down Expand Up @@ -446,7 +454,7 @@ def setup(self, fimage, image=None, strength=0.75,
self._input_image = self.scale_image(self._input_image, scale)
self._original_size = self._input_image.size
logging.debug("origin image size {self._original_size}")
self._input_image = util.pad_image_to_multiple_of_8(self._input_image)
self._input_image = util.pad_image_to_multiple(self._input_image, self.pad)
self.pipe_params.update({
"width": self._input_image.width if width is None else width,
"height": self._input_image.height if height is None else height,
Expand Down Expand Up @@ -501,10 +509,12 @@ def gen(self, inputs: dict):
# so we update kwargs with inputs after pipe_params
kwargs.update({"image": self._input_image})
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return result
res = []
for image in self.pipe(**kwargs).images:
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
res.append(result)
return res


class MaskedIm2ImPipe(Im2ImPipe):
Expand Down Expand Up @@ -555,8 +565,8 @@ def __verify_from_pipe(self, cls, pipe, **args):
source_components = set(pipe.components.keys())
target_components = set(allowed)

logging.debug("Missing components: ", target_components - source_components)
logging.debug("Extra components: ", source_components - target_components)
logging.debug("Missing components: " + str(target_components - source_components))
logging.debug("Extra components: " + str(source_components - target_components))
return cls(**{k: v for (k, v) in pipe.components.items() if k in allowed}, **args)

def setup(self, image=None, image_painted=None, mask=None, blur=4,
Expand Down Expand Up @@ -597,12 +607,13 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
input_image = self._image_painted if self._image_painted is not None else self._original_image

super().setup(fimage=None, image=input_image, scale=scale, **kwargs)

if self._original_image is not None:
self._original_image = self.scale_image(self._original_image, scale)
self._original_image = util.pad_image_to_multiple_of_8(self._original_image)
self._original_image = util.pad_image_to_multiple(self._original_image, self.pad)
if self._image_painted is not None:
self._image_painted = self.scale_image(self._image_painted, scale)
self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted)
self._image_painted = util.pad_image_to_multiple(self._image_painted, self.pad)

# there are two options:
# 1. mask is provided
Expand All @@ -619,7 +630,7 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
pil_mask = Image.fromarray(mask)
if pil_mask.mode != "L":
pil_mask = pil_mask.convert("L")
pil_mask = util.pad_image_to_multiple_of_8(pil_mask)
pil_mask = util.pad_image_to_multiple(pil_mask, self.pad)
self._mask = pil_mask
self._mask_blur = self.blur_mask(pil_mask, blur)
self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1]))
Expand All @@ -644,13 +655,15 @@ def gen(self, inputs):
if 'sample_mode' not in inputs:
inputs['sample_mode'] = self._sample_mode
inputs['original_image'] = normalised
img_gen = super().gen(inputs)

# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
return img_compose
images = super().gen(inputs)
res = []
for img_gen in images:
# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
res.append(img_compose)
return res


class Cond2ImPipe(BasePipe):
Expand Down Expand Up @@ -863,7 +876,7 @@ def setup(self, fimage, width=None, height=None,
image = Image.open(fimage).convert("RGB") if image is None else image
self._original_size = image.size
self._use_input_size = width is None or height is None
image = util.pad_image_to_multiple_of_8(image)
image = util.pad_image_to_multiple(image, self.pad)
self._condition_image = [image]
self._input_image = [image]
if cscales is None:
Expand Down Expand Up @@ -914,10 +927,12 @@ def gen(self, inputs):
inputs = self.prepare_inputs(inputs)
inputs.update({"image": self._input_image,
"control_image": self._condition_image})
image = self.pipe(**inputs).images[0]
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
res = []
for image in self.pipe(**inputs).images:
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
self._original_size[1] if self._use_input_size else inputs.get('width') ))
return result
res.append(result)
return res


class CIm2ImPipe(Cond2ImPipe):
Expand Down Expand Up @@ -1038,7 +1053,7 @@ def _proc_cimg(self, oriImg):
condition_image += [Image.fromarray(formatted)]
else:
condition_image += [Image.fromarray(oriImg)]
return condition_image
return [c.resize((oriImg.shape[1], oriImg.shape[0])) for c in condition_image]


class InpaintingPipe(MaskedIm2ImPipe):
Expand Down Expand Up @@ -1126,7 +1141,7 @@ def gen(self, inputs):
"mask_image": self._mask_image,
"control_image": self._control_image
})
image = self.pipe(**inputs).images[0]
image = self.pipe(**inputs).images
return image

def _make_inpaint_condition(self, image, image_mask):
Expand Down
133 changes: 87 additions & 46 deletions multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@


class GenSession:

def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
"""
Initialize a GenSession instance.

Args:
session_dir (str):
The directory to store the session files.
Expand All @@ -28,6 +26,8 @@ def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
self.confg = config
self.last_conf = None
self.name_prefix = name_prefix
# Check if sequential CPU offloading is enabled
self.offload_gpu_id = getattr(pipe, 'offload_gpu_id', None)

def get_last_conf(self):
conf = {**self.last_conf}
Expand All @@ -36,9 +36,9 @@ def get_last_conf(self):
'feedback': '?',
'cversion': "0.0.1"})
return conf

def get_last_file_prefix(self):
idxs = self.name_prefix + str(self.last_index).zfill(5)
def get_file_prefix(self, index):
idxs = self.name_prefix + str(index).zfill(5)
f_prefix = os.path.join(self.session_dir, idxs)
if os.path.isfile(f_prefix + ".txt"):
cnt = 1
Expand All @@ -47,17 +47,15 @@ def get_last_file_prefix(self):
f_prefix += "_" + str(cnt)
return f_prefix

def save_last_conf(self):
self.last_cfg_name = self.get_last_file_prefix() + ".txt"
with open(self.last_cfg_name, 'w') as f:
print(json.dumps(self.get_last_conf(), indent=4), file=f)

def gen_sess(self, add_count = 0, save_img=True,
drop_cfg=False, force_collect=False,
callback=None, save_metadata=False):
def save_conf(self, index, conf):
cfg_name = self.get_file_prefix(index) + ".txt"
with open(cfg_name, 'w') as f:
print(json.dumps(conf, indent=4), file=f)

def gen_sess(self, add_count=0, save_img=True, drop_cfg=False,
force_collect=False, callback=None, save_metadata=False):
"""
Run image generation session

Run image generation session.
Args:
add_count (int, *optional*):
The number of additional iterations to add. Defaults to 0.
Expand All @@ -71,43 +69,86 @@ def gen_sess(self, add_count = 0, save_img=True,
A callback function to be called after each iteration. Defaults to None.
save_metadata (bool, *optional*):
Whether to save metadata in the image EXIF. Defaults to False.

Returns:
List[Image.Image]: The generated images if `save_img` is False or `force_collect` is True.
"""
self.confg.max_count += add_count
self.confg.start_count = self.confg.count
self.last_img_name = None
self.last_cfg_name = None
images = None
images = []

if save_img:
os.makedirs(self.session_dir, exist_ok=True)
# collecting images to return if requested or images are not saved
if not save_img or force_collect:
images = []
logging.info(f"add count = {add_count}")
jk = 0
for inputs in self.confg:
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])
logging.debug("start generation")
image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(self.last_img_name, exif=exif)
if not save_img or force_collect:
images += [image]
# saving cfg only if images are saved and dropping is not requested
if save_img and not drop_cfg:
self.save_last_conf()
if callback is not None:
logging.debug("call callback after generation")
callback()
jk += 1
logging.debug(f"done iteration {jk}")
return images

# Determine batch size
if self.offload_gpu_id is not None:
# Sequential CPU offloading is enabled, set batch_size to a reasonable number
batch_size = 8 # You can adjust this value based on your environment
else:
batch_size = 1 # Process one input at a time

logging.info(f"Starting generation with batch_size = {batch_size}")
confg_iter = iter(self.confg)
index = self.confg.start_count

while True:
batch_inputs_list = []
# Collect inputs into batch
for _ in range(batch_size):
try:
inputs = next(confg_iter)
except StopIteration:
break # No more inputs
batch_inputs_list.append(inputs)

if not batch_inputs_list:
break # All inputs have been processed

# Prepare batch inputs
batch_inputs_dict = {}
for key in batch_inputs_list[0]:
batch_inputs_dict[key] = [input[key] for input in batch_inputs_list]

# Adjust 'generator' field with manual seeds
batch_generators = []
for seed in batch_inputs_dict.get('generator', [None] * len(batch_inputs_list)):
if seed is not None:
batch_generators.append(torch.Generator().manual_seed(seed))
else:
batch_generators.append(torch.Generator())
batch_inputs_dict['generator'] = batch_generators

# Generate images
batch_images = self.pipe.gen(batch_inputs_dict)

# Process generated images
for i, image in enumerate(batch_images):
idx = index + i
self.last_index = idx
self.last_conf = {**batch_inputs_list[i % len(batch_inputs_list)]}
self.last_conf.update(self.pipe.get_config())
self.last_conf.update({'feedback': '?', 'cversion': '0.0.1'})

if save_img:
f_prefix = self.get_file_prefix(idx)
img_name = f_prefix + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(img_name, exif=exif)
self.last_img_name = img_name
if not drop_cfg:
# Save configuration
self.save_conf(idx, self.get_last_conf())
if not save_img or force_collect:
images.append(image)
if callback is not None:
logging.debug("Call callback after generation")
callback()

index += len(batch_images)
logging.debug(f"Processed batch up to index {index}")

logging.debug(f"Generation session completed.")
return images if images else None
Loading