From 617c5651735a3caa4ee9d994077dca798a6e86ca Mon Sep 17 00:00:00 2001 From: majie <1264592305@qq.com> Date: Sun, 8 Oct 2023 15:17:26 +0800 Subject: [PATCH] refactor the convert script and add readme --- controlnet/.gitignore | 16 - controlnet/README.md | 34 +- controlnet/{ldm => }/cldm/cldm.py | 89 +++-- controlnet/cldm/dataset.py | 312 +++++++++++++++++ controlnet/{ldm => }/cldm/ddim_hacked.py | 61 ---- .../configs/train_controlnet_config.json | 26 ++ controlnet/configs/v1-inference-chinese.yaml | 64 ---- controlnet/demo.py | 58 --- controlnet/inpaint.py | 331 ------------------ ...trolnet.py => run_controlnet_inference.py} | 124 ++++--- .../{run_train.py => run_controlnet_train.py} | 129 +++---- controlnet/run_db_train.py | 257 -------------- controlnet/scripts/run_train.sh | 38 -- controlnet/scripts/run_txt2img.sh | 33 -- controlnet/test.py | 148 -------- controlnet/torch2ms/convert.py | 54 +-- controlnet/txt2img.py | 315 ----------------- 17 files changed, 581 insertions(+), 1508 deletions(-) delete mode 100644 controlnet/.gitignore rename controlnet/{ldm => }/cldm/cldm.py (83%) create mode 100644 controlnet/cldm/dataset.py rename controlnet/{ldm => }/cldm/ddim_hacked.py (77%) create mode 100644 controlnet/configs/train_controlnet_config.json delete mode 100644 controlnet/configs/v1-inference-chinese.yaml delete mode 100644 controlnet/demo.py delete mode 100644 controlnet/inpaint.py rename controlnet/{test_controlnet.py => run_controlnet_inference.py} (52%) rename controlnet/{run_train.py => run_controlnet_train.py} (64%) mode change 100755 => 100644 delete mode 100644 controlnet/run_db_train.py delete mode 100755 controlnet/scripts/run_train.sh delete mode 100644 controlnet/scripts/run_txt2img.sh delete mode 100644 controlnet/test.py delete mode 100644 controlnet/txt2img.py diff --git a/controlnet/.gitignore b/controlnet/.gitignore deleted file mode 100644 index 19c4f59..0000000 --- a/controlnet/.gitignore +++ /dev/null @@ -1,16 +0,0 @@ -rank_0/* - -models/* - -torch2ms/ms_weight/ -torch2ms/numpy_weight/ - -model_struct/ - -demo/ - -output/ - -slurm.sh - -*.pyc \ No newline at end of file diff --git a/controlnet/README.md b/controlnet/README.md index 9f8d4aa..b1e578f 100644 --- a/controlnet/README.md +++ b/controlnet/README.md @@ -1,5 +1,33 @@ -# Mindspore-controlnet +# Mindspore-ControlNet -1. Convert pytorch ckeckpoints to mindspore +The Stable diffusion code is copied from https://github.com/mindspore-lab/minddiffusion -2. Run test_controlnet.py to use controlnet. \ No newline at end of file +## Install +``` shell + pip install mindspore==1.9.0 + pip install -r requirements.txt +``` + + +## 1. Inference with pretrained ControlNet +1. download pytorch controlnet checkpoints from https://huggingface.co/lllyasviel/ControlNet/tree/main/models or https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main. + +2. convert downloaded pytorch checkpoints to mindspore checkpoints, or directly download from https://huggingface.co/unrealMJ/MindSpore-ControlNet and put them into torch2ms/ms_weight + ```shell + python torch2ms/convert.py --input_path xxxx --output_path xxxx # convert full model + python torch2ms/convert.py --input_path xxxx --output_path xxxx --only_controlnet # convert controlnet only + ``` + +3. Run run_controlnet_inference.py to use controlnet. + ```shell + python run_controlnet_inference.py --input_path xxxx --output_path xxxx + ``` + +## 2. Train ControlNet from scratch +1. Download the dataset from https://huggingface.co/datasets/fusing/fill50k + + +2. Run run_controlnet_train.py to train controlnet. + ```shell + python run_controlnet_train.py --data_path xxxx --train_config configs/train_controlnet_config.json --model_config configs/cldm_v15.yaml + ``` diff --git a/controlnet/ldm/cldm/cldm.py b/controlnet/cldm/cldm.py similarity index 83% rename from controlnet/ldm/cldm/cldm.py rename to controlnet/cldm/cldm.py index 0b14179..25e84b9 100644 --- a/controlnet/ldm/cldm/cldm.py +++ b/controlnet/cldm/cldm.py @@ -1,15 +1,15 @@ import mindspore as ms import mindspore.nn as nn -from modules.diffusionmodules.openaimodel import UNetModel, ResBlock, Downsample, AttentionBlock -from modules.diffusionmodules.util import ( +from ldm.modules.diffusionmodules.openaimodel import UNetModel, ResBlock, Downsample, AttentionBlock +from ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) -from modules.attention import SpatialTransformer -from models.diffusion.ddpm import LatentDiffusion +from ldm.modules.attention import SpatialTransformer +from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.util import exists, instantiate_from_config @@ -17,23 +17,21 @@ class ControlledUnetModel(UNetModel): def construct(self, x, timesteps, context, control_1, control_2, control_3, control_4, control_5, control_6): - # control = [] - # split api 与2.0不同 - control = ms.ops.split(control_1, axis=0, output_num=control_1.shape[0]) \ - + ms.ops.split(control_2, axis=0, output_num=control_2.shape[0]) \ - + ms.ops.split(control_3, axis=0, output_num=control_3.shape[0]) \ - + ms.ops.split(control_4, axis=0, output_num=control_4.shape[0]) \ - + ms.ops.split(control_5, axis=0, output_num=control_5.shape[0]) \ - + ms.ops.split(control_6, axis=0, output_num=control_6.shape[0]) - control = list(control) - + control = [] + if control_1 is not None: + control = ms.ops.split(control_1, axis=0, output_num=control_1.shape[0]) \ + + ms.ops.split(control_2, axis=0, output_num=control_2.shape[0]) \ + + ms.ops.split(control_3, axis=0, output_num=control_3.shape[0]) \ + + ms.ops.split(control_4, axis=0, output_num=control_4.shape[0]) \ + + ms.ops.split(control_5, axis=0, output_num=control_5.shape[0]) \ + + ms.ops.split(control_6, axis=0, output_num=control_6.shape[0]) + control = list(control) hs = [] - # mindspore不需要包装torch.no_grad() t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) h = x.astype(self.dtype) - # h = x + for module in self.input_blocks: for cell in module: h = cell(h, emb, context) @@ -43,26 +41,22 @@ def construct(self, x, timesteps, context, h = module(h, emb, context) control_idx = -1 - if control is not None: + if len(control) != 0: h += control[control_idx] control_idx -= 1 - only_mid_control = False hs_idx = -1 - # TODO: check all tensor dtype for i, module in enumerate(self.output_blocks): - if only_mid_control or control is None: + if only_mid_control or len(control) == 0: h = ms.ops.concat([h, hs[hs_idx].astype(h.dtype)], axis=1) else: h = ms.ops.concat([h, hs[hs_idx].astype(h.dtype) + control[control_idx].astype(h.dtype)], axis=1) control_idx -= 1 hs_idx -= 1 - # h = module(h, emb, context) for cell in module: h = cell(h, emb, context) - # # h = h.astype(x.dtype) return self.out(h) @@ -300,7 +294,7 @@ def __init__( def make_zero_conv(self, channels): return nn.SequentialCell([ zero_module( - conv_nd(self.dims, channels, channels, 1, padding=0, has_bias=True, pad_mode='pad') + conv_nd(self.dims, channels, channels, 1, padding=0, has_bias=True, pad_mode='pad').to_float(self.dtype) ) ]) @@ -342,10 +336,6 @@ def __init__(self, control_stage_config, control_key, only_mid_control, *args, * self.control_key = control_key self.only_mid_control = only_mid_control self.control_scales = [1.0] * 13 - - def get_input(self, x, c): - # TODO: support train - pass def apply_model(self, x_noisy, t, cond, *args, **kwargs): diffusion_model = self.model.diffusion_model @@ -353,7 +343,11 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs): cond_txt = ms.ops.concat(cond['c_crossattn'], 1) if cond['c_concat'] is None: - pass + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, + control_1=None, control_2=None, + control_3=None, control_4=None, + control_5=None, control_6=None, + ) else: control = self.control_model(x=x_noisy, hint=ms.ops.concat(cond['c_concat'], 1), timesteps=t, context=cond_txt) @@ -366,8 +360,6 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs): control_5 = ms.ops.concat(control[7:9], 0) control_6 = ms.ops.concat(control[9: ]) - # from mindspore.common import mutable - # control = mutable(control) eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control_1=control_1, control_2=control_2, @@ -376,3 +368,40 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs): ) return eps + def get_input(self, x, c, control): + x, c = super().get_input(x, c) + + control = ms.numpy.transpose(control, (0, 3, 1, 2)) + + return x, c, control + + def construct(self, x, c, control): + t = ms.ops.UniformInt()((x.shape[0],), ms.Tensor(0, dtype=ms.dtype.int32), ms.Tensor(self.num_timesteps, dtype=ms.dtype.int32)) + x, c, control = self.get_input(x, c, control) + c = self.get_learned_conditioning_fortrain(c) + return self.p_losses(x, c, t, control) + + def p_losses(self, x_start, cond, t, control, noise=None): + noise = ms.numpy.randn(x_start.shape) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + tmp = {'c_concat': [control], 'c_crossattn': [cond]} + model_output = self.apply_model(x_noisy, t, tmp) + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + + logvar_t = self.logvar[t] + loss = loss_simple / ms.ops.exp(logvar_t) + logvar_t + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean((1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss += (self.original_elbo_weight * loss_vlb) + + return loss \ No newline at end of file diff --git a/controlnet/cldm/dataset.py b/controlnet/cldm/dataset.py new file mode 100644 index 0000000..4a4eff4 --- /dev/null +++ b/controlnet/cldm/dataset.py @@ -0,0 +1,312 @@ +import os +import gc +from random import randint +from collections import defaultdict + +import pandas as pd +import albumentations +import numpy as np +from PIL import Image +import imagesize +import mindspore as ms +from mindspore.dataset import GeneratorDataset + +from toolz.sandbox import unzip + +def control_collate(inputs): + """ + Return: + :img_feat (batch_size, height, weight, 3) + :txt_tokens (n, max_txt_len) + """ + img_feat, txt_tokens, control_feat = map(list, unzip(inputs)) + batch = { + 'img_feat': img_feat, + 'txt_tokens': txt_tokens, + 'control_feat': control_feat + } + return batch + +data_column = [ + 'img_feat', + 'txt_tokens', + 'control_feat' +] + + +def load_data( + data_path, + batch_size, + tokenizer, + image_size=512, + image_filter_size=256, + device_num=1, + random_crop=False, + filter_small_size=True, + rank_id=0, + sample_num=-1 + ): + + + if not os.path.exists(data_path): + raise ValueError("Data directory does not exist!") + all_images, all_captions, all_conds = list_image_files_captions_recursively(data_path) + print(f"The first image path is {all_images[0]}, and the caption is {all_captions[0]}") + print(f"total data num: {len(all_images)}") + dataloaders = {} + dataset = ImageDataset( + batch_size, + all_images, + all_captions, + all_conds, + tokenizer, + image_size, + image_filter_size, + random_crop=random_crop, + filter_small_size=filter_small_size + ) + datalen = dataset.__len__ + loader = build_dataloader_ft(dataset, datalen, control_collate, batch_size, device_num, rank_id=rank_id) + dataloaders["ftT2I"] = loader + if sample_num==-1: + batchlen = datalen//(batch_size * device_num) + else: + batchlen = sample_num + metaloader = MetaLoader(dataloaders, datalen=batchlen, task_num=len(dataloaders.keys())) + dataset = GeneratorDataset(metaloader, column_names=data_column, shuffle=True) + + return dataset + + +def build_dataloader_ft(dataset, datalens, collate_fn, batch_size, device_num, rank_id=0): + sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num) + loader = DataLoader(dataset, batch_sampler=sampler, collate_fn=collate_fn, device_num=device_num, drop_last=True, rank_id=rank_id) + return loader + + +def list_image_files_captions_recursively(data_path): + import json + all_images = [] + all_conds = [] + all_captions = [] + with open(f'{data_path}/train.jsonl', 'r') as f: + for line in f: + data = json.loads(line) + all_images.append(f'{data_path}/{data["image"]}') + all_conds.append(f'{data_path}/{data["conditioning_image"]}') + all_captions.append(data["text"]) + + assert len(all_images) == len(all_captions) + return all_images, all_captions, all_conds + + +class ImageDataset(): + def __init__( + self, + batch_size, + image_paths, + captions, + conds, + tokenizer, + image_size, + image_filter_size, + shuffle=True, + random_crop=False, + filter_small_size=False + ): + super().__init__() + self.batch_size = batch_size + self.tokenizer = tokenizer + self.image_size = image_size + self.image_filter_size = image_filter_size + self.local_images = image_paths + self.local_captions = captions + self.local_control = conds + self.shuffle = shuffle + self.random_crop = random_crop + self.filter_small_size = filter_small_size + + @property + def __len__(self): + return len(self.local_images) + + def random_sample(self): + return self.__getitem__(randint(0, self.__len__() - 1)) + + def sequential_sample(self, ind): + if ind >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(ind + 1) + + def skip_sample(self, ind): + if self.shuffle: + return self.random_sample() + return self.sequential_sample(ind=ind) + + def __getitem__(self, idx): + # images preprocess + img_path = self.local_images[idx] + img = Image.open(img_path).convert('RGB') + img = np.asarray(img).astype(np.float32) + img = (img / 127.5 - 1.0) + + # control + control_path = self.local_control[idx] + control = Image.open(control_path).convert('RGB') + control = np.asarray(control).astype(np.float32) + control = control / 255.0 + + # caption preprocess + caption = self.local_captions[idx] + caption_input = self.tokenize(caption) + return np.array(img, dtype=np.float32), np.array(caption_input, dtype=np.int32), np.array(control, dtype=np.float32) + + def tokenize(self, text): + SOT_TEXT = "<|startoftext|>" + EOT_TEXT = "<|endoftext|>" + CONTEXT_LEN = 77 + + sot_token = self.tokenizer.encoder[SOT_TEXT] + eot_token = self.tokenizer.encoder[EOT_TEXT] + tokens = [sot_token] + self.tokenizer.encode(text) + [eot_token] + result = np.zeros([CONTEXT_LEN]) + if len(tokens) > CONTEXT_LEN: + tokens = tokens[:CONTEXT_LEN - 1] + [eot_token] + result[:len(tokens)] = tokens + + return result + + +class BatchSampler: + """ + Batch Sampler + """ + + def __init__(self, lens, batch_size, device_num): + self._lens = lens + self._batch_size = batch_size * device_num + + def _create_ids(self): + return list(range(self._lens)) + + def __iter__(self): + ids = self._create_ids() + batches = [ids[i:i + self._batch_size] for i in range(0, len(ids), self._batch_size)] + gc.collect() + return iter(batches) + + def __len__(self): + raise ValueError("NOT supported. " + "This has some randomness across epochs") + + +class DataLoader: + """ DataLoader """ + + def __init__(self, dataset, batch_sampler, collate_fn, device_num=1, drop_last=True, rank_id=0): + self.dataset = dataset + self.batch_sampler = batch_sampler + self.collat_fn = collate_fn + self.device_num = device_num + self.rank_id = rank_id + self.drop_last = drop_last + self.batch_size = len(next(iter(self.batch_sampler))) + + def __iter__(self): + self.step_index = 0 + self.batch_indices = iter(self.batch_sampler) + + return self + + def __next__(self): + try: + indices = next(self.batch_indices) + if len(indices) != self.batch_size and self.drop_last: + return self.__next__() + except StopIteration: + self.batch_indices = iter(self.batch_sampler) + indices = next(self.batch_indices) + data = [] + per_batch = len(indices) // self.device_num + index = indices[self.rank_id * per_batch:(self.rank_id + 1) * per_batch] + for idx in index: + data.append(self.dataset[idx]) + + data = self.collat_fn(data) + return data + + +class MetaLoader(): + """ wraps multiple data loaders """ + + def __init__(self, loaders, datalen, task_num=1): + assert isinstance(loaders, dict) + self.task_num = task_num + self.name2loader = {} + self.name2iter = {} + self.sampling_pools = [] + self.loaders = loaders + self.datalen = datalen + for n, l in loaders.items(): + if isinstance(l, tuple): + l, r = l + elif isinstance(l, DataLoader): + r = 1 + else: + raise ValueError() + self.name2loader[n] = l + self.name2iter[n] = iter(l) + self.sampling_pools.extend([n] * r) + + self.task = self.sampling_pools[0] + self.task_label = [0] * self.task_num + self.step = 0 + self.step_cnt = 0 + self.task_index_list = np.random.permutation(self.task_num) + self.all_ids = [] + + def init_iter(self, task_name): + self.name2iter[task_name] = iter(self.name2loader[task_name]) + + def return_ids(self): + return self.all_ids + + def get_batch(self, batch, task): + """ get_batch """ + batch = defaultdict(lambda: None, batch) + img_feat = batch.get('img_feat', None) + txt_tokens = batch.get('txt_tokens', None) + control_feat = batch.get('control_feat', None) + output = (img_feat, txt_tokens, control_feat) + + return output + + def __getitem__(self, index): + if self.step_cnt == self.task_num: + self.task_index_list = np.random.permutation(self.task_num) + self.step_cnt = 0 + task_index = self.task_index_list[self.step_cnt] + local_task = self.sampling_pools[task_index] + + iter_ = self.name2iter[local_task] + + name = local_task + try: + batch = next(iter_) + except StopIteration: + self.init_iter(local_task) + iter_ = self.name2iter[local_task] + batch = next(iter_) + + task = name.split('_')[0] + for key, val in batch.items(): + if isinstance(val, np.ndarray): + if val.dtype == np.int64: + batch[key] = val.astype(np.int32) + + output = self.get_batch(batch, task) + self.step_cnt += 1 + return output + + def __len__(self): + return self.datalen \ No newline at end of file diff --git a/controlnet/ldm/cldm/ddim_hacked.py b/controlnet/cldm/ddim_hacked.py similarity index 77% rename from controlnet/ldm/cldm/ddim_hacked.py rename to controlnet/cldm/ddim_hacked.py index 30d440c..6a6e319 100644 --- a/controlnet/ldm/cldm/ddim_hacked.py +++ b/controlnet/cldm/ddim_hacked.py @@ -1,17 +1,3 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ import mindspore as ms from mindspore import ops @@ -141,7 +127,6 @@ def plms_sampling(self, cond, shape, total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") - # iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) iterator = time_range old_eps = [] @@ -154,7 +139,6 @@ def plms_sampling(self, cond, shape, assert x0 is not None img_orig = self.model.q_sample(x0, ts, ms.numpy.randn(x0.shape)) img = img_orig * mask + (1. - mask) * img - # print(f'i={i}, {img.shape}') outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, @@ -162,7 +146,6 @@ def plms_sampling(self, cond, shape, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next) - # print('end') img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: @@ -203,53 +186,13 @@ def get_model_output(x, t): sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # def get_x_prev_and_pred_x0(e_t, index): - # # select parameters corresponding to the currently considered timestep - # a_t = ms.numpy.full((b, 1, 1, 1), alphas[index]) - # a_prev = ms.numpy.full((b, 1, 1, 1), alphas_prev[index]) - # sigma_t = ms.numpy.full((b, 1, 1, 1), sigmas[index]) - # sqrt_one_minus_at = ms.numpy.full((b, 1, 1, 1), sqrt_one_minus_alphas[index]) - - # # current prediction for x_0 - # pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - # if quantize_denoised: - # pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # # direction pointing to x_t - # dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - # noise = sigma_t * noise_like(x.shape, repeat_noise) * temperature - # if noise_dropout > 0.: - # noise, _ = ops.dropout(noise, p=noise_dropout) - # x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - - # return x_prev, pred_x0 - e_t = get_model_output(x, t) - # if len(old_eps) == 0: - # # Pseudo Improved Euler (2nd order) - # x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - # e_t_next = get_model_output(x_prev, t_next) - # e_t_prime = (e_t + e_t_next) / 2 - # elif len(old_eps) == 1: - # # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - # e_t_prime = (3 * e_t - old_eps[-1]) / 2 - # elif len(old_eps) == 2: - # # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - # e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - # elif len(old_eps) >= 3: - # # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - # e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - a_t = ms.numpy.full((b, 1, 1, 1), alphas[index]) a_prev = ms.numpy.full((b, 1, 1, 1), alphas_prev[index]) sigma_t = ms.numpy.full((b, 1, 1, 1), sigmas[index]) sqrt_one_minus_at = ms.numpy.full((b, 1, 1, 1), sqrt_one_minus_alphas[index]) - # print(f'x shape is {x.shape}') - # print(f'e_t shape is {e_t.shape}') - # print(f'sqrt_one_minus_at shape is {sqrt_one_minus_at.shape}') - # print(f'a_t shape is {a_t.shape}') pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) @@ -260,8 +203,4 @@ def get_model_output(x, t): noise, _ = ops.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - - - # x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - return x_prev, pred_x0, e_t \ No newline at end of file diff --git a/controlnet/configs/train_controlnet_config.json b/controlnet/configs/train_controlnet_config.json new file mode 100644 index 0000000..228f99d --- /dev/null +++ b/controlnet/configs/train_controlnet_config.json @@ -0,0 +1,26 @@ +{ + "model_config": "controlnet/configs/cldm_v15.yaml", + "pretrained_model_path": "torch2ms/weight", + "data_path": "dataset/fill50k", + "train_batch_size": 1, + "gradient_accumulation_steps": 1, + "optim": "adamw", + "patch_size":32, + "epochs": 20, + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "warmup_steps": 1000, + "seed": 3407, + "image_size": 512, + "image_filter_size": 256, + "random_crop": false, + "filter_small_size": true, + "start_learning_rate": 1e-5, + "end_learning_rate": 1e-7, + "decay_steps": 0, + "save_checkpoint_steps": 10000 +} diff --git a/controlnet/configs/v1-inference-chinese.yaml b/controlnet/configs/v1-inference-chinese.yaml deleted file mode 100644 index 0c0bcc2..0000000 --- a/controlnet/configs/v1-inference-chinese.yaml +++ /dev/null @@ -1,64 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - use_fp16: True - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - use_fp16: True - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - use_fp16: True - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder_ZH - params: - use_fp16: True diff --git a/controlnet/demo.py b/controlnet/demo.py deleted file mode 100644 index cd5bedc..0000000 --- a/controlnet/demo.py +++ /dev/null @@ -1,58 +0,0 @@ -import mindspore as ms -import os - -from ldm.util import instantiate_from_config -from omegaconf import OmegaConf - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - param_not_load = ms.load_param_into_net(model, param_dict) - print("param not load:", param_not_load) - else: - print(f"!!!Warning!!!: {ckpt} doesn't exist") - - return model - - -def load_full_model(config, path='./models/wukong/', verbose=False): - model = instantiate_from_config(config.model) - param_not_load = [] - if os.path.isdir(path): - unet = ms.load_checkpoint(os.path.join(path, 'unet.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.model, unet)) - vae = ms.load_checkpoint(os.path.join(path, 'vae.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.first_stage_model, vae)) - text_encoder = ms.load_checkpoint(os.path.join(path, 'text_encoder.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.cond_stage_model, text_encoder)) - else: - param_dict = ms.load_checkpoint(path) - param_not_load.extend(ms.load_param_into_net(model, param_dict)) - print("param not load:", param_not_load) - print("load model from", path) - return model - - - - -if __name__ == '__main__': - config = 'configs/v1-inference-chinese.yaml' - config = OmegaConf.load(config) - model1 = load_full_model(config) - print('---------------------------------------------') - model2 = load_full_model(config, path='./models/wukong-huahua-ms.ckpt') - - for (k1, v1), (k2, v2) in zip(model1.parameters_and_names(), model2.parameters_and_names()): - if k1.startswith('first_stage_model.encoder.down.3.downsample'): - continue - if k1.startswith('first_stage_model.decoder.up.0.upsample'): - continue - if not (v1 == v2).all(): - print(k1, k2, v1.sum(), v2.sum()) - print('error') - exit(0) - print('ok') \ No newline at end of file diff --git a/controlnet/inpaint.py b/controlnet/inpaint.py deleted file mode 100644 index cee7ddf..0000000 --- a/controlnet/inpaint.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import argparse -import datetime -import math -import os -import sys -import shutil - -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -import mindspore as ms -import mindspore.dataset.vision as vision -from mindspore import Tensor -from mindspore import dtype as mstype -from mindspore import ops - -workspace = os.path.dirname(os.path.abspath(__file__)) -print("workspace:", workspace, flush=True) -sys.path.append(workspace) -from ldm.models.diffusion.plms import PLMSSampler -from ldm.util import instantiate_from_config - - -def make_batch_sd( - image, - mask, - txt, - num_samples=1): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = Tensor(image, dtype=mstype.float32) / 127.5 - 1.0 - - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = Tensor(mask, dtype=mstype.float32) - - masked_image = image * (mask < 0.5) - - batch = { - "image": image.repeat(num_samples, axis=0), - "txt": num_samples * [txt], - "mask": mask.repeat(num_samples, axis=0), - "masked_image": masked_image.repeat(num_samples, axis=0), - } - return batch - -def inpaint(sampler, image, mask, prompt, seed, scale, sample_steps, num_samples=1, w=512, h=512): - model = sampler.model - - prng = np.random.RandomState(seed) - start_code = prng.randn(num_samples, 4, h // 8, w // 8) - start_code = Tensor(start_code, dtype=mstype.float32) - - batch = make_batch_sd(image, mask, txt=prompt, num_samples=num_samples) - - c = model.get_learned_conditioning(batch["txt"]) - - c_cat = list() - for ck in model.concat_keys: - cc = batch[ck] - if ck != model.masked_image_key: - bchw = [num_samples, 4, h // 8, w // 8] - cc = x = ops.ResizeNearestNeighbor((bchw[-2], bchw[-1]))(cc) - else: - cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = ops.concat(c_cat, axis=1) - - # cond - cond = {"c_concat": c_cat, "c_crossattn": c} - - # uncond cond - uc_cross = model.get_learned_conditioning(num_samples * [""]) - uc_full = {"c_concat": c_cat, "c_crossattn": uc_cross} - - shape = [model.channels, h // 8, w // 8] - samples_cfg, intermediates = sampler.sample( - sample_steps, - num_samples, - shape, - cond, - verbose=False, - eta=0.0, - unconditional_guidance_scale=scale, - unconditional_conditioning=uc_full, - x_T=start_code, - x0=c_cat[:, 1:], - ) - - x_samples = model.decode_first_stage(samples_cfg) - - result = ops.clip_by_value((x_samples + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) - - result = result.asnumpy().transpose(0, 2, 3, 1) - result = result * 255 - - result = [Image.fromarray(img.astype(np.uint8)) for img in result] - - return result - - -def image_grid(imgs, rows, cols): - w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h)) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - -def main(args): - device_id = int(os.getenv("DEVICE_ID", 0)) - ms.context.set_context( - mode=ms.context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, - max_device_memory="30GB" - ) - - if args.save_graph: - save_graphs_path = "graph" - shutil.rmtree(save_graphs_path) - ms.context.set_context( - save_graphs=True, - save_graphs_path=save_graphs_path - ) - - seed_everything(args.seed) - - if not os.path.isabs(args.config): - args.config = os.path.join(workspace, args.config) - config = OmegaConf.load(f"{args.config}") - model = load_model_from_config(config, f"{os.path.join(args.ckpt_path, args.ckpt_name)}") - if args.sampler.lower() == "plms": - sampler = PLMSSampler(model) - else: - raise TypeError("unsupported sampler type") - - img_size = args.img_size - num_samples = args.num_samples - prompt = args.prompt - image = Image.open(args.img).convert("RGB") - mask_image = Image.open(args.mask).convert("RGB") - if args.aug == "resize": - aug_func = lambda x_: x_.resize((img_size, img_size)) - elif args.aug == 'crop': - assert img_size % 2 == 0 - mask_idx = np.where(np.array(mask_image)[:, :, 0] > 127.5) - mask_center = np.array(list(map(np.mean, mask_idx)))[::-1].astype('int') - mask_center = [x_.clip(img_size // 2, size_ - img_size // 2) for x_, size_ in zip(mask_center, image.size)] - aug_func = lambda x_: x_.crop((mask_center[0] - img_size // 2, mask_center[1] - img_size // 2, - mask_center[0] + img_size // 2, mask_center[1] + img_size // 2)) - elif args.aug == 'resizecrop': - mask_idx = np.where(np.array(mask_image)[:, :, 0] > 127.5) - mask_center = np.array(list(map(np.mean, mask_idx)))[::-1].astype('int') - mask_range = max(*[x_.max() - x_.min() for x_ in mask_idx]) - new_img_size = math.ceil(mask_range / args.mask_ratio) - mask_center = [x_.clip(new_img_size // 2, size_ - new_img_size // 2) for x_, size_ in - zip(mask_center, image.size)] - aug_func = lambda x_: x_.crop((mask_center[0] - new_img_size // 2, mask_center[1] - new_img_size // 2, - mask_center[0] + new_img_size // 2, mask_center[1] + new_img_size // 2)).resize( - (img_size, img_size)) - else: - aug_func = lambda x_: x_ - image = aug_func(image) - mask_image = aug_func(mask_image) - mask_image = Image.fromarray(np.array(mask_image)[:, :, -1] > 127.5) - - images = [image, mask_image] - for _ in range(math.ceil(num_samples / args.batch_size)): - output = inpaint( - sampler=sampler, - image=image, - mask=mask_image, - prompt=prompt, - seed=args.seed, - scale=args.guidance_scale, - sample_steps=args.sample_steps, - num_samples=args.batch_size, - h=img_size, - w=img_size - ) - images.extend(output) - - im_save = image_grid(images, 1, num_samples + 2) - ct = datetime.datetime.now().strftime("%Y_%d_%b_%H_%M_%S_") - img_name = ct + prompt.replace(" ", "_") + ".png" - os.makedirs(args.save_path, exist_ok=True) - im_save.save(os.path.join(args.save_path, img_name)) - print("finish inpaint.") - - -def seed_everything(seed): - if seed: - ms.set_seed(seed) - np.random.seed(seed) - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - param_not_load = ms.load_param_into_net(model, param_dict) - print("param not load:", param_not_load) - else: - print(f"!!!Warning!!!: {ckpt} doesn't exist") - - return model - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--img", - type=str, - required=True, - help="path to origin image" - ) - parser.add_argument( - "--mask", - type=str, - required=True, - help="path to mask image" - ) - parser.add_argument( - "--save_path", - type=str, - default="output", - help="path to save image" - ) - parser.add_argument( - "--prompt", - type=str, - required=True, - help="" - ) - parser.add_argument( - "--config", - type=str, - default="configs/wukong-huahua_inpaint_inference.yaml", - help="" - ) - parser.add_argument( - "--ckpt_path", - type=str, - default="models", - help="" - ) - parser.add_argument( - "--ckpt_name", - type=str, - default="wukong-huahua-inpaint-ms.ckpt", - help="" - ) - parser.add_argument( - "--aug", - type=str, - default="resize", - help="augment type" - ) - parser.add_argument( - "--mask_ratio", - type=float, - default=.75, - help="" - ) - parser.add_argument( - "--num_samples", - type=int, - default=4, - help="num of total samples" - ) - parser.add_argument( - "--img_size", - type=int, - default=512, - help="" - ) - parser.add_argument( - "--batch_size", - type=int, - default=4, - help="batch size of model" - ) - parser.add_argument( - "--seed", - type=int, - default=42 - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=7.5, - help="" - ) - parser.add_argument( - "--sample_steps", - type=int, - default=30, - help="" - ) - parser.add_argument( - "--sampler", - type=str, - default="plms", - help="support plms only" - ) - parser.add_argument( - "--save_graph", - action='store_true', - help="" - ) - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/controlnet/test_controlnet.py b/controlnet/run_controlnet_inference.py similarity index 52% rename from controlnet/test_controlnet.py rename to controlnet/run_controlnet_inference.py index a6b18f3..cd40bfe 100644 --- a/controlnet/test_controlnet.py +++ b/controlnet/run_controlnet_inference.py @@ -3,16 +3,19 @@ import cv2 from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from ldm.cldm.ddim_hacked import PLMSSampler as DDIMSampler +from cldm.ddim_hacked import PLMSSampler as DDIMSampler import mindspore as ms from mindspore import ops import os +import argparse +import time class CannyDetector: def __call__(self, img, low_threshold, high_threshold): return cv2.Canny(img, low_threshold, high_threshold) +apply_canny = CannyDetector() def create_model(config_path): config = OmegaConf.load(config_path) @@ -71,25 +74,19 @@ def load_state_dict(model, path='torch2ms/ms_weight'): return model -device_id = int(os.getenv("DEVICE_ID", 0)) -ms.context.set_context( - mode=ms.context.GRAPH_MODE, - device_target="GPU", - device_id=device_id, - max_device_memory="30GB" -) - - -apply_canny = CannyDetector() - -model = create_model('./configs/cldm_v15.yaml') -model = load_state_dict(model) +def load_model(config_path, pretrained_path): + config = OmegaConf.load(config_path) + model = instantiate_from_config(config.model) -ddim_sampler = DDIMSampler(model) + model = load_state_dict(model, path=pretrained_path) + ddim_sampler = DDIMSampler(model) + return model, ddim_sampler -def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold): +def process(input_path, low_threshold=100, high_threshold=200, image_resolution=512): + input_image = Image.open(input_path).convert('RGB') + input_image = np.asarray(input_image) img = resize_image(HWC3(input_image), image_resolution) img = input_image H, W, C = img.shape @@ -97,15 +94,27 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) - # 这里不使用permute,使用numpy transpose或者mindspore transpose - # control = ms.Tensor(detected_map.copy()).float() / 255.0 - control = np.transpose(detected_map.copy(), (2, 0, 1)) + return img, detected_map + + +def inference(control, config_path, pretrained_path, + prompt, n_prompt, num_samples, + ddim_steps, guess_mode, strength, scale, eta, + width=512, height=512): + + control_map = control.copy() + + # load model + model, ddim_sampler = load_model(config_path=config_path, pretrained_path=pretrained_path) + + # process control map + control = np.transpose(control.copy(), (2, 0, 1)) control = ms.Tensor(control.copy()) / 255.0 control = ops.stack([control for _ in range(num_samples)], axis=0) - cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)]} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} - shape = (4, H // 8, W // 8) + shape = (4, height // 8, width // 8) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, @@ -118,36 +127,59 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti x_samples = x_samples.asnumpy().copy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] - return [255 - detected_map] + results - + return [255 - control_map] + results -image = Image.open('input_image.png').convert('RGB') -image = np.asarray(image) -prompt = 'a girl' -a_prompt = 'best quality, extremely detailed' -n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' -num_samples = 1 -image_resolution = 512 -ddim_steps = 20 -guess_mode = False -strength = 1.0 -scale = 9.0 -seed = -1 -eta = 0.0 -low_threshold = 100 -high_threshold = 200 -result = process(image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold) - - -def save(results): +def save(results, output_path): control = results[0] samples = results[1:] - os.makedirs('output/controlnet', exist_ok=True) + dt_string = time.strftime("%Y-%m-%d-%H-%M-%S") + os.makedirs(f'{output_path}/{dt_string}', exist_ok=True) - Image.fromarray(control).save('output/controlnet/control.png') + Image.fromarray(control).save(f'{output_path}/{dt_string}/control.png') for i in range(len(samples)): - Image.fromarray(samples[i]).save(f'output/controlnet/sample_{i}.png') + Image.fromarray(samples[i]).save(f'{output_path}/{dt_string}/sample_{i}.png') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--pretrained_path', default='torch2ms/ms_weight', type=str) + parser.add_argument('--config_path', default='configs/cldm_v15.yaml', type=str) + parser.add_argument('--input_path', default=None, type=str) + parser.add_argument( + '--prompt', + default='a girl,best quality,extremely detailed', + type=str + ) + parser.add_argument( + '--negative_prompt', + default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality', + type=str + ) + parser.add_argument('--output_path', default='output/controlnet', type=str) + parser.add_argument('--num_samples', default=1, type=int) + parser.add_argument('--image_resolution', default=512, type=int) + parser.add_argument('--ddim_steps', default=20, type=int) + parser.add_argument('--guess_mode', default=False, type=bool) + parser.add_argument('--strength', default=1.0, type=float) + parser.add_argument('--scale', default=9.0, type=float) + parser.add_argument('--eta', default=0.0, type=float) + + args = parser.parse_args() -save(result) \ No newline at end of file + + device_id = int(os.getenv("DEVICE_ID", 0)) + ms.context.set_context( + mode=ms.context.GRAPH_MODE, + device_target="GPU", + device_id=device_id, + max_device_memory="30GB" + ) + + _, control_map = process(args.input_path) + results = inference(control_map, args.config_path, args.pretrained_path, + args.prompt, args.negative_prompt, args.num_samples, + args.ddim_steps, args.guess_mode, args.strength, args.scale, args.eta) + + save(results, args.output_path) diff --git a/controlnet/run_train.py b/controlnet/run_controlnet_train.py old mode 100755 new mode 100644 similarity index 64% rename from controlnet/run_train.py rename to controlnet/run_controlnet_train.py index 1490242..46c016f --- a/controlnet/run_train.py +++ b/controlnet/run_controlnet_train.py @@ -1,17 +1,3 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ import os @@ -30,12 +16,12 @@ from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint -from ldm.data.dataset import load_data +from cldm.dataset import load_data from ldm.modules.train.optim import build_optimizer from ldm.modules.train.callback import OverflowMonitor from ldm.modules.train.learningrate import LearningRate from ldm.modules.train.parallel_config import ParallelConfig -from ldm.models.clip_zh.simple_tokenizer import WordpieceTokenizer +from ldm.models.clip_zh.simple_tokenizer import WordpieceTokenizer, BpeTokenizer from ldm.modules.train.tools import parse_with_config, set_random_seed from ldm.modules.train.cell_wrapper import ParallelTrainOneStepWithLossScaleCell @@ -67,13 +53,13 @@ def init_env(opts): opts.rank = rank_id context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", + device_target="GPU", device_id=device_id, max_device_memory="30GB", ) """ create dataset""" - tokenizer = WordpieceTokenizer() + tokenizer = BpeTokenizer() dataset = load_data( data_path=opts.data_path, batch_size=opts.train_batch_size, @@ -118,82 +104,65 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - param_not_load = ms.load_param_into_net(model, param_dict) - print("param not load:", param_not_load) - else: - print(f"{ckpt} not exist:") +def load_pretrained_model(path='torch2ms/ms_weight', model=None): + print(f"Loading model from {path}") + + param_not_load = [] + + unet_weight = ms.load_checkpoint(os.path.join(path, 'unet.ckpt')) + param_not_load.extend(ms.load_param_into_net(model.model, unet_weight)) + + vae_weight = ms.load_checkpoint(os.path.join(path, 'vae.ckpt')) + param_not_load.extend(ms.load_param_into_net(model.first_stage_model, vae_weight)) + + text_encoder_weight = ms.load_checkpoint(os.path.join(path, 'text_encoder.ckpt')) + param_not_load.extend(ms.load_param_into_net(model.cond_stage_model, text_encoder_weight)) + + print("param not load:", param_not_load) + return model -def load_pretrained_model(pretrained_ckpt, net): - print(f"start loading pretrained_ckpt {pretrained_ckpt}") - if os.path.exists(pretrained_ckpt): - param_dict = load_checkpoint(pretrained_ckpt) - param_not_load = load_param_into_net(net, param_dict) - print("param not load:", param_not_load) - else: - print("ckpt file not exist!") - - print("end loading ckpt") - - -def load_pretrained_model_clip_and_vae(pretrained_ckpt, net): - new_param_dict = {} - print(f"start loading pretrained_ckpt {pretrained_ckpt}") - if os.path.exists(pretrained_ckpt): - param_dict = load_checkpoint(pretrained_ckpt) - for key in param_dict: - if key.startswith("first") or key.startswith("cond"): - new_param_dict[key] = param_dict[key] - param_not_load = load_param_into_net(net, new_param_dict) - print("param not load:") - for param in param_not_load: - print(param) - else: - print("ckpt file not exist!") - print("end loading ckpt") def main(opts): dataset, rank_id, device_id, device_num = init_env(opts) - LatentDiffusionWithLoss = instantiate_from_config(opts.model_config) - pretrained_ckpt = os.path.join(opts.pretrained_model_path, opts.pretrained_model_file) - load_pretrained_model(pretrained_ckpt, LatentDiffusionWithLoss) + CLDMWithLoss = instantiate_from_config(opts.model_config) + CLDMWithLoss = load_pretrained_model(opts.pretrained_model_path, CLDMWithLoss) - if opts.enable_lora: - from tk.graph import freeze_delta - # 适配lora算法后,冻结lora模块之外的参数 - freeze_delta(LatentDiffusionWithLoss, 'lora') + # fix SD + for k, v in CLDMWithLoss.parameters_and_names(): + if k.startswith("control_model"): + v.requires_grad = True + else: + v.requires_grad = False if not opts.decay_steps: dataset_size = dataset.get_dataset_size() opts.decay_steps = opts.epochs * dataset_size + lr = LearningRate(opts.start_learning_rate, opts.end_learning_rate, opts.warmup_steps, opts.decay_steps) - optimizer = build_optimizer(LatentDiffusionWithLoss, opts, lr, enable_lora=opts.enable_lora) + optimizer = build_optimizer(CLDMWithLoss, opts, lr) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=opts.init_loss_scale, scale_factor=opts.loss_scale_factor, scale_window=opts.scale_window) - + if opts.use_parallel: - net_with_grads = ParallelTrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer, + net_with_grads = ParallelTrainOneStepWithLossScaleCell(CLDMWithLoss, optimizer=optimizer, scale_sense=update_cell, parallel_config=ParallelConfig) else: - net_with_grads = TrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer, + net_with_grads = TrainOneStepWithLossScaleCell(CLDMWithLoss, optimizer=optimizer, scale_sense=update_cell) + model = Model(net_with_grads) callback = [TimeMonitor(opts.callback_size), LossMonitor(opts.callback_size)] ofm_cb = OverflowMonitor() callback.append(ofm_cb) + if rank_id == 0: dataset_size = dataset.get_dataset_size() if not opts.save_checkpoint_steps: @@ -201,24 +170,13 @@ def main(opts): ckpt_dir = os.path.join(opts.output_path, "ckpt", f"rank_{str(rank_id)}") if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) - - if not opts.enable_lora: - config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps, + + config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps, keep_checkpoint_max=10, integrated_save=False) - ckpoint_cb = ModelCheckpoint(prefix="wkhh_txt2img", - directory=ckpt_dir, - config=config_ck) - else: - from tk.graph.ckpt_util import TrainableParamsCheckPoint - - config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps, - keep_checkpoint_max=10, - integrated_save=False, - saved_network=LatentDiffusionWithLoss) - ckpoint_cb = TrainableParamsCheckPoint(prefix="wkhh_txt2img_lora", - directory=ckpt_dir, - config=config_ck) + ckpoint_cb = ModelCheckpoint(prefix="wkhh_txt2img", + directory=ckpt_dir, + config=config_ck) callback.append(ckpoint_cb) @@ -232,15 +190,14 @@ def main(opts): parser.add_argument('--use_parallel', default=False, type=str2bool, help='use parallel') parser.add_argument('--data_path', default="dataset", type=str, help='data path') parser.add_argument('--output_path', default="output/", type=str, help='use audio out') - parser.add_argument('--train_config', default="configs/train_config.json", type=str, help='train config path') + parser.add_argument('--train_config', default="configs/train_controlnet_config.json", type=str, help='train config path') parser.add_argument('--model_config', default="configs/v1-train-chinese.yaml", type=str, help='model config path') parser.add_argument('--pretrained_model_path', default="", type=str, help='pretrained model directory') - parser.add_argument('--pretrained_model_file', default="", type=str, help='pretrained model file name') parser.add_argument('--optim', default="adamw", type=str, help='optimizer') parser.add_argument('--seed', default=3407, type=int, help='data path') parser.add_argument('--warmup_steps', default=1000, type=int, help='warmup steps') - parser.add_argument('--train_batch_size', default=10, type=int, help='batch size') + parser.add_argument('--train_batch_size', default=1, type=int, help='batch size') parser.add_argument('--callback_size', default=1, type=int, help='callback size.') parser.add_argument("--start_learning_rate", default=1e-5, type=float,help="The initial learning rate for Adam.") parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") @@ -255,8 +212,6 @@ def main(opts): parser.add_argument('--image_size', default=512, type=int, help='images size') parser.add_argument('--image_filter_size', default=256, type=int, help='image filter size') - parser.add_argument('--enable_lora', default=False, type=str2bool, help='enable lora') - args = parser.parse_args() args = parse_with_config(args) abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) diff --git a/controlnet/run_db_train.py b/controlnet/run_db_train.py deleted file mode 100644 index 3a1140a..0000000 --- a/controlnet/run_db_train.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import os -import sys -import time -import argparse -import importlib - -import albumentations -import mindspore as ms -from omegaconf import OmegaConf -from mindspore import Model, context -from mindspore import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_rank, get_group_size -from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint -from mindspore.nn import DynamicLossScaleUpdateCell -from mindspore.nn import TrainOneStepWithLossScaleCell -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell - -from ldm.data.dataset_db import load_data -from ldm.models.clip_zh.simple_tokenizer import WordpieceTokenizer -from ldm.modules.train.optim import build_optimizer -from ldm.modules.train.callback import OverflowMonitor -from ldm.modules.train.learningrate import LearningRate -from ldm.modules.train.parallel_config import ParallelConfig -from ldm.modules.train.tools import parse_with_config, set_random_seed -from ldm.modules.train.cell_wrapper import ParallelTrainOneStepWithLossScaleCell - - -os.environ['HCCL_CONNECT_TIMEOUT'] = '6000' - - -def init_env(opts): - """ init_env """ - set_random_seed(opts.seed) - if opts.use_parallel: - init() - device_id = int(os.getenv('DEVICE_ID')) - device_num = get_group_size() - ParallelConfig.dp = device_num - rank_id = get_rank() - opts.rank = rank_id - print("device_id is {}, rank_id is {}, device_num is {}".format( - device_id, rank_id, device_num)) - context.reset_auto_parallel_context() - context.set_auto_parallel_context( - parallel_mode=context.ParallelMode.DATA_PARALLEL, - gradients_mean=True, - device_num=device_num) - else: - device_num = 1 - device_id = int(os.getenv('DEVICE_ID', 0)) - rank_id = 0 - opts.rank = rank_id - - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, - max_device_memory="30GB", - ) - - """ create dataset""" - tokenizer = WordpieceTokenizer() - dataset = load_data( - train_data_path = opts.train_data_path, - reg_data_path = opts.reg_data_path, - train_data_repeats = opts.train_data_repeats, - class_word = opts.class_word, - token = opts.token, - batch_size = opts.train_batch_size, - tokenizer = tokenizer, - image_size=opts.image_size, - image_filter_size=opts.image_filter_size, - device_num=device_num, - random_crop=opts.random_crop, - rank_id=rank_id, - sample_num=-1 - ) - print(f"rank id {rank_id}, sample num is {dataset.get_dataset_size()}") - - return dataset, rank_id, device_id, device_num - - -def instantiate_from_config(config): - config = OmegaConf.load(config).model - if not "target" in config: - if config == '__is_first_stage__': - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def str2bool(b): - if b.lower() not in ["false", "true"]: - raise Exception("Invalid Bool Value") - if b.lower() in ["false"]: - return False - return True - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - param_not_load = ms.load_param_into_net(model, param_dict) - print("param not load:", param_not_load) - else: - print(f"{ckpt} not exist:") - - return model - - -def load_pretrained_model(pretrained_ckpt, net): - print(f"start loading pretrained_ckpt {pretrained_ckpt}") - if os.path.exists(pretrained_ckpt): - param_dict = load_checkpoint(pretrained_ckpt) - param_not_load = load_param_into_net(net, param_dict) - print("param not load:", param_not_load) - else: - print("ckpt file not exist!") - - print("end loading ckpt") - - -def load_pretrained_model_clip_and_vae(pretrained_ckpt, net): - new_param_dict = {} - print(f"start loading pretrained_ckpt {pretrained_ckpt}") - if os.path.exists(pretrained_ckpt): - param_dict = load_checkpoint(pretrained_ckpt) - for key in param_dict: - if key.startswith("first") or key.startswith("cond"): - new_param_dict[key] = param_dict[key] - param_not_load = load_param_into_net(net, new_param_dict) - print("param not load:") - for param in param_not_load: - print(param) - else: - print("ckpt file not exist!") - - print("end loading ckpt") - - -def main(opts): - dataset, rank_id, device_id, device_num = init_env(opts) - LatentDiffusionWithLoss = instantiate_from_config(opts.model_config) - pretrained_ckpt = os.path.join(opts.pretrained_model_path, opts.pretrained_model_file) - load_pretrained_model(pretrained_ckpt, LatentDiffusionWithLoss) - - if not opts.decay_steps: - dataset_size = dataset.get_dataset_size() - opts.decay_steps = opts.epochs * dataset_size - lr = LearningRate(opts.start_learning_rate, opts.end_learning_rate, opts.warmup_steps, opts.decay_steps) - optimizer = build_optimizer(LatentDiffusionWithLoss, opts, lr) - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=opts.init_loss_scale, - scale_factor=opts.loss_scale_factor, - scale_window=opts.scale_window) - - if opts.use_parallel: - net_with_grads = ParallelTrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer, - scale_sense=update_cell, parallel_config=ParallelConfig) - else: - net_with_grads = TrainOneStepWithLossScaleCell(LatentDiffusionWithLoss, optimizer=optimizer, - scale_sense=update_cell) - model = Model(net_with_grads) - callback = [TimeMonitor(opts.callback_size), LossMonitor(opts.callback_size)] - - ofm_cb = OverflowMonitor() - callback.append(ofm_cb) - - if rank_id == 0: - dataset_size = dataset.get_dataset_size() - if not opts.save_checkpoint_steps: - opts.save_checkpoint_steps = dataset_size - ckpt_dir = os.path.join(opts.output_path, "ckpt", f"rank_{str(rank_id)}") - if not os.path.exists(ckpt_dir): - os.makedirs(ckpt_dir) - config_ck = CheckpointConfig(save_checkpoint_steps=opts.save_checkpoint_steps, - keep_checkpoint_max=10, - integrated_save=False) - ckpoint_cb = ModelCheckpoint(prefix="wkhh_txt2img", - directory=ckpt_dir, - config=config_ck) - callback.append(ckpoint_cb) - - print("start_training...") - model.train(opts.epochs, dataset, callbacks=callback, dataset_sink_mode=False) - - -if __name__ == "__main__": - print('process id:', os.getpid()) - parser = argparse.ArgumentParser() - parser.add_argument('--use_parallel', default=False, type=str2bool, help='use parallel') - parser.add_argument('--data_path', default="dataset", type=str, help='data path') - parser.add_argument('--output_path', default="output/", type=str, help='use audio out') - parser.add_argument('--train_config', default="configs/train_db_config.json", type=str, help='train config path') - parser.add_argument('--model_config', default="configs/v1-train-db-chinese.yaml", type=str, help='model config path') - parser.add_argument('--pretrained_model_path', default="", type=str, help='pretrained model directory') - parser.add_argument('--pretrained_model_file', default="", type=str, help='pretrained model file name') - parser.add_argument('--train_data_path', default="", type=str, help='train data path') - parser.add_argument('--reg_data_path', default="", type=str, help='regularization data path') - - parser.add_argument('--train_data_repeats', default=100, type=int, help='repetition times of training data') - parser.add_argument('--class_word', default="", type=str, help='Match class_word to the category of images you want to train') - parser.add_argument('--token', default="α", type=str, help='unique token you want to represent your trained model') - parser.add_argument('--optim', default="adamw", type=str, help='optimizer') - parser.add_argument('--seed', default=3407, type=int, help='data path') - parser.add_argument('--warmup_steps', default=1000, type=int, help='warmup steps') - parser.add_argument('--train_batch_size', default=10, type=int, help='batch size') - parser.add_argument('--callback_size', default=1, type=int, help='callback size.') - parser.add_argument("--start_learning_rate", default=1e-5, type=float,help="The initial learning rate for Adam.") - parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") - parser.add_argument("--decay_steps", default=0, type=int,help="lr decay steps.") - parser.add_argument("--epochs", default=10, type=int, help="epochs") - parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale") - parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") - parser.add_argument("--scale_window", default=1000, type=float, help="scale window") - parser.add_argument("--save_checkpoint_steps", default=0, type=int, help="save checkpoint steps") - parser.add_argument('--random_crop', default=False, type=str2bool, help='random crop') - parser.add_argument('--filter_small_size', default=True, type=str2bool, help='filter small images') - parser.add_argument('--image_size', default=512, type=int, help='images size') - parser.add_argument('--image_filter_size', default=256, type=int, help='image filter size') - - args = parser.parse_args() - args = parse_with_config(args) - abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) - args.model_config = os.path.join(abs_path, args.model_config) - print(args) - - start = time.time() - main(args) - end = time.time() - print("training time: ", end-start) \ No newline at end of file diff --git a/controlnet/scripts/run_train.sh b/controlnet/scripts/run_train.sh deleted file mode 100755 index 68380e2..0000000 --- a/controlnet/scripts/run_train.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -export GLOG_v=3 -export HCCL_CONNECT_TIMEOUT=600 -export ASCEND_GLOBAL_LOG_LEVEL=3 -export ASCEND_SLOG_PRINT_TO_STDOUT=0 -device_id=2 - -output_path=output/ -task_name=txt2img -data_path=dataset/ -pretrained_model_path=models/ -train_config_file=configs/train_config.json - -rm -rf ${output_path:?}/${task_name:?} -mkdir -p ${output_path:?}/${task_name:?} -export RANK_SIZE=1;export DEVICE_ID=$device_id;export MS_COMPILER_CACHE_PATH=${output_path:?}/${task_name:?}; \ -nohup python -u run_train.py \ - --data_path=$data_path \ - --train_config=$train_config_file \ - --output_path=$output_path/$task_name \ - --use_parallel=False \ - --pretrained_model_path=$pretrained_model_path \ - > $output_path/$task_name/log_train 2>&1 & \ No newline at end of file diff --git a/controlnet/scripts/run_txt2img.sh b/controlnet/scripts/run_txt2img.sh deleted file mode 100644 index 4ece9ca..0000000 --- a/controlnet/scripts/run_txt2img.sh +++ /dev/null @@ -1,33 +0,0 @@ - -#!/bin/bash -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -export GLOG_v=3 -export ASCEND_GLOBAL_LOG_LEVEL=3 -export ASCEND_SLOG_PRINT_TO_STDOUT=0 - -export DEVICE_ID=0; \ -python txt2img.py \ - --prompt "a photo of a girl" \ - --config configs/v1-inference-chinese.yaml \ - --output_path ./output/ \ - --seed 42 \ - --dpm_solver \ - --n_iter 4 \ - --n_samples 4 \ - --W 512 \ - --H 512 \ - --ddim_steps 15 diff --git a/controlnet/test.py b/controlnet/test.py deleted file mode 100644 index b870e67..0000000 --- a/controlnet/test.py +++ /dev/null @@ -1,148 +0,0 @@ -import mindspore as ms -import os -import numpy as np - -from ldm.util import instantiate_from_config -from omegaconf import OmegaConf - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - param_not_load = ms.load_param_into_net(model, param_dict) - print("param not load:", param_not_load) - else: - print(f"!!!Warning!!!: {ckpt} doesn't exist") - - return model - - -def load_full_model(config, path='./models/wukong/', verbose=False): - model = instantiate_from_config(config.model) - param_not_load = [] - if os.path.isdir(path): - unet = ms.load_checkpoint(os.path.join(path, 'unet.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.model, unet)) - vae = ms.load_checkpoint(os.path.join(path, 'vae.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.first_stage_model, vae)) - text_encoder = ms.load_checkpoint(os.path.join(path, 'text_encoder.ckpt')) - param_not_load.extend(ms.load_param_into_net(model.cond_stage_model, text_encoder)) - else: - param_dict = ms.load_checkpoint(path) - param_not_load.extend(ms.load_param_into_net(model, param_dict)) - print("param not load:", param_not_load) - print("load model from", path) - return model - - -def test_load_unet(config): - model = instantiate_from_config(config.model) - param_not_load = [] - unet = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/unet.ckpt') - param_not_load.extend(ms.load_param_into_net(model.model, unet)) - print("param not load:", param_not_load) - print(len(param_not_load)) - print(len(unet)) - - -def test_load_text_encoder(config): - model = instantiate_from_config(config.model) - param_not_load = [] - text_encoder = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/text_encoder.ckpt') - param_not_load.extend(ms.load_param_into_net(model.cond_stage_model, text_encoder)) - print("param not load:", param_not_load) - print(len(param_not_load)) - print(len(text_encoder)) - - -def test_load_vae(config): - model = instantiate_from_config(config.model) - param_not_load = [] - vae = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/vae.ckpt') - param_not_load.extend(ms.load_param_into_net(model.first_stage_model, vae)) - print("param not load:", param_not_load) - print(len(param_not_load)) - print(len(vae)) - - -def test_load_controlnet(config): - model = instantiate_from_config(config.model) - param_not_load = [] - controlnet = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/controlnet.ckpt') - param_not_load.extend(ms.load_param_into_net(model.control_model, controlnet)) - print("param not load:", param_not_load) - print(len(param_not_load)) - print(len(controlnet)) - - -def test_text_encoder_output(config): - model = instantiate_from_config(config.model) - text_encoder = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/text_encoder.ckpt') - ms.load_param_into_net(model.cond_stage_model, text_encoder) - # input_ shape (1, 77) dtype ms.int64 - # 输入要保持一致 - input_ = ms.Tensor(np.arange(77).reshape(1, 77), dtype=ms.int64) - output = model.cond_stage_model.transformer(input_) - print(output.shape) - print(output.sum(), output.min(), output.max()) - - -def test_vae_output(config): - model = instantiate_from_config(config.model) - vae = ms.load_checkpoint('/mnt/petrelfs/majie/project/minddiffusion/vision/wukong-huahua/torch2ms/ms_weight/vae.ckpt') - ms.load_param_into_net(model.first_stage_model, vae) - # input_ shape (1, 4, 32, 32) dtype ms.float32 - # 输入要保持一致 - input_ = np.ones((1, 3, 256, 256), dtype=np.float32) - input_ = ms.Tensor(input_, dtype=ms.float32) - latents = model.first_stage_model.encode(input_) - print(latents.shape) - print(latents.sum(), latents.min(), latents.max()) - - output = model.first_stage_model.decode(latents) - print(output.shape) - print(output.sum(), output.min(), output.max()) - - -def test_tokenizer_output(config): - model = instantiate_from_config(config.model) - text = 'a photo of a girl' - - output = model.cond_stage_model.tokenize(text) - print(output) - - - - - -if __name__ == '__main__': - config = 'configs/cldm_v15.yaml' - config = OmegaConf.load(config) - - ms.context.set_context( - mode=ms.context.GRAPH_MODE, - device_target="GPU", - device_id=0, - max_device_memory="30GB" - ) - - # model1 = load_full_model(config) - # print('---------------------------------------------') - # model2 = load_full_model(config, path='./models/wukong-huahua-ms.ckpt') - - # for (k1, v1), (k2, v2) in zip(model1.parameters_and_names(), model2.parameters_and_names()): - # if k1.startswith('first_stage_model.encoder.down.3.downsample'): - # continue - # if k1.startswith('first_stage_model.decoder.up.0.upsample'): - # continue - # if not (v1 == v2).all(): - # print(k1, k2, v1.sum(), v2.sum()) - # print('error') - # exit(0) - # print('ok') - # test_load_vae(config) - # test_tokenizer_output(config=config) - test_load_controlnet(config=config) \ No newline at end of file diff --git a/controlnet/torch2ms/convert.py b/controlnet/torch2ms/convert.py index f63b195..dfe2b8f 100644 --- a/controlnet/torch2ms/convert.py +++ b/controlnet/torch2ms/convert.py @@ -2,15 +2,17 @@ import mindspore as ms import pickle import torch +import argparse -def convert_torch_to_numpy(name='cond_stage_model', save=False): - path = '/mnt/petrelfs/majie/project/ControlNet/models/control_sd15_canny.pth' + +def convert_torch_to_numpy(path, name='cond_stage_model', save=False): + # name: cond_stage_model, diffusion_model, first_stage_model, control_model + # path = '/mnt/petrelfs/majie/project/ControlNet/models/control_sd15_canny.pth' torch_weight = torch.load(path) numpy_weight = {} for k, v in torch_weight.items(): if name in k: - # print(k, v.shape) numpy_weight[k] = v.numpy() if save: @@ -19,7 +21,7 @@ def convert_torch_to_numpy(name='cond_stage_model', save=False): return numpy_weight -def save_ms_ckpt(ckpt, name): +def save_ms_ckpt(ckpt, output_dir, name): save_data = [] for k, v in ckpt.items(): save_data.append({ @@ -27,7 +29,7 @@ def save_ms_ckpt(ckpt, name): 'data': v }) - ms.save_checkpoint(save_data, f'./ms_weight/{name}.ckpt') + ms.save_checkpoint(save_data, f'{output_dir}/{name}.ckpt') def convert_text_encoder(numpy_weight = {}): @@ -149,11 +151,7 @@ def convert_vae(numpy_weight = {}): return final_ckpt -def convert_unet(): - path = '/mnt/petrelfs/majie/project/ControlNet/convert/unet.pkl' - with open(path, 'rb') as f: - numpy_weight = pickle.load(f) - +def convert_unet(numpy_weight={}): ms_weight = {} change = { @@ -211,14 +209,7 @@ def convert_unet(): continue ms_weight[k] = ms.Tensor(v) - save_data = [] - for k, v in ms_weight.items(): - save_data.append({ - 'name': k, - 'data': v - }) - - ms.save_checkpoint(save_data, 'unet.ckpt') + return ms_weight def convert_controlnet(numpy_weight = {}): @@ -267,7 +258,28 @@ def convert_controlnet(numpy_weight = {}): if __name__ == '__main__': - numpy_weight = convert_torch_to_numpy(name='control_model', save=True) - ckpt = convert_controlnet(numpy_weight) - save_ms_ckpt(ckpt, 'controlnet') + parser = argparse.ArgumentParser() + + parser.add_argument('--input_path', type=str, default=None) + parser.add_argument('--only_controlnet', type=bool, default=False) + parser.add_argument('--output_path', type=str, default=None) + + args = parser.parse_args() + + if args.only_controlnet: + controlnet = convert_torch_to_numpy(path=args.input_path, name='control_model') + save_ms_ckpt(convert_controlnet(controlnet), output_dir=args.output_path, name='controlnet') + + else: + vae = convert_torch_to_numpy(path=args.input_path, name='first_stage_model') + text_encoder = convert_torch_to_numpy(path=args.input_path, name='cond_stage_model') + unet = convert_torch_to_numpy(path=args.input_path, name='diffusion_model') + controlnet = convert_torch_to_numpy(path=args.input_path, name='control_model') + + save_ms_ckpt(convert_vae(vae), output_dir=args.output_path, name='vae') + save_ms_ckpt(convert_text_encoder(text_encoder), output_dir=args.output_path, name='text_encoder') + save_ms_ckpt(convert_unet(unet), output_dir=args.output_path, name='unet') + save_ms_ckpt(convert_controlnet(controlnet), output_dir=args.output_path, name='controlnet') + + print('Done!') diff --git a/controlnet/txt2img.py b/controlnet/txt2img.py deleted file mode 100644 index c35c7c7..0000000 --- a/controlnet/txt2img.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import os -import time -import sys -import argparse -from PIL import Image -from omegaconf import OmegaConf - -import numpy as np -import mindspore as ms - -workspace = os.path.dirname(os.path.abspath(__file__)) -print("workspace", workspace, flush=True) -sys.path.append(workspace) -from ldm.util import instantiate_from_config -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler - - -def seed_everything(seed): - if seed: - ms.set_seed(seed) - np.random.seed(seed) - - -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - -def str2bool(b): - if b.lower() not in ["false", "true"]: - raise Exception("Invalid Bool Value") - if b.lower() in ["false"]: - return False - return True - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model = instantiate_from_config(config.model) - if os.path.exists(ckpt): - param_dict = ms.load_checkpoint(ckpt) - if param_dict: - ms.load_param_into_net(model, param_dict) - else: - print(f"!!!Warning!!!: {ckpt} doesn't exist") - - return model - - -def load_model_from_config_convert(config, path='torch2ms/ms_weight', verbose=False): - print(f"Loading model from {path}") - model = instantiate_from_config(config.model) - unet_weight = ms.load_checkpoint(os.path.join(path, 'unet.ckpt')) - ms.load_param_into_net(model.model, unet_weight) - - vae_weight = ms.load_checkpoint(os.path.join(path, 'vae.ckpt')) - ms.load_param_into_net(model.first_stage_model, vae_weight) - - text_encoder_weight = ms.load_checkpoint(os.path.join(path, 'text_encoder.ckpt')) - ms.load_param_into_net(model.cond_stage_model, text_encoder_weight) - - return model - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--data_path", - type=str, - nargs="?", - default="", - help="the prompt to render" - ) - parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="狗 绘画 写实风格", - help="the prompt to render" - ) - parser.add_argument( - "--output_path", - type=str, - nargs="?", - default="output", - help="dir to write results to" - ) - parser.add_argument( - "--skip_grid", - action='store_true', - help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", - ) - parser.add_argument( - "--skip_save", - action='store_true', - help="do not save individual samples. For speed measurements.", - ) - parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", - ) - parser.add_argument( - "--fixed_code", - action='store_true', - help="if enabled, uses the same starting code across samples ", - ) - parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", - ) - parser.add_argument( - "--n_iter", - type=int, - default=2, - help="sample this often", - ) - parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", - ) - parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", - ) - parser.add_argument( - "--n_samples", - type=int, - default=8, - help="how many samples to produce for each given prompt. A.k.a. batch size", - ) - parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", - ) - parser.add_argument( - "--n_rows", - type=int, - default=0, - help="rows in the grid (default: n_samples)", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", - ) - parser.add_argument( - "--from-file", - type=str, - help="if specified, load prompts from this file", - ) - parser.add_argument( - "--config", - type=str, - default="configs/v1-inference-chinese.yaml", - help="path to config which constructs model", - ) - parser.add_argument( - "--ckpt_path", - type=str, - default="models", - help="path to checkpoint of model", - ) - parser.add_argument( - "--ckpt_name", - type=str, - default="wukong-huahua-ms.ckpt", - help="path to checkpoint of model", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="the seed (for reproducible sampling)", - ) - parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" - ) - parser.add_argument("--enable_lora", default=False, type=str2bool, help="enable lora") - parser.add_argument("--lora_ckpt_filepath", type=str, default="models", help="path to checkpoint of model with lora") - opt = parser.parse_args() - work_dir = os.path.dirname(os.path.abspath(__file__)) - print(f"WORK DIR:{work_dir}") - - device_id = int(os.getenv("DEVICE_ID", 0)) - ms.context.set_context( - mode=ms.context.GRAPH_MODE, - device_target="GPU", - device_id=device_id, - max_device_memory="30GB" - ) - - seed_everything(opt.seed) - - if not os.path.isabs(opt.config): - opt.config = os.path.join(work_dir, opt.config) - config = OmegaConf.load(f"{opt.config}") - # model = load_model_from_config(config, f"{os.path.join(opt.ckpt_path, opt.ckpt_name)}") - model = load_model_from_config_convert(config) - - if opt.enable_lora: - lora_ckpt_path = opt.lora_ckpt_filepath - lora_param_dict = ms.load_checkpoint(lora_ckpt_path) - ms.load_param_into_net(model, lora_param_dict) - - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - else: - sampler = PLMSSampler(model) - os.makedirs(opt.output_path, exist_ok=True) - outpath = opt.output_path - - batch_size = opt.n_samples - if not opt.data_path: - prompt = opt.prompt - assert prompt is not None - data = [batch_size * [prompt]] - else: - opt.prompt = os.path.join(opt.data_path, opt.prompt) - print(f"reading prompts from {opt.prompt}") - with open(opt.prompt, "r") as f: - data = f.read().splitlines() - data = [batch_size * [prompt for prompt in data]] - - sample_path = os.path.join(outpath, "samples") - os.makedirs(sample_path, exist_ok=True) - base_count = len(os.listdir(sample_path)) - - start_code = None - if opt.fixed_code: - stdnormal = ms.ops.StandardNormal() - start_code = stdnormal((opt.n_samples, 4, opt.H // 8, opt.W // 8)) - - all_samples = list() - for prompts in data: - for n in range(opt.n_iter): - start_time = time.time() - - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [4, opt.H // 8, opt.W // 8] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code - ) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = ms.ops.clip_by_value((x_samples_ddim + 1.0) / 2.0, - clip_value_min=0.0, clip_value_max=1.0) - x_samples_ddim_numpy = x_samples_ddim.asnumpy() - - if not opt.skip_save: - for x_sample in x_samples_ddim_numpy: - x_sample = 255. * x_sample.transpose(1, 2, 0) - img = Image.fromarray(x_sample.astype(np.uint8)) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - - if not opt.skip_grid: - all_samples.append(x_samples_ddim_numpy) - - end_time = time.time() - print(f"the infer time of a batch is {end_time-start_time}") - - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") - -if __name__ == "__main__": - main()