From 9b02857c08543ed4100e37667790793782df40ce Mon Sep 17 00:00:00 2001 From: Cone <127710303+vancyland@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:10:56 +0800 Subject: [PATCH] Update attention.py --- animatediff/models/attention.py | 90 +++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py index ad23583c..d79e9740 100644 --- a/animatediff/models/attention.py +++ b/animatediff/models/attention.py @@ -28,6 +28,96 @@ class Transformer3DModelOutput(BaseOutput): xformers = None +class SparseCausalAttention2D(CrossAttention): + def _attention_with_prob(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states, attention_scores + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, exemplar_latent=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + # query0 = rearrange(query, "(b f) d c -> b f d c", f=video_length) + # query0 = query0[:, [0]] + # query0 = rearrange(query0, "b f d c -> (b f) d c") + # key0 = rearrange(key, "(b f) d c -> b (f d) c", f=video_length) + # key0 = self.reshape_heads_to_batch_dim(key0) + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # # attention, what we cannot get enough of + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + class Transformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__(