Coverage for src/bob/bio/face/pytorch/facexzoo/HRNet.py: 88%

289 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Hanbin Dai, Jun Wang 

3@date: 20201020 

4@contact: daihanbin.ac@gmail.com, jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py 

9 

10import logging 

11import os 

12 

13import torch 

14import torch._utils 

15import torch.nn as nn 

16 

17from torch.nn import BatchNorm1d, Linear, Module, Sequential 

18 

19BN_MOMENTUM = 0.1 

20logger = logging.getLogger(__name__) 

21 

22 

23class Flatten(Module): 

24 def forward(self, x): 

25 return x.view(x.size(0), -1) 

26 

27 

28def conv3x3(in_planes, out_planes, stride=1): 

29 """3x3 convolution with padding""" 

30 return nn.Conv2d( 

31 in_planes, 

32 out_planes, 

33 kernel_size=3, 

34 stride=stride, 

35 padding=1, 

36 bias=False, 

37 ) 

38 

39 

40class BasicBlock(nn.Module): 

41 expansion = 1 

42 

43 def __init__(self, inplanes, planes, stride=1, downsample=None): 

44 super(BasicBlock, self).__init__() 

45 self.conv1 = conv3x3(inplanes, planes, stride) 

46 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 

47 self.relu = nn.ReLU(inplace=True) 

48 self.conv2 = conv3x3(planes, planes) 

49 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 

50 self.downsample = downsample 

51 self.stride = stride 

52 

53 def forward(self, x): 

54 residual = x 

55 

56 out = self.conv1(x) 

57 out = self.bn1(out) 

58 out = self.relu(out) 

59 

60 out = self.conv2(out) 

61 out = self.bn2(out) 

62 

63 if self.downsample is not None: 

64 residual = self.downsample(x) 

65 

66 out += residual 

67 out = self.relu(out) 

68 

69 return out 

70 

71 

72class Bottleneck(nn.Module): 

73 expansion = 4 

74 

75 def __init__(self, inplanes, planes, stride=1, downsample=None): 

76 super(Bottleneck, self).__init__() 

77 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 

78 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 

79 self.conv2 = nn.Conv2d( 

80 planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 

81 ) 

82 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 

83 self.conv3 = nn.Conv2d( 

84 planes, planes * self.expansion, kernel_size=1, bias=False 

85 ) 

86 self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 

87 self.relu = nn.ReLU(inplace=True) 

88 self.downsample = downsample 

89 self.stride = stride 

90 

91 def forward(self, x): 

92 residual = x 

93 

94 out = self.conv1(x) 

95 out = self.bn1(out) 

96 out = self.relu(out) 

97 

98 out = self.conv2(out) 

99 out = self.bn2(out) 

100 out = self.relu(out) 

101 

102 out = self.conv3(out) 

103 out = self.bn3(out) 

104 

105 if self.downsample is not None: 

106 residual = self.downsample(x) 

107 

108 out += residual 

109 out = self.relu(out) 

110 

111 return out 

112 

113 

114class HighResolutionModule(nn.Module): 

115 def __init__( 

116 self, 

117 num_branches, 

118 blocks, 

119 num_blocks, 

120 num_inchannels, 

121 num_channels, 

122 fuse_method, 

123 multi_scale_output=True, 

124 ): 

125 super(HighResolutionModule, self).__init__() 

126 self._check_branches( 

127 num_branches, blocks, num_blocks, num_inchannels, num_channels 

128 ) 

129 

130 self.num_inchannels = num_inchannels 

131 self.fuse_method = fuse_method 

132 self.num_branches = num_branches 

133 

134 self.multi_scale_output = multi_scale_output 

135 

136 self.branches = self._make_branches( 

137 num_branches, blocks, num_blocks, num_channels 

138 ) 

139 self.fuse_layers = self._make_fuse_layers() 

140 self.relu = nn.ReLU(False) 

141 

142 def _check_branches( 

143 self, num_branches, blocks, num_blocks, num_inchannels, num_channels 

144 ): 

145 if num_branches != len(num_blocks): 

146 error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( 

147 num_branches, len(num_blocks) 

148 ) 

149 logger.error(error_msg) 

150 raise ValueError(error_msg) 

151 

152 if num_branches != len(num_channels): 

