forked from HabanaAI/Model-References
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention_processor.py
executable file
·200 lines (157 loc) · 7.38 KB
/
attention_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Copyright (C) 2023 Habana Labs, Ltd. an Intel Company
###############################################################################
# Changes:
# - Added Fused SDPA
# - Added scaled_dot_product_attention
# - Replaced pytorch softmax with optimized softmax_fp8
# - Added fused mult changes
# - Added Support for quantized SDPA
from importlib import import_module
from typing import Callable, Optional, Union
import os
import math
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention_processor import Attention
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Softmax(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, dim = None, invAttnHead= None):
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)
class Matmul(nn.Module):
def __init__(self):
super().__init__()
def forward(self, *args, **kwargs):
return torch.matmul(*args, **kwargs)
# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
self.bmm1 = Matmul()
self.bmm2 = Matmul()
self.softmax = Softmax()
def forward(self, query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to('hpu')
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
if(S<128):
attn_weight = self.bmm1(key,query.transpose(-2, -1))
attn_weight = self.softmax(attn_weight, dim=-2, invAttnHead=invAttnHead)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return self.bmm2(attn_weight.transpose(-2, -1), value)
else:
attn_weight = self.bmm1(query, key.transpose(-2, -1))
attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return self.bmm2(attn_weight, value)
# Copied from diffusers.models.attention_processor.AttnProcessor2_0
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, attention_module=None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.attention_module = attention_module
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale: float = 1.0,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
if os.environ.get('PATCH_SDPA') is not None:
hidden_states = self.attention_module(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
else:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
import habana_frameworks.torch.hpu as ht
with ht.sdp_kernel(enable_recompute = True):
hidden_states = FusedSDPA.apply(query, key, value, attention_mask, 0.0, False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
AttentionProcessor = Union[
AttnProcessor2_0,
]