-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement MMDIT block that is necessary for flux (#592)
- Loading branch information
1 parent
539be41
commit 680cd6e
Showing
7 changed files
with
343 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
"""MMDIT Layers adapted from black-forest-labs' flux implementation | ||
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py | ||
""" | ||
|
||
import torch.nn.functional as F | ||
import torch | ||
from torch import Tensor | ||
|
||
from .. import ops | ||
|
||
from .base import Theta, ThetaLayer | ||
from .linear import LinearLayer | ||
from .modulation import ModulationLayer | ||
from .norm import RMSNormLayer | ||
from .paged_llama_attention_block import PagedLlamaAttentionBlock | ||
|
||
|
||
def qk_norm(q, k, v, rms_q, rms_k): | ||
return rms_q(q).to(v), rms_k(k).to(v) | ||
|
||
|
||
# TODO: Work on unifying with the current RoPE layer | ||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | ||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | ||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | ||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | ||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | ||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | ||
|
||
|
||
def attention(q, k, v, pe): | ||
q, k = apply_rope(q, k, pe) # todo | ||
|
||
x = ops.scaled_dot_product_attention( | ||
q=q, k=k, v=v, a=None, is_causal=True, scale=None | ||
) | ||
x = ops.permute(x, (0, 2, 1, 3)) | ||
x = x.view(x.shape[0], x.shape[1], -1) | ||
|
||
return x | ||
|
||
|
||
class MMDITDoubleBlock(ThetaLayer): | ||
def __init__(self, theta, num_heads: int): | ||
super().__init__(theta) | ||
|
||
self.num_heads = num_heads | ||
self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True)) | ||
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv"))) | ||
self.add_module( | ||
"img_attn_norm_q", | ||
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6), | ||
) | ||
self.add_module( | ||
"img_attn_norm_k", | ||
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6), | ||
) | ||
self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj"))) | ||
|
||
self.add_module("img_mlp1", LinearLayer(theta("img_mlp.0"))) | ||
self.add_module("img_mlp2", LinearLayer(theta("img_mlp.2"))) | ||
|
||
self.add_module("txt_mod", ModulationLayer(theta("txt_mod"), double=True)) | ||
self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))) | ||
self.add_module( | ||
"txt_attn_norm_q", | ||
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6), | ||
) | ||
self.add_module( | ||
"txt_attn_norm_k", | ||
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6), | ||
) | ||
self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj"))) | ||
|
||
self.add_module("txt_mlp1", LinearLayer(theta("txt_mlp.0"))) | ||
self.add_module("txt_mlp2", LinearLayer(theta("txt_mlp.2"))) | ||
|
||
def forward( | ||
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor | ||
) -> tuple[Tensor, Tensor]: | ||
img_mod1, img_mod2 = self.img_mod(vec) | ||
txt_mod1, txt_mod2 = self.txt_mod(vec) | ||
|
||
# prepare image for attention | ||
img_modulated = ops.layer_norm(img, None, None, eps=1e-6) | ||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift | ||
img_qkv = self.img_attn_qkv(img_modulated) | ||
img_qkv_2 = img_qkv.view( | ||
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1 | ||
) # | ||
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4)) | ||
img_q, img_k, img_v = img_qkv_3 | ||
img_q, img_k = qk_norm( | ||
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k | ||
) | ||
|
||
# prepare text for attention | ||
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6) | ||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift | ||
txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
txt_qkv_2 = txt_qkv.view( | ||
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1 | ||
) # | ||
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4)) | ||
txt_q, txt_k, txt_v = txt_qkv_3 | ||
txt_q, txt_k = qk_norm( | ||
txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k | ||
) | ||
|
||
# run actual attention | ||
q = torch.cat((txt_q, img_q), dim=2) | ||
k = torch.cat((txt_k, img_k), dim=2) | ||
v = torch.cat((txt_v, img_v), dim=2) | ||
|
||
attn = attention(q, k, v, pe) | ||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] | ||
|
||
# calculate the image blocks | ||
# TODO: Refactor this for code reuse with the txt blocks | ||
img = img + img_mod1.gate * self.img_attn_proj(img_attn) | ||
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm( | ||
img, None, None, eps=1e-6 | ||
) + img_mod2.shift | ||
img_mlp_out1 = self.img_mlp1(img_mlp_in) | ||
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1) | ||
img_mlp_out3 = self.img_mlp2(img_mlp_out2) | ||
img = img + img_mod2.gate * img_mlp_out3 | ||
|
||
# calculate the text blocks | ||
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn) | ||
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm( | ||
txt, None, None, eps=1e-6 | ||
) + txt_mod2.shift | ||
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in) | ||
# TODO: Unify with modulation layer by taking act_fn as an arg | ||
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1) | ||
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2) | ||
txt = txt + txt_mod2.gate * txt_mlp_out3 | ||
|
||
return img, txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
"""Modulation Layer adapted from black-forest-labs' flux implementation | ||
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py | ||
""" | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .. import ops | ||
|
||
from .base import Theta, ThetaLayer | ||
from .linear import LinearLayer | ||
|
||
|
||
class ModulationOut: | ||
def __init__(self, shift, scale, gate): | ||
self.shift = shift | ||
self.scale = scale | ||
self.gate = gate | ||
|
||
|
||
class ModulationLayer(ThetaLayer): | ||
def __init__(self, theta: Theta, double: bool): | ||
super().__init__(theta) | ||
|
||
self.is_double = double | ||
self.multiplier = 6 if double else 3 | ||
self.add_module("lin", LinearLayer(theta("lin"))) | ||
|
||
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]: | ||
silu_result = ops.elementwise(F.silu, vec) | ||
out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1) | ||
|
||
return ( | ||
ModulationOut(*out[:3]), | ||
ModulationOut(*out[3:]) if self.is_double else None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import logging | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
from iree.turbine import aot | ||
from sharktank.layers import ( | ||
MMDITDoubleBlock, | ||
) | ||
import sharktank.ops as ops | ||
from sharktank.layers.testing import ( | ||
make_mmdit_double_block_theta, | ||
) | ||
from sharktank.types.tensors import DefaultPrimitiveTensor | ||
|
||
|
||
class MMDITTest(unittest.TestCase): | ||
def setUp(self): | ||
torch.manual_seed(12345) | ||
self.hidden_size = 3072 | ||
self.num_heads = 24 | ||
self.batch_size = 3 | ||
|
||
def testDoubleExport(self): | ||
|
||
theta = make_mmdit_double_block_theta() | ||
mmdit = MMDITDoubleBlock( | ||
theta=theta, | ||
num_heads=self.num_heads, | ||
) | ||
|
||
img = torch.rand([self.batch_size, 1024, self.hidden_size]) | ||
txt = torch.rand([self.batch_size, 512, self.hidden_size]) | ||
vec = torch.rand([self.batch_size, self.hidden_size]) | ||
rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2]) | ||
mmdit.forward(img, txt, vec, rot) | ||
fxb = aot.FxProgramsBuilder(mmdit) | ||
|
||
@fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False) | ||
def _(model, img, txt, vec, rot) -> torch.Tensor: | ||
return model.forward(img, txt, vec, rot) | ||
|
||
output = aot.export(fxb) | ||
output.verify() | ||
asm = str(output.mlir_module) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |