Coverage for src/bob/bio/face/pytorch/facexzoo/AttentionNets.py: 100%
152 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Jun Wang
3@date: 20201019
4@contact: jun21wangustc@gmail.com
5"""
7# based on:
8# https://github.com/tengshaofeng/ResidualAttentionNetwork-pytorch/tree/master/Residual-Attention-Network/model
10import torch.nn as nn
13class Flatten(nn.Module):
14 def forward(self, x):
15 return x.view(x.size(0), -1)
18class ResidualBlock(nn.Module):
19 def __init__(self, input_channels, output_channels, stride=1):
20 super(ResidualBlock, self).__init__()
21 self.input_channels = input_channels
22 self.output_channels = output_channels
23 self.stride = stride
24 self.bn1 = nn.BatchNorm2d(input_channels)
25 self.relu = nn.ReLU(inplace=True)
26 self.conv1 = nn.Conv2d(
27 input_channels, output_channels // 4, 1, 1, bias=False
28 )
29 self.bn2 = nn.BatchNorm2d(output_channels // 4)
30 self.relu = nn.ReLU(inplace=True)
31 self.conv2 = nn.Conv2d(
32 output_channels // 4,
33 output_channels // 4,
34 3,
35 stride,
36 padding=1,
37 bias=False,
38 )
39 self.bn3 = nn.BatchNorm2d(output_channels // 4)
40 self.relu = nn.ReLU(inplace=True)
41 self.conv3 = nn.Conv2d(
42 output_channels // 4, output_channels, 1, 1, bias=False
43 )
44 self.conv4 = nn.Conv2d(
45 input_channels, output_channels, 1, stride, bias=False
46 )
48 def forward(self, x):
49 residual = x
50 out = self.bn1(x)
51 out1 = self.relu(out)
52 out = self.conv1(out1)
53 out = self.bn2(out)
54 out = self.relu(out)
55 out = self.conv2(out)
56 out = self.bn3(out)
57 out = self.relu(out)
58 out = self.conv3(out)
59 if (self.input_channels != self.output_channels) or (self.stride != 1):
60 residual = self.conv4(out1)
61 out += residual
62 return out
65class AttentionModule_stage1(nn.Module):
66 # input size is 56*56
67 def __init__(
68 self,
69 in_channels,
70 out_channels,
71 size1=(56, 56),
72 size2=(28, 28),
73 size3=(14, 14),
74 ):
75 super(AttentionModule_stage1, self).__init__()
76 self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
77 self.trunk_branches = nn.Sequential(
78 ResidualBlock(in_channels, out_channels),
79 ResidualBlock(in_channels, out_channels),
80 )
81 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
82 self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
83 self.skip1_connection_residual_block = ResidualBlock(
84 in_channels, out_channels
85 )
86 self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
87 self.softmax2_blocks = ResidualBlock(in_channels, out_channels)
88 self.skip2_connection_residual_block = ResidualBlock(
89 in_channels, out_channels
90 )
91 self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
92 self.softmax3_blocks = nn.Sequential(
93 ResidualBlock(in_channels, out_channels),
94 ResidualBlock(in_channels, out_channels),
95 )
96 self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
97 self.softmax4_blocks = ResidualBlock(in_channels, out_channels)
98 self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
99 self.softmax5_blocks = ResidualBlock(in_channels, out_channels)
100 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
101 self.softmax6_blocks = nn.Sequential(
102 nn.BatchNorm2d(out_channels),
103 nn.ReLU(inplace=True),
104 nn.Conv2d(
105 out_channels, out_channels, kernel_size=1, stride=1, bias=False
106 ),
107 nn.BatchNorm2d(out_channels),
108 nn.ReLU(inplace=True),
109 nn.Conv2d(
110 out_channels, out_channels, kernel_size=1, stride=1, bias=False
111 ),
112 nn.Sigmoid(),
113 )
114 self.last_blocks = ResidualBlock(in_channels, out_channels)
116 def forward(self, x):
117 x = self.first_residual_blocks(x)
118 out_trunk = self.trunk_branches(x)
119 out_mpool1 = self.mpool1(x)
120 out_softmax1 = self.softmax1_blocks(out_mpool1)
121 out_skip1_connection = self.skip1_connection_residual_block(
122 out_softmax1
123 )
124 out_mpool2 = self.mpool2(out_softmax1)
125 out_softmax2 = self.softmax2_blocks(out_mpool2)
126 out_skip2_connection = self.skip2_connection_residual_block(
127 out_softmax2
128 )
129 out_mpool3 = self.mpool3(out_softmax2)
130 out_softmax3 = self.softmax3_blocks(out_mpool3)
131 #
132 out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
133 # print(out_skip2_connection.data)
134 # print(out_interp3.data)
135 out = out_interp3 + out_skip2_connection
136 out_softmax4 = self.softmax4_blocks(out)
137 out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
138 out = out_interp2 + out_skip1_connection
139 out_softmax5 = self.softmax5_blocks(out)
140 out_interp1 = self.interpolation1(out_softmax5) + out_trunk
141 out_softmax6 = self.softmax6_blocks(out_interp1)
142 out = (1 + out_softmax6) * out_trunk
143 out_last = self.last_blocks(out)
145 return out_last
148class AttentionModule_stage2(nn.Module):
149 # input image size is 28*28
150 def __init__(
151 self, in_channels, out_channels, size1=(28, 28), size2=(14, 14)
152 ):
153 super(AttentionModule_stage2, self).__init__()
154 self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
155 self.trunk_branches = nn.Sequential(
156 ResidualBlock(in_channels, out_channels),
157 ResidualBlock(in_channels, out_channels),
158 )
159 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
160 self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
161 self.skip1_connection_residual_block = ResidualBlock(
162 in_channels, out_channels
163 )
164 self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
165 self.softmax2_blocks = nn.Sequential(
166 ResidualBlock(in_channels, out_channels),
167 ResidualBlock(in_channels, out_channels),
168 )
169 self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
170 self.softmax3_blocks = ResidualBlock(in_channels, out_channels)
171 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
172 self.softmax4_blocks = nn.Sequential(
173 nn.BatchNorm2d(out_channels),
174 nn.ReLU(inplace=True),
175 nn.Conv2d(
176 out_channels, out_channels, kernel_size=1, stride=1, bias=False
177 ),
178 nn.BatchNorm2d(out_channels),
179 nn.ReLU(inplace=True),
180 nn.Conv2d(
181 out_channels, out_channels, kernel_size=1, stride=1, bias=False
182 ),
183 nn.Sigmoid(),
184 )
185 self.last_blocks = ResidualBlock(in_channels, out_channels)
187 def forward(self, x):
188 x = self.first_residual_blocks(x)
189 out_trunk = self.trunk_branches(x)
190 out_mpool1 = self.mpool1(x)
191 out_softmax1 = self.softmax1_blocks(out_mpool1)
192 out_skip1_connection = self.skip1_connection_residual_block(
193 out_softmax1
194 )
195 out_mpool2 = self.mpool2(out_softmax1)
196 out_softmax2 = self.softmax2_blocks(out_mpool2)
197 out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
198 # print(out_skip2_connection.data)
199 # print(out_interp3.data)
200 out = out_interp2 + out_skip1_connection
201 out_softmax3 = self.softmax3_blocks(out)
202 out_interp1 = self.interpolation1(out_softmax3) + out_trunk
203 out_softmax4 = self.softmax4_blocks(out_interp1)
204 out = (1 + out_softmax4) * out_trunk
205 out_last = self.last_blocks(out)
206 return out_last
209class AttentionModule_stage3(nn.Module):
210 # input image size is 14*14
211 def __init__(self, in_channels, out_channels, size1=(14, 14)):
212 super(AttentionModule_stage3, self).__init__()
213 self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
214 self.trunk_branches = nn.Sequential(
215 ResidualBlock(in_channels, out_channels),
216 ResidualBlock(in_channels, out_channels),
217 )
218 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
219 self.softmax1_blocks = nn.Sequential(
220 ResidualBlock(in_channels, out_channels),
221 ResidualBlock(in_channels, out_channels),
222 )
223 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
224 self.softmax2_blocks = nn.Sequential(
225 nn.BatchNorm2d(out_channels),
226 nn.ReLU(inplace=True),
227 nn.Conv2d(
228 out_channels, out_channels, kernel_size=1, stride=1, bias=False
229 ),
230 nn.BatchNorm2d(out_channels),
231 nn.ReLU(inplace=True),
232 nn.Conv2d(
233 out_channels, out_channels, kernel_size=1, stride=1, bias=False
234 ),
235 nn.Sigmoid(),
236 )
237 self.last_blocks = ResidualBlock(in_channels, out_channels)
239 def forward(self, x):
240 x = self.first_residual_blocks(x)
241 out_trunk = self.trunk_branches(x)
242 out_mpool1 = self.mpool1(x)
243 out_softmax1 = self.softmax1_blocks(out_mpool1)
244 out_interp1 = self.interpolation1(out_softmax1) + out_trunk
245 out_softmax2 = self.softmax2_blocks(out_interp1)
246 out = (1 + out_softmax2) * out_trunk
247 out_last = self.last_blocks(out)
248 return out_last
251class ResidualAttentionNet(nn.Module):
252 def __init__(
253 self,
254 stage1_modules,
255 stage2_modules,
256 stage3_modules,
257 feat_dim,
258 out_h,
259 out_w,
260 ):
261 super(ResidualAttentionNet, self).__init__()
262 self.conv1 = nn.Sequential(
263 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
264 nn.BatchNorm2d(64),
265 nn.ReLU(inplace=True),
266 )
267 attention_modules = []
269 attention_modules.append(ResidualBlock(64, 256))
270 # stage 1
271 for i in range(stage1_modules):
272 attention_modules.append(AttentionModule_stage1(256, 256))
274 attention_modules.append(ResidualBlock(256, 512, 2))
275 # stage2
276 for i in range(stage2_modules):
277 attention_modules.append(AttentionModule_stage2(512, 512))
279 attention_modules.append(ResidualBlock(512, 1024, 2))
280 # stage3
281 for i in range(stage3_modules):
282 attention_modules.append(AttentionModule_stage3(1024, 1024))
284 # final residual
285 attention_modules.append(ResidualBlock(1024, 2048, 2))
286 attention_modules.append(ResidualBlock(2048, 2048))
287 attention_modules.append(ResidualBlock(2048, 2048))
288 self.attention_body = nn.Sequential(*attention_modules)
289 # output layer
290 self.output_layer = nn.Sequential(
291 Flatten(),
292 nn.Linear(2048 * out_h * out_w, feat_dim, False),
293 nn.BatchNorm1d(feat_dim),
294 )
296 def forward(self, x):
297 out = self.conv1(x)
298 out = self.attention_body(out)
299 out = self.output_layer(out)
300 return out