Coverage for src/bob/bio/face/pytorch/facexzoo/TF_NAS.py: 67%
298 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Yibo Hu, Jun Wang
3@date: 20201019
4@contact: jun21wangustc@gmail.com
5"""
7from collections import OrderedDict
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
14def channel_shuffle(x, groups):
15 assert groups > 1
16 batchsize, num_channels, height, width = x.size()
17 assert num_channels % groups == 0
18 channels_per_group = num_channels // groups
19 # reshape
20 x = x.view(batchsize, groups, channels_per_group, height, width)
21 # transpose
22 x = torch.transpose(x, 1, 2).contiguous()
23 # flatten
24 x = x.view(batchsize, -1, height, width)
25 return x
28def get_same_padding(kernel_size):
29 if isinstance(kernel_size, tuple):
30 assert len(kernel_size) == 2, "invalid kernel size: {}".format(
31 kernel_size
32 )
33 p1 = get_same_padding(kernel_size[0])
34 p2 = get_same_padding(kernel_size[1])
35 return p1, p2
36 assert isinstance(
37 kernel_size, int
38 ), "kernel size should be either `int` or `tuple`"
39 assert kernel_size % 2 > 0, "kernel size should be odd number"
40 return kernel_size // 2
43class Swish(nn.Module):
44 def __init__(self, inplace=False):
45 super(Swish, self).__init__()
46 self.inplace = inplace
48 def forward(self, x):
49 if self.inplace:
50 return x.mul_(x.sigmoid())
51 else:
52 return x * x.sigmoid()
55class HardSwish(nn.Module):
56 def __init__(self, inplace=False):
57 super(HardSwish, self).__init__()
58 self.inplace = inplace
60 def forward(self, x):
61 if self.inplace:
62 return x.mul_(F.relu6(x + 3.0, inplace=True) / 6.0)
63 else:
64 return x * F.relu6(x + 3.0) / 6.0
67class BasicLayer(nn.Module):
68 def __init__(
69 self,
70 in_channels,
71 out_channels,
72 use_bn=True,
73 affine=True,
74 act_func="relu6",
75 ops_order="weight_bn_act",
76 ):
77 super(BasicLayer, self).__init__()
79 self.in_channels = in_channels
80 self.out_channels = out_channels
81 self.use_bn = use_bn
82 self.affine = affine
83 self.act_func = act_func
84 self.ops_order = ops_order
86 """ add modules """
87 # batch norm
88 if self.use_bn:
89 if self.bn_before_weight:
90 self.bn = nn.BatchNorm2d(
91 in_channels, affine=affine, track_running_stats=affine
92 )
93 else:
94 self.bn = nn.BatchNorm2d(
95 out_channels, affine=affine, track_running_stats=affine
96 )
97 else:
98 self.bn = None
99 # activation
100 if act_func == "relu":
101 if self.ops_list[0] == "act":
102 self.act = nn.ReLU(inplace=False)
103 else:
104 self.act = nn.ReLU(inplace=True)
105 elif act_func == "relu6":
106 if self.ops_list[0] == "act":
107 self.act = nn.ReLU6(inplace=False)
108 else:
109 self.act = nn.ReLU6(inplace=True)
110 elif act_func == "swish":
111 if self.ops_list[0] == "act":
112 self.act = Swish(inplace=False)
113 else:
114 self.act = Swish(inplace=True)
115 elif act_func == "h-swish":
116 if self.ops_list[0] == "act":
117 self.act = HardSwish(inplace=False)
118 else:
119 self.act = HardSwish(inplace=True)
120 else:
121 self.act = None
123 @property
124 def ops_list(self):
125 return self.ops_order.split("_")
127 @property
128 def bn_before_weight(self):
129 for op in self.ops_list:
130 if op == "bn":
131 return True
132 elif op == "weight":
133 return False
134 raise ValueError("Invalid ops_order: %s" % self.ops_order)
136 def weight_call(self, x):
137 raise NotImplementedError
139 def forward(self, x):
140 for op in self.ops_list:
141 if op == "weight":
142 x = self.weight_call(x)
143 elif op == "bn":
144 if self.bn is not None:
145 x = self.bn(x)
146 elif op == "act":
147 if self.act is not None:
148 x = self.act(x)
149 else:
150 raise ValueError("Unrecognized op: %s" % op)
151 return x
154class ConvLayer(BasicLayer):
155 def __init__(
156 self,
157 in_channels,
158 out_channels,
159 kernel_size=3,
160 stride=1,
161 groups=1,
162 has_shuffle=False,
163 bias=False,
164 use_bn=True,
165 affine=True,
166 act_func="relu6",
167 ops_order="weight_bn_act",
168 ):
169 super(ConvLayer, self).__init__(
170 in_channels, out_channels, use_bn, affine, act_func, ops_order
171 )
173 self.kernel_size = kernel_size
174 self.stride = stride
175 self.groups = groups
176 self.has_shuffle = has_shuffle
177 self.bias = bias
179 padding = get_same_padding(self.kernel_size)
180 self.conv = nn.Conv2d(
181 in_channels,
182 out_channels,
183 kernel_size=self.kernel_size,
184 stride=self.stride,
185 padding=padding,
186 groups=self.groups,
187 bias=self.bias,
188 )
190 def weight_call(self, x):
191 x = self.conv(x)
192 if self.has_shuffle and self.groups > 1:
193 x = channel_shuffle(x, self.groups)
194 return x
197class LinearLayer(nn.Module):
198 def __init__(
199 self,
200 in_features,
201 out_features,
202 bias=True,
203 use_bn=False,
204 affine=False,
205 act_func=None,
206 ops_order="weight_bn_act",
207 ):
208 super(LinearLayer, self).__init__()
210 self.in_features = in_features
211 self.out_features = out_features
212 self.bias = bias
213 self.use_bn = use_bn
214 self.affine = affine
215 self.act_func = act_func
216 self.ops_order = ops_order
218 """ add modules """
219 # batch norm
220 if self.use_bn:
221 if self.bn_before_weight:
222 self.bn = nn.BatchNorm1d(
223 in_features, affine=affine, track_running_stats=affine
224 )
225 else:
226 self.bn = nn.BatchNorm1d(
227 out_features, affine=affine, track_running_stats=affine
228 )
229 else:
230 self.bn = None
231 # activation
232 if act_func == "relu":
233 if self.ops_list[0] == "act":
234 self.act = nn.ReLU(inplace=False)
235 else:
236 self.act = nn.ReLU(inplace=True)
237 elif act_func == "relu6":
238 if self.ops_list[0] == "act":
239 self.act = nn.ReLU6(inplace=False)
240 else:
241 self.act = nn.ReLU6(inplace=True)
242 elif act_func == "tanh":
243 self.act = nn.Tanh()
244 elif act_func == "sigmoid":
245 self.act = nn.Sigmoid()
246 else:
247 self.act = None
248 # linear
249 self.linear = nn.Linear(self.in_features, self.out_features, self.bias)
251 @property
252 def ops_list(self):
253 return self.ops_order.split("_")
255 @property
256 def bn_before_weight(self):
257 for op in self.ops_list:
258 if op == "bn":
259 return True
260 elif op == "weight":
261 return False
262 raise ValueError("Invalid ops_order: %s" % self.ops_order)
264 def forward(self, x):
265 for op in self.ops_list:
266 if op == "weight":
267 x = self.linear(x)
268 elif op == "bn":
269 if self.bn is not None:
270 x = self.bn(x)
271 elif op == "act":
272 if self.act is not None:
273 x = self.act(x)
274 else:
275 raise ValueError("Unrecognized op: %s" % op)
276 return x
279class MBInvertedResBlock(nn.Module):
280 def __init__(
281 self,
282 in_channels,
283 mid_channels,
284 se_channels,
285 out_channels,
286 kernel_size=3,
287 stride=1,
288 groups=1,
289 has_shuffle=False,
290 bias=False,
291 use_bn=True,
292 affine=True,
293 act_func="relu6",
294 ):
295 super(MBInvertedResBlock, self).__init__()
297 self.in_channels = in_channels
298 self.mid_channels = mid_channels
299 self.se_channels = se_channels
300 self.out_channels = out_channels
301 self.kernel_size = kernel_size
302 self.stride = stride
303 self.groups = groups
304 self.has_shuffle = has_shuffle
305 self.bias = bias
306 self.use_bn = use_bn
307 self.affine = affine
308 self.act_func = act_func
310 # inverted bottleneck
311 if mid_channels > in_channels:
312 inverted_bottleneck = OrderedDict(
313 [
314 (
315 "conv",
316 nn.Conv2d(
317 in_channels,
318 mid_channels,
319 1,
320 1,
321 0,
322 groups=groups,
323 bias=bias,
324 ),
325 ),
326 ]
327 )
328 if use_bn:
329 inverted_bottleneck["bn"] = nn.BatchNorm2d(
330 mid_channels, affine=affine, track_running_stats=affine
331 )
332 if act_func == "relu":
333 inverted_bottleneck["act"] = nn.ReLU(inplace=True)
334 elif act_func == "relu6":
335 inverted_bottleneck["act"] = nn.ReLU6(inplace=True)
336 elif act_func == "swish":
337 inverted_bottleneck["act"] = Swish(inplace=True)
338 elif act_func == "h-swish":
339 inverted_bottleneck["act"] = HardSwish(inplace=True)
340 self.inverted_bottleneck = nn.Sequential(inverted_bottleneck)
341 else:
342 self.inverted_bottleneck = None
343 self.mid_channels = in_channels
344 mid_channels = in_channels
346 # depthwise convolution
347 padding = get_same_padding(self.kernel_size)
348 depth_conv = OrderedDict(
349 [
350 (
351 "conv",
352 nn.Conv2d(
353 mid_channels,
354 mid_channels,
355 kernel_size,
356 stride,
357 padding,
358 groups=mid_channels,
359 bias=bias,
360 ),
361 ),
362 ]
363 )
364 if use_bn:
365 depth_conv["bn"] = nn.BatchNorm2d(
366 mid_channels, affine=affine, track_running_stats=affine
367 )
368 if act_func == "relu":
369 depth_conv["act"] = nn.ReLU(inplace=True)
370 elif act_func == "relu6":
371 depth_conv["act"] = nn.ReLU6(inplace=True)
372 elif act_func == "swish":
373 depth_conv["act"] = Swish(inplace=True)
374 elif act_func == "h-swish":
375 depth_conv["act"] = HardSwish(inplace=True)
376 self.depth_conv = nn.Sequential(depth_conv)
378 # se model
379 if se_channels > 0:
380 squeeze_excite = OrderedDict(
381 [
382 (
383 "conv_reduce",
384 nn.Conv2d(
385 mid_channels,
386 se_channels,
387 1,
388 1,
389 0,
390 groups=groups,
391 bias=True,
392 ),
393 ),
394 ]
395 )
396 if act_func == "relu":
397 squeeze_excite["act"] = nn.ReLU(inplace=True)
398 elif act_func == "relu6":
399 squeeze_excite["act"] = nn.ReLU6(inplace=True)
400 elif act_func == "swish":
401 squeeze_excite["act"] = Swish(inplace=True)
402 elif act_func == "h-swish":
403 squeeze_excite["act"] = HardSwish(inplace=True)
404 squeeze_excite["conv_expand"] = nn.Conv2d(
405 se_channels, mid_channels, 1, 1, 0, groups=groups, bias=True
406 )
407 self.squeeze_excite = nn.Sequential(squeeze_excite)
408 else:
409 self.squeeze_excite = None
410 self.se_channels = 0
412 # pointwise linear
413 point_linear = OrderedDict(
414 [
415 (
416 "conv",
417 nn.Conv2d(
418 mid_channels,
419 out_channels,
420 1,
421 1,
422 0,
423 groups=groups,
424 bias=bias,
425 ),
426 ),
427 ]
428 )
429 if use_bn:
430 point_linear["bn"] = nn.BatchNorm2d(
431 out_channels, affine=affine, track_running_stats=affine
432 )
433 self.point_linear = nn.Sequential(point_linear)
435 # residual flag
436 self.has_residual = (in_channels == out_channels) and (stride == 1)
438 def forward(self, x):
439 res = x
441 if self.inverted_bottleneck is not None:
442 x = self.inverted_bottleneck(x)
443 if self.has_shuffle and self.groups > 1:
444 x = channel_shuffle(x, self.groups)
446 x = self.depth_conv(x)
447 if self.squeeze_excite is not None:
448 x_se = F.adaptive_avg_pool2d(x, 1)
449 x = x * torch.sigmoid(self.squeeze_excite(x_se))
451 x = self.point_linear(x)
452 if self.has_shuffle and self.groups > 1:
453 x = channel_shuffle(x, self.groups)
455 if self.has_residual:
456 x += res
458 return x
461class Flatten(nn.Module):
462 def forward(self, x):
463 return x.view(x.size(0), -1)
466class TF_NAS_A(nn.Module):
467 def __init__(self, out_h, out_w, feat_dim, drop_ratio=0.0):
468 super(TF_NAS_A, self).__init__()
469 self.drop_ratio = drop_ratio
471 self.first_stem = ConvLayer(
472 3, 32, kernel_size=3, stride=1, act_func="relu"
473 )
474 self.second_stem = MBInvertedResBlock(
475 32, 32, 8, 16, kernel_size=3, stride=1, act_func="relu"
476 )
477 self.stage1 = nn.Sequential(
478 MBInvertedResBlock(
479 16, 83, 32, 24, kernel_size=3, stride=2, act_func="relu"
480 ),
481 MBInvertedResBlock(
482 24, 128, 0, 24, kernel_size=5, stride=1, act_func="relu"
483 ),
484 )
485 self.stage2 = nn.Sequential(
486 MBInvertedResBlock(
487 24, 138, 48, 40, kernel_size=3, stride=2, act_func="swish"
488 ),
489 MBInvertedResBlock(
490 40, 297, 0, 40, kernel_size=3, stride=1, act_func="swish"
491 ),
492 MBInvertedResBlock(
493 40, 170, 80, 40, kernel_size=5, stride=1, act_func="swish"
494 ),
495 )
496 self.stage3 = nn.Sequential(
497 MBInvertedResBlock(
498 40, 248, 80, 80, kernel_size=5, stride=2, act_func="swish"
499 ),
500 MBInvertedResBlock(
501 80, 500, 0, 80, kernel_size=3, stride=1, act_func="swish"
502 ),
503 MBInvertedResBlock(
504 80, 424, 0, 80, kernel_size=3, stride=1, act_func="swish"
505 ),
506 MBInvertedResBlock(
507 80, 477, 0, 80, kernel_size=3, stride=1, act_func="swish"
508 ),
509 )
510 self.stage4 = nn.Sequential(
511 MBInvertedResBlock(
512 80, 504, 160, 112, kernel_size=3, stride=1, act_func="swish"
513 ),
514 MBInvertedResBlock(
515 112, 796, 0, 112, kernel_size=3, stride=1, act_func="swish"
516 ),
517 MBInvertedResBlock(
518 112, 723, 224, 112, kernel_size=3, stride=1, act_func="swish"
519 ),
520 MBInvertedResBlock(
521 112, 555, 224, 112, kernel_size=3, stride=1, act_func="swish"
522 ),
523 )
524 self.stage5 = nn.Sequential(
525 MBInvertedResBlock(
526 112, 813, 0, 192, kernel_size=3, stride=2, act_func="swish"
527 ),
528 MBInvertedResBlock(
529 192, 1370, 0, 192, kernel_size=3, stride=1, act_func="swish"
530 ),
531 MBInvertedResBlock(
532 192, 1138, 384, 192, kernel_size=3, stride=1, act_func="swish"
533 ),
534 MBInvertedResBlock(
535 192, 1359, 384, 192, kernel_size=3, stride=1, act_func="swish"
536 ),
537 )
538 self.stage6 = nn.Sequential(
539 MBInvertedResBlock(
540 192, 1203, 384, 320, kernel_size=5, stride=1, act_func="swish"
541 ),
542 )
543 self.feature_mix_layer = ConvLayer(
544 320, 1280, kernel_size=1, stride=1, act_func="none"
545 )
546 self.output_layer = nn.Sequential(
547 nn.Dropout(self.drop_ratio),
548 Flatten(),
549 nn.Linear(1280 * out_h * out_w, feat_dim),
550 nn.BatchNorm1d(feat_dim),
551 )
553 self._initialization()
555 def forward(self, x):
556 x = self.first_stem(x)
557 x = self.second_stem(x)
558 for block in self.stage1:
559 x = block(x)
560 for block in self.stage2:
561 x = block(x)
562 for block in self.stage3:
563 x = block(x)
564 for block in self.stage4:
565 x = block(x)
566 for block in self.stage5:
567 x = block(x)
568 for block in self.stage6:
569 x = block(x)
570 x = self.feature_mix_layer(x)
571 x = self.output_layer(x)
572 return x
574 def _initialization(self):
575 for m in self.modules():
576 if isinstance(m, nn.Conv2d):
577 if m.bias is not None:
578 nn.init.constant_(m.bias, 0)
579 elif isinstance(m, nn.Linear):
580 if m.bias is not None:
581 nn.init.constant_(m.bias, 0)
582 elif isinstance(m, nn.BatchNorm2d):
583 if m.weight is not None:
584 nn.init.constant_(m.weight, 1)
585 if m.bias is not None:
586 nn.init.constant_(m.bias, 0)
589if __name__ == "__main__":
590 x = torch.rand((2, 3, 112, 112))
591 net = TF_NAS_A(7, 7, 512, drop_ratio=0.0)
593 x = x.cuda()
594 net = net.cuda()
596 out = net(x)
597 print(out.size())