论文地址:https://arxiv.org/pdf/2010.11929.pdf

源码地址:google-research/vision_transformer (github.com)

文章引用源码:https://github.com/bubbliiiing/classification-pytorch

文章出处:https://blog.csdn.net/weixin_44791964/article/details/122637701

实现思路

Vision Transformer是Transformer的视觉版本,Transformer基本上已经成为了自然语言处理的标配,但是在视觉中的运用还受到限制。

Vision Transformer打破了这种NLP与CV的隔离,将Transformer应用于图像图块(patch)序列上,进一步完成图像分类任务。简单来理解,Vision Transformer就是将输入进来的图片,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列,将组合后的结果传入Transformer特有的Multi-head Self-attention进行特征提取。最后利用Cls Token进行分类。

整体架构

与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。

在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

网络结构详解

特征提取部分

a)Patch+Position Embedding

该部分作用:对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。

该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。

在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是[224, 224, 3]的时候,我们可以获得一个[14, 14, 768]的特征层。

下一步就是将这个特征层组合成序列,组合的方式非常简单,就是将高宽维度进行平铺,[14, 14, 768]在高宽维度平铺后,获得一个196, 768的特征层。平铺完成后,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,图中的这个0*就是Cls Token,我们此时获得一个197, 768的特征层。

添加完成Cls Token后,再为所有特征添加上位置信息,这样网络才有区分不同区域的能力。添加方式其实也非常简单,我们生成一个197, 768的参数矩阵,这个参数矩阵是可训练的,把这个矩阵加上197, 768的特征层即可。

到这里,Patch+Position Embedding就构建完成了,构建代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# [224, 224, 3]->[14, 14, 768]
class PatchEmbed(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
super().__init__()
# 196 = 14 * 14
self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
self.flatten = flatten

self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
# x = [b, 196, 768]
return x

class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))

def forward_features(self, x):
# x = [b, 196, 768]
x = self.patch_embed(x)
# cls_token = [b, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# x = [b, 197, 768]
x = torch.cat((cls_token, x), dim=1)
# [1, 1, 768]
cls_token_pe = self.pos_embed[:, 0:1, :]
# [1, 196, 768]
img_token_pe = self.pos_embed[:, 1: , :]
# [1, 196, 768]->[1, 14, 14, 768]->[1, 768, 14, 14]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
# 做插值,以防输入图片不是224*224
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
# [1, 768, 14, 14]->[1, 14, 14, 768]->[1, 196, 768]
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
# [1, 197, 768]
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)

x = self.pos_drop(x + pos_embed)

b)transformer encoder

在上一步获得shape为197, 768的序列信息后,将序列信息传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

1)self-attention结构解析

看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。

如果我们想要获得input-1的输出,那么我们进行如下几步:
1、利用input-1的查询向量,分别乘上input-1、input-2、input-3的键向量,此时我们获得了三个score。
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和。
4、此时我们获得了input-1的输出。

如图所示,我们进行如下几步:
1、input-1的查询向量为[1, 0, 2],分别乘上input-1、input-2、input-3的键向量,获得三个score为2,4,4。
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度,获得三个重要程度为0.0,0.5,0.5。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和,即
0.0 ∗ [ 1 , 2 , 3 ] + 0.5 ∗ [ 2 , 8 , 0 ] + 0.5 ∗ [ 2 , 6 , 3 ] = [ 2.0 , 7.0 , 1.5 ]
4、此时我们获得了input-1的输出 [2.0, 7.0, 1.5]。

上述的例子中,序列长度仅为3,每个单位序列的特征长度仅为3,在VIT的Transformer Encoder中,序列长度为197,每个单位序列的特征长度为768 // num_heads。但计算过程是一样的。在实际运算时,我们采用矩阵进行运算。
2)self-attention的矩阵运算

实际的矩阵运算过程如下图所示。我以实际矩阵为例子给大家解析:

输入的Query、Key、Value如下图所示:

首先利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。

输出的每一行,都代表input-1、input-2、input-3,对当前input的贡献,我们对这个贡献值取一个softmax。

然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。

矩阵代码运算如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np

def soft_max(z):
t = np.exp(z)
a = np.exp(z) / np.expand_dims(np.sum(t, axis=1), 1)
return a

Query = np.array([
[1,0,2],
[2,2,2],
[2,1,3]
])

Key = np.array([
[0,1,1],
[4,4,0],
[2,3,1]
])

Value = np.array([
[1,2,3],
[2,8,0],
[2,6,3]
])

