Skip to content

Commit

Permalink
refactor the convert script and add readme
Browse files Browse the repository at this point in the history
  • Loading branch information
unrealMJ committed Oct 31, 2023
1 parent 15ff325 commit 617c565
Show file tree
Hide file tree
Showing 17 changed files with 581 additions and 1,508 deletions.
16 changes: 0 additions & 16 deletions controlnet/.gitignore

This file was deleted.

34 changes: 31 additions & 3 deletions controlnet/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
# Mindspore-controlnet
# Mindspore-ControlNet

1. Convert pytorch ckeckpoints to mindspore
The Stable diffusion code is copied from https://github.com/mindspore-lab/minddiffusion

2. Run test_controlnet.py to use controlnet.
## Install
``` shell
pip install mindspore==1.9.0
pip install -r requirements.txt
```


## 1. Inference with pretrained ControlNet
1. download pytorch controlnet checkpoints from https://huggingface.co/lllyasviel/ControlNet/tree/main/models or https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main.

2. convert downloaded pytorch checkpoints to mindspore checkpoints, or directly download from https://huggingface.co/unrealMJ/MindSpore-ControlNet and put them into torch2ms/ms_weight
```shell
python torch2ms/convert.py --input_path xxxx --output_path xxxx # convert full model
python torch2ms/convert.py --input_path xxxx --output_path xxxx --only_controlnet # convert controlnet only
```

3. Run run_controlnet_inference.py to use controlnet.
```shell
python run_controlnet_inference.py --input_path xxxx --output_path xxxx
```

## 2. Train ControlNet from scratch
1. Download the dataset from https://huggingface.co/datasets/fusing/fill50k


2. Run run_controlnet_train.py to train controlnet.
```shell
python run_controlnet_train.py --data_path xxxx --train_config configs/train_controlnet_config.json --model_config configs/cldm_v15.yaml
```
89 changes: 59 additions & 30 deletions controlnet/ldm/cldm/cldm.py → controlnet/cldm/cldm.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
import mindspore as ms
import mindspore.nn as nn

from modules.diffusionmodules.openaimodel import UNetModel, ResBlock, Downsample, AttentionBlock
from modules.diffusionmodules.util import (
from ldm.modules.diffusionmodules.openaimodel import UNetModel, ResBlock, Downsample, AttentionBlock
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from modules.attention import SpatialTransformer
from models.diffusion.ddpm import LatentDiffusion
from ldm.modules.attention import SpatialTransformer
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import exists, instantiate_from_config


class ControlledUnetModel(UNetModel):
def construct(self, x, timesteps, context,
control_1, control_2, control_3, control_4, control_5, control_6):

# control = []
# split api 与2.0不同
control = ms.ops.split(control_1, axis=0, output_num=control_1.shape[0]) \
+ ms.ops.split(control_2, axis=0, output_num=control_2.shape[0]) \
+ ms.ops.split(control_3, axis=0, output_num=control_3.shape[0]) \
+ ms.ops.split(control_4, axis=0, output_num=control_4.shape[0]) \
+ ms.ops.split(control_5, axis=0, output_num=control_5.shape[0]) \
+ ms.ops.split(control_6, axis=0, output_num=control_6.shape[0])
control = list(control)

control = []
if control_1 is not None:
control = ms.ops.split(control_1, axis=0, output_num=control_1.shape[0]) \
+ ms.ops.split(control_2, axis=0, output_num=control_2.shape[0]) \
+ ms.ops.split(control_3, axis=0, output_num=control_3.shape[0]) \
+ ms.ops.split(control_4, axis=0, output_num=control_4.shape[0]) \
+ ms.ops.split(control_5, axis=0, output_num=control_5.shape[0]) \
+ ms.ops.split(control_6, axis=0, output_num=control_6.shape[0])
control = list(control)
hs = []

# mindspore不需要包装torch.no_grad()
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.astype(self.dtype)
# h = x

for module in self.input_blocks:
for cell in module:
h = cell(h, emb, context)
Expand All @@ -43,26 +41,22 @@ def construct(self, x, timesteps, context,
h = module(h, emb, context)

control_idx = -1
if control is not None:
if len(control) != 0:
h += control[control_idx]
control_idx -= 1


only_mid_control = False
hs_idx = -1
# TODO: check all tensor dtype
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
if only_mid_control or len(control) == 0:
h = ms.ops.concat([h, hs[hs_idx].astype(h.dtype)], axis=1)
else:
h = ms.ops.concat([h, hs[hs_idx].astype(h.dtype) + control[control_idx].astype(h.dtype)], axis=1)
control_idx -= 1
hs_idx -= 1
# h = module(h, emb, context)
for cell in module:
h = cell(h, emb, context)

# # h = h.astype(x.dtype)
return self.out(h)


Expand Down Expand Up @@ -300,7 +294,7 @@ def __init__(
def make_zero_conv(self, channels):
return nn.SequentialCell([
zero_module(
conv_nd(self.dims, channels, channels, 1, padding=0, has_bias=True, pad_mode='pad')
conv_nd(self.dims, channels, channels, 1, padding=0, has_bias=True, pad_mode='pad').to_float(self.dtype)
)
])

Expand Down Expand Up @@ -342,18 +336,18 @@ def __init__(self, control_stage_config, control_key, only_mid_control, *args, *
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13

def get_input(self, x, c):
# TODO: support train
pass

def apply_model(self, x_noisy, t, cond, *args, **kwargs):
diffusion_model = self.model.diffusion_model

cond_txt = ms.ops.concat(cond['c_crossattn'], 1)

if cond['c_concat'] is None:
pass
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt,
control_1=None, control_2=None,
control_3=None, control_4=None,
control_5=None, control_6=None,
)
else:
control = self.control_model(x=x_noisy, hint=ms.ops.concat(cond['c_concat'], 1), timesteps=t,
context=cond_txt)
Expand All @@ -366,8 +360,6 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
control_5 = ms.ops.concat(control[7:9], 0)
control_6 = ms.ops.concat(control[9: ])

# from mindspore.common import mutable
# control = mutable(control)

eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt,
control_1=control_1, control_2=control_2,
Expand All @@ -376,3 +368,40 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
)
return eps

def get_input(self, x, c, control):
x, c = super().get_input(x, c)

control = ms.numpy.transpose(control, (0, 3, 1, 2))

return x, c, control

def construct(self, x, c, control):
t = ms.ops.UniformInt()((x.shape[0],), ms.Tensor(0, dtype=ms.dtype.int32), ms.Tensor(self.num_timesteps, dtype=ms.dtype.int32))
x, c, control = self.get_input(x, c, control)
c = self.get_learned_conditioning_fortrain(c)
return self.p_losses(x, c, t, control)

def p_losses(self, x_start, cond, t, control, noise=None):
noise = ms.numpy.randn(x_start.shape)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
tmp = {'c_concat': [control], 'c_crossattn': [cond]}
model_output = self.apply_model(x_noisy, t, tmp)

if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()

loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])

logvar_t = self.logvar[t]
loss = loss_simple / ms.ops.exp(logvar_t) + logvar_t
loss = self.l_simple_weight * loss.mean()

loss_vlb = self.get_loss(model_output, target, mean=False).mean((1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss += (self.original_elbo_weight * loss_vlb)

return loss
Loading

0 comments on commit 617c565

Please sign in to comment.