Skip to content

Commit

Permalink
support for batch generation
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Jan 20, 2025
1 parent 4256ddc commit 35e46a1
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 76 deletions.
38 changes: 22 additions & 16 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def gen(self, inputs: dict):
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
image = self.pipe(**kwargs).images
return image


Expand Down Expand Up @@ -501,10 +501,12 @@ def gen(self, inputs: dict):
# so we update kwargs with inputs after pipe_params
kwargs.update({"image": self._input_image})
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return result
res = []
for image in self.pipe(**kwargs).images:
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
res.append(result)
return res


class MaskedIm2ImPipe(Im2ImPipe):
Expand Down Expand Up @@ -644,13 +646,15 @@ def gen(self, inputs):
if 'sample_mode' not in inputs:
inputs['sample_mode'] = self._sample_mode
inputs['original_image'] = normalised
img_gen = super().gen(inputs)

# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
return img_compose
images = super().gen(inputs)
res = []
for img_gen in images:
# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
res.append(img_compose)
return res


class Cond2ImPipe(BasePipe):
Expand Down Expand Up @@ -914,10 +918,12 @@ def gen(self, inputs):
inputs = self.prepare_inputs(inputs)
inputs.update({"image": self._input_image,
"control_image": self._condition_image})
image = self.pipe(**inputs).images[0]
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
res = []
for image in self.pipe(**inputs).images:
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
self._original_size[1] if self._use_input_size else inputs.get('width') ))
return result
res.append(image)
return res


class CIm2ImPipe(Cond2ImPipe):
Expand Down Expand Up @@ -1126,7 +1132,7 @@ def gen(self, inputs):
"mask_image": self._mask_image,
"control_image": self._control_image
})
image = self.pipe(**inputs).images[0]
image = self.pipe(**inputs).images
return image

def _make_inpaint_condition(self, image, image_mask):
Expand Down
133 changes: 87 additions & 46 deletions multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@


class GenSession:

def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
"""
Initialize a GenSession instance.
Args:
session_dir (str):
The directory to store the session files.
Expand All @@ -28,6 +26,8 @@ def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
self.confg = config
self.last_conf = None
self.name_prefix = name_prefix
# Check if sequential CPU offloading is enabled
self.offload_gpu_id = getattr(pipe, 'offload_gpu_id', None)

def get_last_conf(self):
conf = {**self.last_conf}
Expand All @@ -36,9 +36,9 @@ def get_last_conf(self):
'feedback': '?',
'cversion': "0.0.1"})
return conf

def get_last_file_prefix(self):
idxs = self.name_prefix + str(self.last_index).zfill(5)
def get_file_prefix(self, index):
idxs = self.name_prefix + str(index).zfill(5)
f_prefix = os.path.join(self.session_dir, idxs)
if os.path.isfile(f_prefix + ".txt"):
cnt = 1
Expand All @@ -47,17 +47,15 @@ def get_last_file_prefix(self):
f_prefix += "_" + str(cnt)
return f_prefix

def save_last_conf(self):
self.last_cfg_name = self.get_last_file_prefix() + ".txt"
with open(self.last_cfg_name, 'w') as f:
print(json.dumps(self.get_last_conf(), indent=4), file=f)

def gen_sess(self, add_count = 0, save_img=True,
drop_cfg=False, force_collect=False,
callback=None, save_metadata=False):
def save_conf(self, index, conf):
cfg_name = self.get_file_prefix(index) + ".txt"
with open(cfg_name, 'w') as f:
print(json.dumps(conf, indent=4), file=f)

def gen_sess(self, add_count=0, save_img=True, drop_cfg=False,
force_collect=False, callback=None, save_metadata=False):
"""
Run image generation session
Run image generation session.
Args:
add_count (int, *optional*):
The number of additional iterations to add. Defaults to 0.
Expand All @@ -71,43 +69,86 @@ def gen_sess(self, add_count = 0, save_img=True,
A callback function to be called after each iteration. Defaults to None.
save_metadata (bool, *optional*):
Whether to save metadata in the image EXIF. Defaults to False.
Returns:
List[Image.Image]: The generated images if `save_img` is False or `force_collect` is True.
"""
self.confg.max_count += add_count
self.confg.start_count = self.confg.count
self.last_img_name = None
self.last_cfg_name = None
images = None
images = []

if save_img:
os.makedirs(self.session_dir, exist_ok=True)
# collecting images to return if requested or images are not saved
if not save_img or force_collect:
images = []
logging.info(f"add count = {add_count}")
jk = 0
for inputs in self.confg:
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])
logging.debug("start generation")
image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(self.last_img_name, exif=exif)
if not save_img or force_collect:
images += [image]
# saving cfg only if images are saved and dropping is not requested
if save_img and not drop_cfg:
self.save_last_conf()
if callback is not None:
logging.debug("call callback after generation")
callback()
jk += 1
logging.debug(f"done iteration {jk}")
return images

# Determine batch size
if self.offload_gpu_id is not None:
# Sequential CPU offloading is enabled, set batch_size to a reasonable number
batch_size = 8 # You can adjust this value based on your environment
else:
batch_size = 1 # Process one input at a time

logging.info(f"Starting generation with batch_size = {batch_size}")
confg_iter = iter(self.confg)
index = self.confg.start_count

while True:
batch_inputs_list = []
# Collect inputs into batch
for _ in range(batch_size):
try:
inputs = next(confg_iter)
except StopIteration:
break # No more inputs
batch_inputs_list.append(inputs)

if not batch_inputs_list:
break # All inputs have been processed

# Prepare batch inputs
batch_inputs_dict = {}
for key in batch_inputs_list[0]:
batch_inputs_dict[key] = [input[key] for input in batch_inputs_list]

# Adjust 'generator' field with manual seeds
batch_generators = []
for seed in batch_inputs_dict.get('generator', [None] * len(batch_inputs_list)):
if seed is not None:
batch_generators.append(torch.Generator().manual_seed(seed))
else:
batch_generators.append(torch.Generator())
batch_inputs_dict['generator'] = batch_generators

# Generate images
batch_images = self.pipe.gen(batch_inputs_dict)

# Process generated images
for i, image in enumerate(batch_images):
idx = index + i
self.last_index = idx
self.last_conf = {**batch_inputs_list[i % len(batch_inputs_list)]}
self.last_conf.update(self.pipe.get_config())
self.last_conf.update({'feedback': '?', 'cversion': '0.0.1'})

if save_img:
f_prefix = self.get_file_prefix(idx)
img_name = f_prefix + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(img_name, exif=exif)
self.last_img_name = img_name
if not drop_cfg:
# Save configuration
self.save_conf(idx, self.get_last_conf())
if not save_img or force_collect:
images.append(image)
if callback is not None:
logging.debug("Call callback after generation")
callback()

index += len(batch_images)
logging.debug(f"Processed batch up to index {index}")

logging.debug(f"Generation session completed.")
return images if images else None
28 changes: 14 additions & 14 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def test_basic_txt2im(self):
params = dict(prompt="a cube planet, cube-shaped, space photo, masterpiece",
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image = pipe.gen(params)
image = pipe.gen(params)[0]
image.save("cube_test.png")

# generate with different scheduler
if self.model_type() == ModelType.FLUX:
params.update(generator=torch.Generator().manual_seed(seed + 1))
else:
params.update(scheduler=self.schedulers[1])
image_ddim = pipe.gen(params)
image_ddim = pipe.gen(params)[0]
image_ddim.save("cube_test2_dimm.png")
diff = self.compute_diff(image_ddim, image)
# check that difference is large
Expand Down Expand Up @@ -73,15 +73,15 @@ def test_img2img_basic(self):
seed = 49045438434843
pipe.setup(im, strength=0.7, steps=5, guidance_scale=3.3)
self.assertEqual(3.3, pipe.pipe_params['guidance_scale'])
image = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))
image = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0]
image.save('test_img2img_basic.png')
pipe.setup(im, strength=0.7, steps=5, guidance_scale=7.6)
image1 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))
image1 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0]
diff = self.compute_diff(image1, image)
# check that difference is large
self.assertGreater(diff, 1000)
pipe.setup(im, strength=0.7, steps=5, guidance_scale=3.3)
image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))
image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0]
diff = self.compute_diff(image2, image)
# check that difference is small
self.assertLess(diff, 1)
Expand All @@ -108,18 +108,18 @@ def test_maskedimg2img_basic(self):
pipe.setup(**param_3_3)
self.assertEqual(3.3, pipe.pipe_params['guidance_scale'])
image = pipe.gen(dict(prompt="cube planet cartoon style",
generator=torch.Generator().manual_seed(seed)))
generator=torch.Generator().manual_seed(seed)))[0]
self.assertEqual(image.width, img.width)
self.assertEqual(image.height, img.height)
image.save('test_img2img_basic.png')
pipe.setup(**param_7_6)
image1 = pipe.gen(dict(prompt="cube planet cartoon style",
generator=torch.Generator().manual_seed(seed)))
generator=torch.Generator().manual_seed(seed)))[0]
diff = self.compute_diff(image1, image)
# check that difference is large
self.assertGreater(diff, 1000)
pipe.setup(**param_3_3)
image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))
image2 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator().manual_seed(seed)))[0]
diff = self.compute_diff(image2, image)
# check that difference is small
self.assertLess(diff, 1)
Expand All @@ -138,12 +138,12 @@ def test_lpw(self):
params = dict(prompt=prompt,
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image = pipe.gen(params)
image = pipe.gen(params)[0]
image.save("cube_test_lpw.png")
params = dict(prompt=prompt + " , best quality, famous photo",
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image1 = pipe.gen(params)
image1 = pipe.gen(params)[0]
image.save("cube_test_lpw1.png")
diff = self.compute_diff(image1, image)
# check that difference is large
Expand All @@ -161,12 +161,12 @@ def test_lpw_turned_off(self):
params = dict(prompt=prompt,
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image = pipe.gen(params)
image = pipe.gen(params)[0]
image.save("cube_test_no_lpw.png")
params = dict(prompt=prompt + " , best quality, famous photo",
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image1 = pipe.gen(params)
image1 = pipe.gen(params)[0]
image.save("cube_test_no_lpw1.png")
diff = self.compute_diff(image1, image)
# check that difference is large
Expand All @@ -193,7 +193,7 @@ def test_controlnet(self):
params = dict(prompt="cube planet minecraft style",
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
image = pipe.gen(params)
image = pipe.gen(params)[0]
image.save("mech_test.png")
img_ref = PIL.Image.open(imgpth)
self.assertEqual(image.width, img_ref.width)
Expand All @@ -205,7 +205,7 @@ def test_controlnet(self):
else:
# generate with different scheduler
params.update(scheduler=self.schedulers[1])
image_ddim = pipe.gen(params)
image_ddim = pipe.gen(params)[0]
image_ddim.save("cube_test2_dimm.png")
diff = self.compute_diff(image_ddim, image)
# check that difference is large
Expand Down

0 comments on commit 35e46a1

Please sign in to comment.