scores = Query @ Key.T
print(scores)
scores = soft_max(scores)
print(scores)
out = scores @ Value
print(out)

3)Multihead多头注意力机制

多头注意力机制的示意图如图所示:

这幅图给人的感觉略显迷茫,我们跳脱出这个图,直接从矩阵的shape入手会清晰很多。

在第一步进行图像的分割后,我们获得的特征层为197, 768。

在施加多头的时候,我们直接对196, 768的最后一维度进行分割,比如我们想分割成12个头,那么矩阵的shape就变成了196, 12, 64。

然后我们将196, 12, 64进行转置,将12放到前面去,获得的特征层为12, 196, 64。之后我们忽略这个12,把它和batch维度同等对待,只对196, 64进行处理,其实也就是上面的注意力机制的过程了。

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#--------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
# 768->768*3
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
# batch, 196, 768
B, N, C = x.shape
# batch, 196, 768 -> batch, 196, 768*3 -> batch, 196, 3, 8, 768/8=96 -> 3, batch, 8, 196, 96
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# 3 * 1, batch, 8, 196, 96 -> q, k, v = batch: 16, head: 8, patch: 196, each_head_attention_channels: 96
q, k, v = qkv[0], qkv[1], qkv[2]
# batch, 8, 196, 96 @ batch, 8, 96, 196 -> batch, 8, 196, 196
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# batch, 8, 196, 196 @ batch, 8, 196, 96 -> batch, 8, 196, 96 -> batch, 196, 8, 96 -> batch, 196, 768
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# batch, 196, 768 -> batch, 196, 768
x = self.proj(x)
# Dropout(batch, 196, 768)
x = self.proj_drop(x)
return x

4)TransformerBlock的构建

在完成MultiHeadSelfAttention的构建后,我们需要在其后加上两个全连接。就构建了整个TransformerBlock

block流程见下图:

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)

self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])

def forward(self, x):
# batch, 196, 768 -> batch, 196, 768
x = self.fc1(x)
# batch, 196, 768 -> batch, 196, 768
x = self.act(x)
# batch, 196, 768 -> batch, 196, 768
x = self.drop1(x)
# batch, 196, 768 -> batch, 196, 768
x = self.fc2(x)
# batch, 196, 768 -> batch, 196, 768
x = self.drop2(x)
return x

# a transoformer encoder block
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

c)VIT模型构建

整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super(VisionTransformer, self).__init__()
#-----------------------------------------------#
# 224, 224, 3 -> batch, 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 1, 1, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 1, 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
# 1, 197, 768
self.pos_drop = nn.Dropout(p=drop_rate)

#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)

cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]

img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)

x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]

def forward(self, x):
# # 整个Transformer Encoder = batch, 768
x = self.forward_features(x)
# 最后的MLP Header = batch, 768 -> 768 -> 1000 -> batch, 1000
x = self.head(x)
return x

def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False

def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True

分类部分

在分类部分,VIT所做的工作是利用提取到的特征进行分类。

在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。

最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)

#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)

cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]

img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)

x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]

def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False

def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True


VIT构建代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import math
from collections import OrderedDict
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

#--------------------------------------#
# Gelu激活函数的实现
# 利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):
def __init__(self):
super(GELU, self).__init__()

def forward(self, x):
return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))

def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output

class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

class PatchEmbed(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
super().__init__()
self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
self.flatten = flatten

self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x

#--------------------------------------------------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)

self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x

class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

class VisionTransformer(nn.Module):
def __init__(
self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
):
super().__init__()
#-----------------------------------------------#
# 224, 224, 3 -> 196, 768
#-----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

#--------------------------------------------------------------------------------------------------------------------#
# classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
# 196, 768 -> 197, 768
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
#--------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息。
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)

#-----------------------------------------------#
# 197, 768 -> 197, 768 12次
#-----------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(
dim = num_features,
num_heads = num_heads,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop = drop_rate,
attn_drop = attn_drop_rate,
drop_path = dpr[i],
norm_layer = norm_layer,
act_layer = act_layer
)for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)

cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1: , :]

img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)

x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]

def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

def freeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = False
except:
module.requires_grad = False

def Unfreeze_backbone(self):
backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
for module in backbone:
try:
for param in module.parameters():
param.requires_grad = True
except:
module.requires_grad = True


def vit(input_shape=[224, 224], pretrained=False, num_classes=1000):
model = VisionTransformer(input_shape)
if pretrained:
model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))

if num_classes!=1000:
model.head = nn.Linear(model.num_features, num_classes)
return model