153 error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( 

154 num_branches, len(num_channels) 

155 ) 

156 logger.error(error_msg) 

157 raise ValueError(error_msg) 

158 

159 if num_branches != len(num_inchannels): 

160 error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( 

161 num_branches, len(num_inchannels) 

162 ) 

163 logger.error(error_msg) 

164 raise ValueError(error_msg) 

165 

166 def _make_one_branch( 

167 self, branch_index, block, num_blocks, num_channels, stride=1 

168 ): 

169 downsample = None 

170 if ( 

171 stride != 1 

172 or self.num_inchannels[branch_index] 

173 != num_channels[branch_index] * block.expansion 

174 ): 

175 downsample = nn.Sequential( 

176 nn.Conv2d( 

177 self.num_inchannels[branch_index], 

178 num_channels[branch_index] * block.expansion, 

179 kernel_size=1, 

180 stride=stride, 

181 bias=False, 

182 ), 

183 nn.BatchNorm2d( 

184 num_channels[branch_index] * block.expansion, 

185 momentum=BN_MOMENTUM, 

186 ), 

187 ) 

188 

189 layers = [] 

190 layers.append( 

191 block( 

192 self.num_inchannels[branch_index], 

193 num_channels[branch_index], 

194 stride, 

195 downsample, 

196 ) 

197 ) 

198 self.num_inchannels[branch_index] = ( 

199 num_channels[branch_index] * block.expansion 

200 ) 

201 for i in range(1, num_blocks[branch_index]): 

202 layers.append( 

203 block( 

204 self.num_inchannels[branch_index], 

205 num_channels[branch_index], 

206 ) 

207 ) 

208 

209 return nn.Sequential(*layers) 

210 

211 def _make_branches(self, num_branches, block, num_blocks, num_channels): 

212 branches = [] 

213 

214 for i in range(num_branches): 

215 branches.append( 

216 self._make_one_branch(i, block, num_blocks, num_channels) 

217 ) 

218 

219 return nn.ModuleList(branches) 

220 

221 def _make_fuse_layers(self): 

222 if self.num_branches == 1: 

223 return None 

224 

225 num_branches = self.num_branches 

226 num_inchannels = self.num_inchannels 

227 fuse_layers = [] 

228 for i in range(num_branches if self.multi_scale_output else 1): 

229 fuse_layer = [] 

230 for j in range(num_branches): 

231 if j > i: 

232 fuse_layer.append( 

233 nn.Sequential( 

234 nn.Conv2d( 

235 num_inchannels[j], 

236 num_inchannels[i], 

237 1, 

238 1, 

239 0, 

240 bias=False, 

241 ), 

242 nn.BatchNorm2d( 

243 num_inchannels[i], momentum=BN_MOMENTUM 

244 ), 

245 nn.Upsample( 

246 scale_factor=2 ** (j - i), mode="nearest" 

247 ), 

248 ) 

249 ) 

250 elif j == i: 

251 fuse_layer.append(None) 

252 else: 

253 conv3x3s = [] 

254 for k in range(i - j): 

255 if k == i - j - 1: 

256 num_outchannels_conv3x3 = num_inchannels[i] 

257 conv3x3s.append( 

258 nn.Sequential( 

259 nn.Conv2d( 

260 num_inchannels[j], 

261 num_outchannels_conv3x3, 

262 3, 

263 2, 

264 1, 

265 bias=False, 

266 ), 

267 nn.BatchNorm2d( 

268 num_outchannels_conv3x3, 

269 momentum=BN_MOMENTUM, 

270 ), 

271 ) 

272 ) 

273 else: 

274 num_outchannels_conv3x3 = num_inchannels[j] 

275 conv3x3s.append( 

276 nn.Sequential( 

277 nn.Conv2d( 

278 num_inchannels[j], 

279 num_outchannels_conv3x3, 

280 3, 

281 2, 

282 1, 

283 bias=False, 

284 ), 

285 nn.BatchNorm2d( 

286 num_outchannels_conv3x3, 

287 momentum=BN_MOMENTUM, 

288 ), 

289 nn.ReLU(False), 

290 ) 

291 ) 

292 fuse_layer.append(nn.Sequential(*conv3x3s)) 

293 fuse_layers.append(nn.ModuleList(fuse_layer)) 

294 

295 return nn.ModuleList(fuse_layers) 

296 

297 def get_num_inchannels(self): 

298 return self.num_inchannels 

299 

300 def forward(self, x): 

301 if self.num_branches == 1: 

302 return [self.branches[0](x[0])] 

303 

304 for i in range(self.num_branches): 

305 x[i] = self.branches[i](x[i]) 

306 

307 x_fuse = [] 

308 for i in range(len(self.fuse_layers)): 

309 y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 

310 for j in range(1, self.num_branches): 

311 if i == j: 

312 y = y + x[j] 

313 else: 

314 y = y + self.fuse_layers[i][j](x[j]) 

315 x_fuse.append(self.relu(y)) 

316 

317 return x_fuse 

318 

319 

320blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} 

321 

322 

323class HighResolutionNet(nn.Module): 

324 def __init__(self, cfg, **kwargs): 

325 super(HighResolutionNet, self).__init__() 

326 

327 self.conv1 = nn.Conv2d( 

328 3, 64, kernel_size=3, stride=2, padding=1, bias=False 

329 ) 

330 self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 

331 # self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 

332 self.conv2 = nn.Conv2d( 

333 64, 64, kernel_size=3, stride=1, padding=1, bias=False 

334 ) 

335 

336 self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 

337 self.relu = nn.ReLU(inplace=True) 

338 

339 self.stage1_cfg = cfg["MODEL"]["EXTRA"]["STAGE1"] 

340 num_channels = self.stage1_cfg["NUM_CHANNELS"][0] 

341 block = blocks_dict[self.stage1_cfg["BLOCK"]] 

342 num_blocks = self.stage1_cfg["NUM_BLOCKS"][0] 

343 self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) 

344 stage1_out_channel = block.expansion * num_channels 

345 

346 self.stage2_cfg = cfg["MODEL"]["EXTRA"]["STAGE2"] 

347 num_channels = self.stage2_cfg["NUM_CHANNELS"] 

348 block = blocks_dict[self.stage2_cfg["BLOCK"]] 

349 num_channels = [ 

350 num_channels[i] * block.expansion for i in range(len(num_channels)) 

351 ] 

352 self.transition1 = self._make_transition_layer( 

353 [stage1_out_channel], num_channels 

354 ) 

355 self.stage2, pre_stage_channels = self._make_stage( 

356 self.stage2_cfg, num_channels 

357 ) 

358 

359 self.stage3_cfg = cfg["MODEL"]["EXTRA"]["STAGE3"] 

360 num_channels = self.stage3_cfg["NUM_CHANNELS"] 

361 block = blocks_dict[self.stage3_cfg["BLOCK"]] 

362 num_channels = [ 

363 num_channels[i] * block.expansion for i in range(len(num_channels)) 

364 ] 

365 self.transition2 = self._make_transition_layer( 

366 pre_stage_channels, num_channels 

367 ) 

368 self.stage3, pre_stage_channels = self._make_stage( 

369 self.stage3_cfg, num_channels 

370 ) 

371 

372 self.stage4_cfg = cfg["MODEL"]["EXTRA"]["STAGE4"] 

373 num_channels = self.stage4_cfg["NUM_CHANNELS"] 

374 block = blocks_dict[self.stage4_cfg["BLOCK"]] 

375 num_channels = [ 

376 num_channels[i] * block.expansion for i in range(len(num_channels)) 

377 ] 

378 self.transition3 = self._make_transition_layer( 

379 pre_stage_channels, num_channels 

380 ) 

381 self.stage4, pre_stage_channels = self._make_stage( 

382 self.stage4_cfg, num_channels, multi_scale_output=True 

383 ) 

384 

385 # Classification Head 

386 ( 

387 self.incre_modules, 

388 self.downsamp_modules, 

389 self.final_layer, 

390 ) = self._make_head(pre_stage_channels) 

391 

392 # self.classifier = nn.Linear(2048, 1000) 

393 self.output_layer = Sequential( 

394 Flatten(), 

395 Linear( 

396 2048 * cfg["MODEL"]["out_h"] * cfg["MODEL"]["out_w"], 

397 cfg["MODEL"]["feat_dim"], 

398 False, 

399 ), 

400 BatchNorm1d(512), 

401 ) 

