-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_model.py
191 lines (168 loc) · 8.28 KB
/
my_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from functools import partial
import torch
import torch.nn as nn
def drop_path(x,drop_prob:float = 0.,training:bool = False):
if drop_prob==0 or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1)
random_tensor = keep_prob + torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor()
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path(x,self.drop_prob,self.training)
class PatchEmbed(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
super().__init__()
img_size = (img_size,img_size)
patch_size = (patch_size,patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1])
self.num_patches = self.grid_size[0]*self.grid_size[1]
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self,x):
B,C,H,W = x.shape
assert H==self.img_size[0] and W==self.img_size[1],\
f"input image size ({H}*{W}) does not match the model({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1,2)
x = self.norm(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention,self).__init__()
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim,dim*3,bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self,x):
B,N,C = x.shape
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv[0],qkv[1],qkv[2]
attn = ([email protected](-2,-1))*self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn@v).transpose(1,2).reshape(B,N,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer = nn.GELU,drop = 0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop =nn.Dropout(drop)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block,self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio>0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim*mlp_ratio)
self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop_ratio)
def forward(self,x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features=self.embed_dim=embed_dim
self.num_tokens=1
norm_layer=norm_layer or partial(nn.LayerNorm,eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
num_patches=self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
#这里首先创建了一个 Dropout 模块 pos_drop,其 p 参数设为 drop_ratio,
#用于在特征向量中以一定概率丢弃元素,以防止过拟合。
#接下来,dpr 列表通过 torch.linspace 生成,该列表表示了在 Transformer 层中应用的随机深度衰减规则。torch.linspace(0, drop_path_ratio, depth) 会生成一个从 0 到 drop_path_ratio 等间距的浮点数序列,长度为 depth(即 Transformer 层的数量)。列表中的每个元素 x.item() 表示相应 Transformer 层的随机深度概率。drop_path_ratio 是一个超参数,控制整个模型中随机深度的概率。
dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
self.blocks = nn.Sequential(*[
Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.has_logits = False
self.pre_logits = nn.Identity()
self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
nn.init.trunc_normal_(self.pos_embed,std=0.02)
nn.init.trunc_normal_(self.cls_token,std=0.02)
self.apply(_init_vit_weights)
def forward_features(self,x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0],-1,-1)
x = torch.cat((cls_token,x),dim=1)
x = self.pos_drop(x+self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:,0]#如果没有蒸馏(即self.dist_token为None),返回特征序列中第0个元素x[:, 0]。这通常代表了经过模型编码后的CLS Token的特征向量,它聚合了整个输入图像的信息,常用于下游任务(如分类)的最终预测。
def forward(self,x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_vit_weights(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)#layernorm会在归一化之后用*weight+bias来缩放处理数据
nn.init.ones_(m.weight)
def create_model(num_classes=10):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=num_classes)
return model