Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update attention.py by adding SparseCausalAttention implementation #380

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions animatediff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,96 @@ class Transformer3DModelOutput(BaseOutput):
xformers = None


class SparseCausalAttention2D(CrossAttention):
def _attention_with_prob(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()

attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)

if attention_mask is not None:
attention_scores = attention_scores + attention_mask

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = attention_scores.softmax(dim=-1)

# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)

# compute attention output
hidden_states = torch.bmm(attention_probs, value)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states, attention_scores

def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, exemplar_latent=None):
batch_size, sequence_length, _ = hidden_states.shape

encoder_hidden_states = encoder_hidden_states

if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)

if self.added_kv_proj_dim is not None:
raise NotImplementedError

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

former_frame_index = torch.arange(video_length) - 1
former_frame_index[0] = 0

# query0 = rearrange(query, "(b f) d c -> b f d c", f=video_length)
# query0 = query0[:, [0]]
# query0 = rearrange(query0, "b f d c -> (b f) d c")
# key0 = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
# key0 = self.reshape_heads_to_batch_dim(key0)

key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
key = rearrange(key, "b f d c -> (b f) d c")

value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
value = rearrange(value, "b f d c -> (b f) d c")

key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)

# # attention, what we cannot get enough of
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)

# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states


class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
Expand Down