402 

403 def _make_head(self, pre_stage_channels): 

404 head_block = Bottleneck 

405 head_channels = [32, 64, 128, 256] 

406 

407 # Increasing the #channels on each resolution 

408 # from C, 2C, 4C, 8C to 128, 256, 512, 1024 

409 incre_modules = [] 

410 for i, channels in enumerate(pre_stage_channels): 

411 incre_module = self._make_layer( 

412 head_block, channels, head_channels[i], 1, stride=1 

413 ) 

414 incre_modules.append(incre_module) 

415 incre_modules = nn.ModuleList(incre_modules) 

416 

417 # downsampling modules 

418 downsamp_modules = [] 

419 for i in range(len(pre_stage_channels) - 1): 

420 in_channels = head_channels[i] * head_block.expansion 

421 out_channels = head_channels[i + 1] * head_block.expansion 

422 

423 downsamp_module = nn.Sequential( 

424 nn.Conv2d( 

425 in_channels=in_channels, 

426 out_channels=out_channels, 

427 kernel_size=3, 

428 stride=2, 

429 padding=1, 

430 ), 

431 nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), 

432 nn.ReLU(inplace=True), 

433 ) 

434 

435 downsamp_modules.append(downsamp_module) 

436 downsamp_modules = nn.ModuleList(downsamp_modules) 

437 

438 final_layer = nn.Sequential( 

439 nn.Conv2d( 

440 in_channels=head_channels[3] * head_block.expansion, 

441 out_channels=2048, 

442 kernel_size=1, 

443 stride=1, 

444 padding=0, 

445 ), 

446 nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), 

447 nn.ReLU(inplace=True), 

448 ) 

449 

450 return incre_modules, downsamp_modules, final_layer 

451 

452 def _make_transition_layer( 

453 self, num_channels_pre_layer, num_channels_cur_layer 

454 ): 

455 num_branches_cur = len(num_channels_cur_layer) 

456 num_branches_pre = len(num_channels_pre_layer) 

457 

458 transition_layers = [] 

459 for i in range(num_branches_cur): 

460 if i < num_branches_pre: 

461 if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 

462 transition_layers.append( 

463 nn.Sequential( 

464 nn.Conv2d( 

465 num_channels_pre_layer[i], 

466 num_channels_cur_layer[i], 

467 3, 

468 1, 

469 1, 

470 bias=False, 

471 ), 

472 nn.BatchNorm2d( 

473 num_channels_cur_layer[i], momentum=BN_MOMENTUM 

474 ), 

475 nn.ReLU(inplace=True), 

476 ) 

477 ) 

478 else: 

479 transition_layers.append(None) 

480 else: 

481 conv3x3s = [] 

482 for j in range(i + 1 - num_branches_pre): 

483 inchannels = num_channels_pre_layer[-1] 

484 outchannels = ( 

485 num_channels_cur_layer[i] 

486 if j == i - num_branches_pre 

487 else inchannels 

488 ) 

489 conv3x3s.append( 

490 nn.Sequential( 

491 nn.Conv2d( 

492 inchannels, outchannels, 3, 2, 1, bias=False 

493 ), 

494 nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 

495 nn.ReLU(inplace=True), 

496 ) 

497 ) 

498 transition_layers.append(nn.Sequential(*conv3x3s)) 

499 

500 return nn.ModuleList(transition_layers) 

501 

502 def _make_layer(self, block, inplanes, planes, blocks, stride=1): 

503 downsample = None 

504 if stride != 1 or inplanes != planes * block.expansion: 

505 downsample = nn.Sequential( 

506 nn.Conv2d( 

507 inplanes, 

508 planes * block.expansion, 

509 kernel_size=1, 

510 stride=stride, 

511 bias=False, 

512 ), 

513 nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 

514 ) 

515 

516 layers = [] 

517 layers.append(block(inplanes, planes, stride, downsample)) 

518 inplanes = planes * block.expansion 

519 for i in range(1, blocks): 

520 layers.append(block(inplanes, planes)) 

521 

522 return nn.Sequential(*layers) 

523 

524 def _make_stage( 

525 self, layer_config, num_inchannels, multi_scale_output=True 

526 ): 

527 num_modules = layer_config["NUM_MODULES"] 

528 num_branches = layer_config["NUM_BRANCHES"] 

529 num_blocks = layer_config["NUM_BLOCKS"] 

530 num_channels = layer_config["NUM_CHANNELS"] 

531 block = blocks_dict[layer_config["BLOCK"]] 

532 fuse_method = layer_config["FUSE_METHOD"] 

533 

534 modules = [] 

535 for i in range(num_modules): 

536 # multi_scale_output is only used last module 

537 if not multi_scale_output and i == num_modules - 1: 

538 reset_multi_scale_output = False 

539 else: 

540 reset_multi_scale_output = True 

541 

542 modules.append( 

543 HighResolutionModule( 

544 num_branches, 

545 block, 

546 num_blocks, 

547 num_inchannels, 

548 num_channels, 

549 fuse_method, 

550 reset_multi_scale_output, 

551 ) 

552 ) 

553 num_inchannels = modules[-1].get_num_inchannels() 

554 

555 return nn.Sequential(*modules), num_inchannels 

556 

557 def forward(self, x): 

558 x = self.conv1(x) 

559 x = self.bn1(x) 

560 x = self.relu(x) 

561 x = self.conv2(x) 

562 x = self.bn2(x) 

563 x = self.relu(x) 

564 x = self.layer1(x) 

565 

566 x_list = [] 

567 for i in range(self.stage2_cfg["NUM_BRANCHES"]): 

568 if self.transition1[i] is not None: 

569 x_list.append(self.transition1[i](x)) 

570 else: 

571 x_list.append(x) 

572 y_list = self.stage2(x_list) 

573 

574 x_list = [] 

575 for i in range(self.stage3_cfg["NUM_BRANCHES"]): 

576 if self.transition2[i] is not None: 

577 x_list.append(self.transition2[i](y_list[-1])) 

578 else: 

579 x_list.append(y_list[i]) 

580 y_list = self.stage3(x_list) 

581 

582 x_list = [] 

583 for i in range(self.stage4_cfg["NUM_BRANCHES"]): 

584 if self.transition3[i] is not None: 

585 x_list.append(self.transition3[i](y_list[-1])) 

586 else: 

587 x_list.append(y_list[i]) 

588 y_list = self.stage4(x_list) 

589 

590 # Classification Head 

591 y = self.incre_modules[0](y_list[0]) 

592 for i in range(len(self.downsamp_modules)): 

593 y = self.incre_modules[i + 1]( 

594 y_list[i + 1] 

595 ) + self.downsamp_modules[i](y) 

596 

597 y = self.final_layer(y) 

598 """ 

599 if torch._C._get_tracing_state(): 

600 y = y.flatten(start_dim=2).mean(dim=2) 

601 else: 

602 y = F.avg_pool2d(y, kernel_size=y.size() 

603 [2:]).view(y.size(0), -1) 

604 

605 y = self.classifier(y) 

606 """ 

607 y = self.output_layer(y) 

608 return y 

609 

610 def init_weights( 

611 self, 

612 pretrained="", 

613 ): 

614 logger.info("=> init weights from normal distribution") 

615 for m in self.modules(): 

616 if isinstance(m, nn.Conv2d): 

617 nn.init.kaiming_normal_( 

618 m.weight, mode="fan_out", nonlinearity="relu" 

619 ) 

620 elif isinstance(m, nn.BatchNorm2d): 

621 nn.init.constant_(m.weight, 1) 

622 nn.init.constant_(m.bias, 0) 

623 if os.path.isfile(pretrained): 

624 pretrained_dict = torch.load(pretrained) 

625 logger.info("=> loading pretrained model {}".format(pretrained)) 

626 model_dict = self.state_dict() 

627 pretrained_dict = { 

628 k: v 

629 for k, v in pretrained_dict.items() 

630 if k in model_dict.keys() 

631 } 

632 for k, _ in pretrained_dict.items(): 

633 logger.info( 

634 "=> loading {} pretrained model {}".format(k, pretrained) 

635 ) 

636 model_dict.update(pretrained_dict) 

637 self.load_state_dict(model_dict) 

638 

639 

640def get_cls_net(config, **kwargs): 

641 model = HighResolutionNet(config, **kwargs) 

642 model.init_weights() 

643 return model