Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open-source model definitions #1

Open
patrickvonplaten opened this issue Nov 6, 2023 · 7 comments
Open

Open-source model definitions #1

patrickvonplaten opened this issue Nov 6, 2023 · 7 comments

Comments

@patrickvonplaten
Copy link

Hey @gabgoh,

Super cool that you're open-sourcing the consistency decoder of Dalle-3 with a MIT license ❤️

Any chance you can also add the model definitions of the torch.jit binary? Otherwise it'll be quite difficult to port the model to other libraries.

@gabgoh
Copy link
Collaborator

gabgoh commented Nov 6, 2023

I can try to get a more human readable version of this pushed, but does decoder_consistency.ckpt.code work for the time being? The complete model definition is in there.

@liuliu
Copy link

liuliu commented Nov 6, 2023

It is not quite as readable, as you need to plumbing through varies wrappers, for example, you get at high-level:

class ConvUNetVAE(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  blocks : __torch__.torch.nn.modules.container.ModuleDict
  def forward(self: __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvUNetVAE,
    x: Tensor,
    t: Tensor,
    features: Tensor) -> Tensor:
    blocks = self.blocks
    output = blocks.output
    blocks0 = self.blocks
    up_0_conv_3 = blocks0.up_0_conv_3
    blocks1 = self.blocks
    up_0_conv_2 = blocks1.up_0_conv_2
    blocks2 = self.blocks
    up_0_conv_1 = blocks2.up_0_conv_1
    blocks3 = self.blocks
    up_0_conv_0 = blocks3.up_0_conv_0
    blocks4 = self.blocks
    up_1_upsamp = blocks4.up_1_upsamp
    blocks5 = self.blocks
    up_1_conv_3 = blocks5.up_1_conv_3
    blocks6 = self.blocks
    up_1_conv_2 = blocks6.up_1_conv_2
    blocks7 = self.blocks
    up_1_conv_1 = blocks7.up_1_conv_1
    blocks8 = self.blocks
    up_1_conv_0 = blocks8.up_1_conv_0
    blocks9 = self.blocks
    up_2_upsamp = blocks9.up_2_upsamp
    blocks10 = self.blocks
    up_2_conv_3 = blocks10.up_2_conv_3
    blocks11 = self.blocks
    up_2_conv_2 = blocks11.up_2_conv_2
    blocks12 = self.blocks
    up_2_conv_1 = blocks12.up_2_conv_1
    blocks13 = self.blocks
    up_2_conv_0 = blocks13.up_2_conv_0
    blocks14 = self.blocks
    up_3_upsamp = blocks14.up_3_upsamp
    blocks15 = self.blocks
    up_3_conv_3 = blocks15.up_3_conv_3
    blocks16 = self.blocks
    up_3_conv_2 = blocks16.up_3_conv_2
    blocks17 = self.blocks
    up_3_conv_1 = blocks17.up_3_conv_1
    blocks18 = self.blocks
    up_3_conv_0 = blocks18.up_3_conv_0
    blocks19 = self.blocks
    mid_conv_1 = blocks19.mid_conv_1
    blocks20 = self.blocks
    mid_conv_0 = blocks20.mid_conv_0
    blocks21 = self.blocks
    down_3_conv_2 = blocks21.down_3_conv_2
    blocks22 = self.blocks
    down_3_conv_1 = blocks22.down_3_conv_1
    blocks23 = self.blocks
    down_3_conv_0 = blocks23.down_3_conv_0
    blocks24 = self.blocks
    down_2_downsamp = blocks24.down_2_downsamp
    blocks25 = self.blocks
    down_2_conv_2 = blocks25.down_2_conv_2
    blocks26 = self.blocks
    down_2_conv_1 = blocks26.down_2_conv_1
    blocks27 = self.blocks
    down_2_conv_0 = blocks27.down_2_conv_0
    blocks28 = self.blocks
    down_1_downsamp = blocks28.down_1_downsamp
    blocks29 = self.blocks
    down_1_conv_2 = blocks29.down_1_conv_2
    blocks30 = self.blocks
    down_1_conv_1 = blocks30.down_1_conv_1
    blocks31 = self.blocks
    down_1_conv_0 = blocks31.down_1_conv_0
    blocks32 = self.blocks
    down_0_downsamp = blocks32.down_0_downsamp
    blocks33 = self.blocks
    down_0_conv_2 = blocks33.down_0_conv_2
    blocks34 = self.blocks
    down_0_conv_1 = blocks34.down_0_conv_1
    blocks35 = self.blocks
    down_0_conv_0 = blocks35.down_0_conv_0
    blocks36 = self.blocks
    embed_image = blocks36.embed_image
    blocks37 = self.blocks
    embed_time = blocks37.embed_time
    input = torch.to(features, torch.device("cuda:0"), 6)
    features0 = torch.upsample_nearest2d(input, None, [8., 8.])
    x0 = torch.cat([x, features0], 1)
    _0 = (embed_time).forward(t, )
    _1 = (embed_image).forward(x0, )
    _2 = (down_0_conv_0).forward(_1, _0, )
    _3 = (down_0_conv_1).forward(_2, _0, )
    _4 = (down_0_conv_2).forward(_3, _0, )
    _5 = (down_0_downsamp).forward(_4, _0, )
    _6 = (down_1_conv_0).forward(_5, _0, )
    _7 = (down_1_conv_1).forward(_6, _0, )
    _8 = (down_1_conv_2).forward(_7, _0, )
    _9 = (down_1_downsamp).forward(_8, _0, )
    _10 = (down_2_conv_0).forward(_9, _0, )
    _11 = (down_2_conv_1).forward(_10, _0, )
    _12 = (down_2_conv_2).forward(_11, _0, )
    _13 = (down_2_downsamp).forward(_12, _0, )
    _14 = (down_3_conv_0).forward(_13, _0, )
    _15 = (down_3_conv_1).forward(_14, _0, )
    _16 = (down_3_conv_2).forward(_15, _0, )
    _17 = (mid_conv_1).forward((mid_conv_0).forward(_16, _0, ), _0, )
    _18 = (up_3_conv_0).forward(_17, _16, _0, )
    _19 = (up_3_conv_1).forward(_18, _15, _0, )
    _20 = (up_3_conv_2).forward(_19, _14, _0, )
    _21 = (up_3_conv_3).forward(_20, _13, _0, )
    _22 = (up_2_conv_0).forward((up_3_upsamp).forward(_21, _0, ), _12, _0, )
    _23 = (up_2_conv_1).forward(_22, _11, _0, )
    _24 = (up_2_conv_2).forward(_23, _10, _0, )
    _25 = (up_2_conv_3).forward(_24, _9, _0, )
    _26 = (up_1_conv_0).forward((up_2_upsamp).forward(_25, _0, ), _8, _0, )
    _27 = (up_1_conv_1).forward(_26, _7, _0, )
    _28 = (up_1_conv_2).forward(_27, _6, _0, )
    _29 = (up_1_conv_3).forward(_28, _5, _0, )
    _30 = (up_0_conv_0).forward((up_1_upsamp).forward(_29, _0, ), _4, _0, )
    _31 = (up_0_conv_1).forward(_30, _3, _0, )
    _32 = (up_0_conv_2).forward(_31, _2, _0, )
    _33 = (up_0_conv_3).forward(_32, _1, _0, )
    return (output).forward(_33, )

and how the blocks are defined has to be found here (in the decoder/code/__torch__/torch/nn/modules/container.py):

class ModuleDict(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  embed_image : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageEmbedding
  embed_time : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.TimestepEmbedding
  down_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvResblock
  down_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_9.ConvResblock
  down_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_15.ConvResblock
  down_0_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_21.ConvResblock
  down_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_28.ConvResblock
  down_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_34.ConvResblock
  down_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_40.ConvResblock
  down_1_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_46.ConvResblock
  down_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_53.ConvResblock
  down_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_59.ConvResblock
  down_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_65.ConvResblock
  down_2_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_71.ConvResblock
  down_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_77.ConvResblock
  down_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_83.ConvResblock
  down_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_89.ConvResblock
  mid_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_95.ConvResblock
  mid_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_101.ConvResblock
  up_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_108.ConvResblock
  up_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_115.ConvResblock
  up_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_122.ConvResblock
  up_3_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_129.ConvResblock
  up_3_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_135.ConvResblock
  up_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_142.ConvResblock
  up_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_149.ConvResblock
  up_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_156.ConvResblock
  up_2_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_163.ConvResblock
  up_2_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_169.ConvResblock
  up_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_176.ConvResblock
  up_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_183.ConvResblock
  up_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_190.ConvResblock
  up_1_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_197.ConvResblock
  up_1_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_203.ConvResblock
  up_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_210.ConvResblock
  up_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_217.ConvResblock
  up_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_224.ConvResblock
  up_0_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_231.ConvResblock
  output : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageUnembedding

And from there, you need to dig into various of python files to find Conv2d configuration etc.

BTW, this is just unzip the decoder.pt to inspect the underlying Python code.

city96 added a commit to city96/ComfyUI_ExtraModels that referenced this issue Nov 6, 2023
Quick and lazy, just ported the example code. Can't do much more without having the model arch.

openai/consistencydecoder#1
@mrsteyk
Copy link

mrsteyk commented Nov 6, 2023

Uploaded weights and "pseudo code" with correct hparams which contribute to the weight.

@jiamings
Copy link

jiamings commented Nov 7, 2023

The code above looks very much like a conditional UNet with concat conditioning (except that latents are upscaled by 8x using nearest neighbor upsampling). So for a latent of 4x32x32, it would be upsampled to 4x256x256 and then concatenated with the noisy input (3x256x256), then it looks like a regular UNet.

@madebyollin
Copy link

@mrsteyk's code worked for me after some minor edits 👍

image

@mrsteyk
Copy link

mrsteyk commented Nov 7, 2023

Yeah, I realised I messed up skip connections when I went to sleep. Ups originally didn’t have 4 non resizing ConvResblocks

@rvorias
Copy link

rvorias commented Dec 14, 2023

Thanks for this commit. Did you test tiled_decode? Or is it not possible for this model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants