Coverage for src/bob/bio/face/pytorch/facexzoo/TF_NAS.py: 67%

298 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Yibo Hu, Jun Wang 

3@date: 20201019 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7from collections import OrderedDict 

8 

9import torch 

10import torch.nn as nn 

11import torch.nn.functional as F 

12 

13 

14def channel_shuffle(x, groups): 

15 assert groups > 1 

16 batchsize, num_channels, height, width = x.size() 

17 assert num_channels % groups == 0 

18 channels_per_group = num_channels // groups 

19 # reshape 

20 x = x.view(batchsize, groups, channels_per_group, height, width) 

21 # transpose 

22 x = torch.transpose(x, 1, 2).contiguous() 

23 # flatten 

24 x = x.view(batchsize, -1, height, width) 

25 return x 

26 

27 

28def get_same_padding(kernel_size): 

29 if isinstance(kernel_size, tuple): 

30 assert len(kernel_size) == 2, "invalid kernel size: {}".format( 

31 kernel_size 

32 ) 

33 p1 = get_same_padding(kernel_size[0]) 

34 p2 = get_same_padding(kernel_size[1]) 

35 return p1, p2 

36 assert isinstance( 

37 kernel_size, int 

38 ), "kernel size should be either `int` or `tuple`" 

39 assert kernel_size % 2 > 0, "kernel size should be odd number" 

40 return kernel_size // 2 

41 

42 

43class Swish(nn.Module): 

44 def __init__(self, inplace=False): 

45 super(Swish, self).__init__() 

46 self.inplace = inplace 

47 

48 def forward(self, x): 

49 if self.inplace: 

50 return x.mul_(x.sigmoid()) 

51 else: 

52 return x * x.sigmoid() 

53 

54 

55class HardSwish(nn.Module): 

56 def __init__(self, inplace=False): 

57 super(HardSwish, self).__init__() 

58 self.inplace = inplace 

59 

60 def forward(self, x): 

61 if self.inplace: 

62 return x.mul_(F.relu6(x + 3.0, inplace=True) / 6.0) 

63 else: 

64 return x * F.relu6(x + 3.0) / 6.0 

65 

66 

67class BasicLayer(nn.Module): 

68 def __init__( 

69 self, 

70 in_channels, 

71 out_channels, 

72 use_bn=True, 

73 affine=True, 

74 act_func="relu6", 

75 ops_order="weight_bn_act", 

76 ): 

77 super(BasicLayer, self).__init__() 

78 

79 self.in_channels = in_channels 

80 self.out_channels = out_channels 

81 self.use_bn = use_bn 

82 self.affine = affine 

83 self.act_func = act_func 

84 self.ops_order = ops_order 

85 

86 """ add modules """ 

87 # batch norm 

88 if self.use_bn: 

89 if self.bn_before_weight: 

90 self.bn = nn.BatchNorm2d( 

91 in_channels, affine=affine, track_running_stats=affine 

92 ) 

93 else: 

94 self.bn = nn.BatchNorm2d( 

95 out_channels, affine=affine, track_running_stats=affine 

96 ) 

97 else: 

98 self.bn = None 

99 # activation 

100 if act_func == "relu": 

101 if self.ops_list[0] == "act": 

102 self.act = nn.ReLU(inplace=False) 

103 else: 

104 self.act = nn.ReLU(inplace=True) 

105 elif act_func == "relu6": 

106 if self.ops_list[0] == "act": 

107 self.act = nn.ReLU6(inplace=False) 

108 else: 

109 self.act = nn.ReLU6(inplace=True) 

110 elif act_func == "swish": 

111 if self.ops_list[0] == "act": 

112 self.act = Swish(inplace=False) 

113 else: 

114 self.act = Swish(inplace=True) 

115 elif act_func == "h-swish": 

116 if self.ops_list[0] == "act": 

117 self.act = HardSwish(inplace=False) 

118 else: 

119 self.act = HardSwish(inplace=True) 

120 else: 

121 self.act = None 

122 

123 @property 

124 def ops_list(self): 

125 return self.ops_order.split("_") 

126 

127 @property 

128 def bn_before_weight(self): 

129 for op in self.ops_list: 

130 if op == "bn": 

131 return True 

132 elif op == "weight": 

133 return False 

134 raise ValueError("Invalid ops_order: %s" % self.ops_order) 

135 

136 def weight_call(self, x): 

137 raise NotImplementedError 

138 

139 def forward(self, x): 

140 for op in self.ops_list: 

141 if op == "weight": 

142 x = self.weight_call(x) 

143 elif op == "bn": 

144 if self.bn is not None: 

145 x = self.bn(x) 

146 elif op == "act": 

147 if self.act is not None: 

148 x = self.act(x) 

149 else: 

150 raise ValueError("Unrecognized op: %s" % op) 

151 return x 

152 

153 

154class ConvLayer(BasicLayer): 

155 def __init__( 

156 self, 

157 in_channels, 

158 out_channels, 

159 kernel_size=3, 

160 stride=1, 

161 groups=1, 

162 has_shuffle=False, 

163 bias=False, 

164 use_bn=True, 

165 affine=True, 

166 act_func="relu6", 

167 ops_order="weight_bn_act", 

168 ): 

169 super(ConvLayer, self).__init__( 

170 in_channels, out_channels, use_bn, affine, act_func, ops_order 

171 ) 

172 

173 self.kernel_size = kernel_size 

174 self.stride = stride 

175 self.groups = groups 

176 self.has_shuffle = has_shuffle 

177 self.bias = bias 

178 

179 padding = get_same_padding(self.kernel_size) 

180 self.conv = nn.Conv2d( 

181 in_channels, 

182 out_channels, 

183 kernel_size=self.kernel_size, 

184 stride=self.stride, 

185 padding=padding, 

186 groups=self.groups, 

187 bias=self.bias, 

188 ) 

189 

190 def weight_call(self, x): 

191 x = self.conv(x) 

192 if self.has_shuffle and self.groups > 1: 

193 x = channel_shuffle(x, self.groups) 

194 return x 

195 

196 

197class LinearLayer(nn.Module): 

198 def __init__( 

199 self, 

200 in_features, 

201 out_features, 

202 bias=True, 

203 use_bn=False, 

204 affine=False, 

205 act_func=None, 

206 ops_order="weight_bn_act", 

207 ): 

208 super(LinearLayer, self).__init__() 

209 

210 self.in_features = in_features 

211 self.out_features = out_features 

212 self.bias = bias 

213 self.use_bn = use_bn 

214 self.affine = affine 

215 self.act_func = act_func 

216 self.ops_order = ops_order 

217 

218 """ add modules """ 

219 # batch norm 

220 if self.use_bn: 

221 if self.bn_before_weight: 

222 self.bn = nn.BatchNorm1d( 

223 in_features, affine=affine, track_running_stats=affine 

224 ) 

225 else: 

226 self.bn = nn.BatchNorm1d( 

227 out_features, affine=affine, track_running_stats=affine 

228 ) 

229 else: 

230 self.bn = None 

231 # activation 

232 if act_func == "relu": 

233 if self.ops_list[0] == "act": 

234 self.act = nn.ReLU(inplace=False) 

235 else: 

236 self.act = nn.ReLU(inplace=True) 

237 elif act_func == "relu6": 

238 if self.ops_list[0] == "act": 

239 self.act = nn.ReLU6(inplace=False) 

240 else: 

241 self.act = nn.ReLU6(inplace=True) 

242 elif act_func == "tanh": 

243 self.act = nn.Tanh() 

244 elif act_func == "sigmoid": 

245 self.act = nn.Sigmoid() 

246 else: 

247 self.act = None 

248 # linear 

249 self.linear = nn.Linear(self.in_features, self.out_features, self.bias) 

250 

251 @property 

252 def ops_list(self): 

253 return self.ops_order.split("_") 

254 

255 @property 

256 def bn_before_weight(self): 

257 for op in self.ops_list: 

258 if op == "bn": 

259 return True 

260 elif op == "weight": 

261 return False 

262 raise ValueError("Invalid ops_order: %s" % self.ops_order) 

263 

264 def forward(self, x): 

265 for op in self.ops_list: 

266 if op == "weight": 

267 x = self.linear(x) 

268 elif op == "bn": 

269 if self.bn is not None: 

270 x = self.bn(x) 

271 elif op == "act": 

272 if self.act is not None: 

273 x = self.act(x) 

274 else: 

275 raise ValueError("Unrecognized op: %s" % op) 

276 return x 

277 

278 

279class MBInvertedResBlock(nn.Module): 

280 def __init__( 

281 self, 

282 in_channels, 

283 mid_channels, 

284 se_channels, 

285 out_channels, 

286 kernel_size=3, 

287 stride=1, 

288 groups=1, 

289 has_shuffle=False, 

290 bias=False, 

291 use_bn=True, 

292 affine=True, 

293 act_func="relu6", 

294 ): 

295 super(MBInvertedResBlock, self).__init__() 

296 

297 self.in_channels = in_channels 

298 self.mid_channels = mid_channels 

299 self.se_channels = se_channels 

300 self.out_channels = out_channels 

301 self.kernel_size = kernel_size 

302 self.stride = stride 

303 self.groups = groups 

304 self.has_shuffle = has_shuffle 

305 self.bias = bias 

306 self.use_bn = use_bn 

307 self.affine = affine 

308 self.act_func = act_func 

309 

310 # inverted bottleneck 

311 if mid_channels > in_channels: 

312 inverted_bottleneck = OrderedDict( 

313 [ 

314 ( 

315 "conv", 

316 nn.Conv2d( 

317 in_channels, 

318 mid_channels, 

319 1, 

320 1, 

321 0, 

322 groups=groups, 

323 bias=bias, 

324 ), 

325 ), 

326 ] 

327 ) 

328 if use_bn: 

329 inverted_bottleneck["bn"] = nn.BatchNorm2d( 

330 mid_channels, affine=affine, track_running_stats=affine 

331 ) 

332 if act_func == "relu": 

333 inverted_bottleneck["act"] = nn.ReLU(inplace=True) 

334 elif act_func == "relu6": 

335 inverted_bottleneck["act"] = nn.ReLU6(inplace=True) 

336 elif act_func == "swish": 

337 inverted_bottleneck["act"] = Swish(inplace=True) 

338 elif act_func == "h-swish": 

339 inverted_bottleneck["act"] = HardSwish(inplace=True) 

340 self.inverted_bottleneck = nn.Sequential(inverted_bottleneck) 

341 else: 

342 self.inverted_bottleneck = None 

343 self.mid_channels = in_channels 

344 mid_channels = in_channels 

345 

346 # depthwise convolution 

347 padding = get_same_padding(self.kernel_size) 

348 depth_conv = OrderedDict( 

349 [ 

350 ( 

351 "conv", 

352 nn.Conv2d( 

353 mid_channels, 

354 mid_channels, 

355 kernel_size, 

356 stride, 

357 padding, 

358 groups=mid_channels, 

359 bias=bias, 

360 ), 

361 ), 

362 ] 

363 ) 

364 if use_bn: 

365 depth_conv["bn"] = nn.BatchNorm2d( 

366 mid_channels, affine=affine, track_running_stats=affine 

367 ) 

368 if act_func == "relu": 

369 depth_conv["act"] = nn.ReLU(inplace=True) 

370 elif act_func == "relu6": 

371 depth_conv["act"] = nn.ReLU6(inplace=True) 

372 elif act_func == "swish": 

373 depth_conv["act"] = Swish(inplace=True) 

374 elif act_func == "h-swish": 

375 depth_conv["act"] = HardSwish(inplace=True) 

376 self.depth_conv = nn.Sequential(depth_conv) 

377 

378 # se model 

379 if se_channels > 0: 

380 squeeze_excite = OrderedDict( 

381 [ 

382 ( 

383 "conv_reduce", 

384 nn.Conv2d( 

385 mid_channels, 

386 se_channels, 

387 1, 

388 1, 

389 0, 

390 groups=groups, 

391 bias=True, 

392 ), 

393 ), 

394 ] 

395 ) 

396 if act_func == "relu": 

397 squeeze_excite["act"] = nn.ReLU(inplace=True) 

398 elif act_func == "relu6": 

399 squeeze_excite["act"] = nn.ReLU6(inplace=True) 

400 elif act_func == "swish": 

401 squeeze_excite["act"] = Swish(inplace=True) 

402 elif act_func == "h-swish": 

403 squeeze_excite["act"] = HardSwish(inplace=True) 

404 squeeze_excite["conv_expand"] = nn.Conv2d( 

405 se_channels, mid_channels, 1, 1, 0, groups=groups, bias=True 

406 ) 

407 self.squeeze_excite = nn.Sequential(squeeze_excite) 

408 else: 

409 self.squeeze_excite = None 

410 self.se_channels = 0 

411 

412 # pointwise linear 

413 point_linear = OrderedDict( 

414 [ 

415 ( 

416 "conv", 

417 nn.Conv2d( 

418 mid_channels, 

419 out_channels, 

420 1, 

421 1, 

422 0, 

423 groups=groups, 

424 bias=bias, 

425 ), 

426 ), 

427 ] 

428 ) 

429 if use_bn: 

430 point_linear["bn"] = nn.BatchNorm2d( 

431 out_channels, affine=affine, track_running_stats=affine 

432 ) 

433 self.point_linear = nn.Sequential(point_linear) 

434 

435 # residual flag 

436 self.has_residual = (in_channels == out_channels) and (stride == 1) 

437 

438 def forward(self, x): 

439 res = x 

440 

441 if self.inverted_bottleneck is not None: 

442 x = self.inverted_bottleneck(x) 

443 if self.has_shuffle and self.groups > 1: 

444 x = channel_shuffle(x, self.groups) 

445 

446 x = self.depth_conv(x) 

447 if self.squeeze_excite is not None: 

448 x_se = F.adaptive_avg_pool2d(x, 1) 

449 x = x * torch.sigmoid(self.squeeze_excite(x_se)) 

450 

451 x = self.point_linear(x) 

452 if self.has_shuffle and self.groups > 1: 

453 x = channel_shuffle(x, self.groups) 

454 

455 if self.has_residual: 

456 x += res 

457 

458 return x 

459 

460 

461class Flatten(nn.Module): 

462 def forward(self, x): 

463 return x.view(x.size(0), -1) 

464 

465 

466class TF_NAS_A(nn.Module): 

467 def __init__(self, out_h, out_w, feat_dim, drop_ratio=0.0): 

468 super(TF_NAS_A, self).__init__() 

469 self.drop_ratio = drop_ratio 

470 

471 self.first_stem = ConvLayer( 

472 3, 32, kernel_size=3, stride=1, act_func="relu" 

473 ) 

474 self.second_stem = MBInvertedResBlock( 

475 32, 32, 8, 16, kernel_size=3, stride=1, act_func="relu" 

476 ) 

477 self.stage1 = nn.Sequential( 

478 MBInvertedResBlock( 

479 16, 83, 32, 24, kernel_size=3, stride=2, act_func="relu" 

480 ), 

481 MBInvertedResBlock( 

482 24, 128, 0, 24, kernel_size=5, stride=1, act_func="relu" 

483 ), 

484 ) 

485 self.stage2 = nn.Sequential( 

486 MBInvertedResBlock( 

487 24, 138, 48, 40, kernel_size=3, stride=2, act_func="swish" 

488 ), 

489 MBInvertedResBlock( 

490 40, 297, 0, 40, kernel_size=3, stride=1, act_func="swish" 

491 ), 

492 MBInvertedResBlock( 

493 40, 170, 80, 40, kernel_size=5, stride=1, act_func="swish" 

494 ), 

495 ) 

496 self.stage3 = nn.Sequential( 

497 MBInvertedResBlock( 

498 40, 248, 80, 80, kernel_size=5, stride=2, act_func="swish" 

499 ), 

500 MBInvertedResBlock( 

501 80, 500, 0, 80, kernel_size=3, stride=1, act_func="swish" 

502 ), 

503 MBInvertedResBlock( 

504 80, 424, 0, 80, kernel_size=3, stride=1, act_func="swish" 

505 ), 

506 MBInvertedResBlock( 

507 80, 477, 0, 80, kernel_size=3, stride=1, act_func="swish" 

508 ), 

509 ) 

510 self.stage4 = nn.Sequential( 

511 MBInvertedResBlock( 

512 80, 504, 160, 112, kernel_size=3, stride=1, act_func="swish" 

513 ), 

514 MBInvertedResBlock( 

515 112, 796, 0, 112, kernel_size=3, stride=1, act_func="swish" 

516 ), 

517 MBInvertedResBlock( 

518 112, 723, 224, 112, kernel_size=3, stride=1, act_func="swish" 

519 ), 

520 MBInvertedResBlock( 

521 112, 555, 224, 112, kernel_size=3, stride=1, act_func="swish" 

522 ), 

523 ) 

524 self.stage5 = nn.Sequential( 

525 MBInvertedResBlock( 

526 112, 813, 0, 192, kernel_size=3, stride=2, act_func="swish" 

527 ), 

528 MBInvertedResBlock( 

529 192, 1370, 0, 192, kernel_size=3, stride=1, act_func="swish" 

530 ), 

531 MBInvertedResBlock( 

532 192, 1138, 384, 192, kernel_size=3, stride=1, act_func="swish" 

533 ), 

534 MBInvertedResBlock( 

535 192, 1359, 384, 192, kernel_size=3, stride=1, act_func="swish" 

536 ), 

537 ) 

538 self.stage6 = nn.Sequential( 

539 MBInvertedResBlock( 

540 192, 1203, 384, 320, kernel_size=5, stride=1, act_func="swish" 

541 ), 

542 ) 

543 self.feature_mix_layer = ConvLayer( 

544 320, 1280, kernel_size=1, stride=1, act_func="none" 

545 ) 

546 self.output_layer = nn.Sequential( 

547 nn.Dropout(self.drop_ratio), 

548 Flatten(), 

549 nn.Linear(1280 * out_h * out_w, feat_dim), 

550 nn.BatchNorm1d(feat_dim), 

551 ) 

552 

553 self._initialization() 

554 

555 def forward(self, x): 

556 x = self.first_stem(x) 

557 x = self.second_stem(x) 

558 for block in self.stage1: 

559 x = block(x) 

560 for block in self.stage2: 

561 x = block(x) 

562 for block in self.stage3: 

563 x = block(x) 

564 for block in self.stage4: 

565 x = block(x) 

566 for block in self.stage5: 

567 x = block(x) 

568 for block in self.stage6: 

569 x = block(x) 

570 x = self.feature_mix_layer(x) 

571 x = self.output_layer(x) 

572 return x 

573 

574 def _initialization(self): 

575 for m in self.modules(): 

576 if isinstance(m, nn.Conv2d): 

577 if m.bias is not None: 

578 nn.init.constant_(m.bias, 0) 

579 elif isinstance(m, nn.Linear): 

580 if m.bias is not None: 

581 nn.init.constant_(m.bias, 0) 

582 elif isinstance(m, nn.BatchNorm2d): 

583 if m.weight is not None: 

584 nn.init.constant_(m.weight, 1) 

585 if m.bias is not None: 

586 nn.init.constant_(m.bias, 0) 

587 

588 

589if __name__ == "__main__": 

590 x = torch.rand((2, 3, 112, 112)) 

591 net = TF_NAS_A(7, 7, 512, drop_ratio=0.0) 

592 

593 x = x.cuda() 

594 net = net.cuda() 

595 

596 out = net(x) 

597 print(out.size())