Coverage for src/bob/bio/face/pytorch/facexzoo/HRNet.py: 88%
289 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Hanbin Dai, Jun Wang
3@date: 20201020
4@contact: daihanbin.ac@gmail.com, jun21wangustc@gmail.com
5"""
7# based on:
8# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py
10import logging
11import os
13import torch
14import torch._utils
15import torch.nn as nn
17from torch.nn import BatchNorm1d, Linear, Module, Sequential
19BN_MOMENTUM = 0.1
20logger = logging.getLogger(__name__)
23class Flatten(Module):
24 def forward(self, x):
25 return x.view(x.size(0), -1)
28def conv3x3(in_planes, out_planes, stride=1):
29 """3x3 convolution with padding"""
30 return nn.Conv2d(
31 in_planes,
32 out_planes,
33 kernel_size=3,
34 stride=stride,
35 padding=1,
36 bias=False,
37 )
40class BasicBlock(nn.Module):
41 expansion = 1
43 def __init__(self, inplanes, planes, stride=1, downsample=None):
44 super(BasicBlock, self).__init__()
45 self.conv1 = conv3x3(inplanes, planes, stride)
46 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
47 self.relu = nn.ReLU(inplace=True)
48 self.conv2 = conv3x3(planes, planes)
49 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
50 self.downsample = downsample
51 self.stride = stride
53 def forward(self, x):
54 residual = x
56 out = self.conv1(x)
57 out = self.bn1(out)
58 out = self.relu(out)
60 out = self.conv2(out)
61 out = self.bn2(out)
63 if self.downsample is not None:
64 residual = self.downsample(x)
66 out += residual
67 out = self.relu(out)
69 return out
72class Bottleneck(nn.Module):
73 expansion = 4
75 def __init__(self, inplanes, planes, stride=1, downsample=None):
76 super(Bottleneck, self).__init__()
77 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
78 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
79 self.conv2 = nn.Conv2d(
80 planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
81 )
82 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
83 self.conv3 = nn.Conv2d(
84 planes, planes * self.expansion, kernel_size=1, bias=False
85 )
86 self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
87 self.relu = nn.ReLU(inplace=True)
88 self.downsample = downsample
89 self.stride = stride
91 def forward(self, x):
92 residual = x
94 out = self.conv1(x)
95 out = self.bn1(out)
96 out = self.relu(out)
98 out = self.conv2(out)
99 out = self.bn2(out)
100 out = self.relu(out)
102 out = self.conv3(out)
103 out = self.bn3(out)
105 if self.downsample is not None:
106 residual = self.downsample(x)
108 out += residual
109 out = self.relu(out)
111 return out
114class HighResolutionModule(nn.Module):
115 def __init__(
116 self,
117 num_branches,
118 blocks,
119 num_blocks,
120 num_inchannels,
121 num_channels,
122 fuse_method,
123 multi_scale_output=True,
124 ):
125 super(HighResolutionModule, self).__init__()
126 self._check_branches(
127 num_branches, blocks, num_blocks, num_inchannels, num_channels
128 )
130 self.num_inchannels = num_inchannels
131 self.fuse_method = fuse_method
132 self.num_branches = num_branches
134 self.multi_scale_output = multi_scale_output
136 self.branches = self._make_branches(
137 num_branches, blocks, num_blocks, num_channels
138 )
139 self.fuse_layers = self._make_fuse_layers()
140 self.relu = nn.ReLU(False)
142 def _check_branches(
143 self, num_branches, blocks, num_blocks, num_inchannels, num_channels
144 ):
145 if num_branches != len(num_blocks):
146 error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
147 num_branches, len(num_blocks)
148 )
149 logger.error(error_msg)
150 raise ValueError(error_msg)
152 if num_branches != len(num_channels):
153 error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
154 num_branches, len(num_channels)
155 )
156 logger.error(error_msg)
157 raise ValueError(error_msg)
159 if num_branches != len(num_inchannels):
160 error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
161 num_branches, len(num_inchannels)
162 )
163 logger.error(error_msg)
164 raise ValueError(error_msg)
166 def _make_one_branch(
167 self, branch_index, block, num_blocks, num_channels, stride=1
168 ):
169 downsample = None
170 if (
171 stride != 1
172 or self.num_inchannels[branch_index]
173 != num_channels[branch_index] * block.expansion
174 ):
175 downsample = nn.Sequential(
176 nn.Conv2d(
177 self.num_inchannels[branch_index],
178 num_channels[branch_index] * block.expansion,
179 kernel_size=1,
180 stride=stride,
181 bias=False,
182 ),
183 nn.BatchNorm2d(
184 num_channels[branch_index] * block.expansion,
185 momentum=BN_MOMENTUM,
186 ),
187 )
189 layers = []
190 layers.append(
191 block(
192 self.num_inchannels[branch_index],
193 num_channels[branch_index],
194 stride,
195 downsample,
196 )
197 )
198 self.num_inchannels[branch_index] = (
199 num_channels[branch_index] * block.expansion
200 )
201 for i in range(1, num_blocks[branch_index]):
202 layers.append(
203 block(
204 self.num_inchannels[branch_index],
205 num_channels[branch_index],
206 )
207 )
209 return nn.Sequential(*layers)
211 def _make_branches(self, num_branches, block, num_blocks, num_channels):
212 branches = []
214 for i in range(num_branches):
215 branches.append(
216 self._make_one_branch(i, block, num_blocks, num_channels)
217 )
219 return nn.ModuleList(branches)
221 def _make_fuse_layers(self):
222 if self.num_branches == 1:
223 return None
225 num_branches = self.num_branches
226 num_inchannels = self.num_inchannels
227 fuse_layers = []
228 for i in range(num_branches if self.multi_scale_output else 1):
229 fuse_layer = []
230 for j in range(num_branches):
231 if j > i:
232 fuse_layer.append(
233 nn.Sequential(
234 nn.Conv2d(
235 num_inchannels[j],
236 num_inchannels[i],
237 1,
238 1,
239 0,
240 bias=False,
241 ),
242 nn.BatchNorm2d(
243 num_inchannels[i], momentum=BN_MOMENTUM
244 ),
245 nn.Upsample(
246 scale_factor=2 ** (j - i), mode="nearest"
247 ),
248 )
249 )
250 elif j == i:
251 fuse_layer.append(None)
252 else:
253 conv3x3s = []
254 for k in range(i - j):
255 if k == i - j - 1:
256 num_outchannels_conv3x3 = num_inchannels[i]
257 conv3x3s.append(
258 nn.Sequential(
259 nn.Conv2d(
260 num_inchannels[j],
261 num_outchannels_conv3x3,
262 3,
263 2,
264 1,
265 bias=False,
266 ),
267 nn.BatchNorm2d(
268 num_outchannels_conv3x3,
269 momentum=BN_MOMENTUM,
270 ),
271 )
272 )
273 else:
274 num_outchannels_conv3x3 = num_inchannels[j]
275 conv3x3s.append(
276 nn.Sequential(
277 nn.Conv2d(
278 num_inchannels[j],
279 num_outchannels_conv3x3,
280 3,
281 2,
282 1,
283 bias=False,
284 ),
285 nn.BatchNorm2d(
286 num_outchannels_conv3x3,
287 momentum=BN_MOMENTUM,
288 ),
289 nn.ReLU(False),
290 )
291 )
292 fuse_layer.append(nn.Sequential(*conv3x3s))
293 fuse_layers.append(nn.ModuleList(fuse_layer))
295 return nn.ModuleList(fuse_layers)
297 def get_num_inchannels(self):
298 return self.num_inchannels
300 def forward(self, x):
301 if self.num_branches == 1:
302 return [self.branches[0](x[0])]
304 for i in range(self.num_branches):
305 x[i] = self.branches[i](x[i])
307 x_fuse = []
308 for i in range(len(self.fuse_layers)):
309 y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
310 for j in range(1, self.num_branches):
311 if i == j:
312 y = y + x[j]
313 else:
314 y = y + self.fuse_layers[i][j](x[j])
315 x_fuse.append(self.relu(y))
317 return x_fuse
320blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
323class HighResolutionNet(nn.Module):
324 def __init__(self, cfg, **kwargs):
325 super(HighResolutionNet, self).__init__()
327 self.conv1 = nn.Conv2d(
328 3, 64, kernel_size=3, stride=2, padding=1, bias=False
329 )
330 self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
331 # self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
332 self.conv2 = nn.Conv2d(
333 64, 64, kernel_size=3, stride=1, padding=1, bias=False
334 )
336 self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
337 self.relu = nn.ReLU(inplace=True)
339 self.stage1_cfg = cfg["MODEL"]["EXTRA"]["STAGE1"]
340 num_channels = self.stage1_cfg["NUM_CHANNELS"][0]
341 block = blocks_dict[self.stage1_cfg["BLOCK"]]
342 num_blocks = self.stage1_cfg["NUM_BLOCKS"][0]
343 self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
344 stage1_out_channel = block.expansion * num_channels
346 self.stage2_cfg = cfg["MODEL"]["EXTRA"]["STAGE2"]
347 num_channels = self.stage2_cfg["NUM_CHANNELS"]
348 block = blocks_dict[self.stage2_cfg["BLOCK"]]
349 num_channels = [
350 num_channels[i] * block.expansion for i in range(len(num_channels))
351 ]
352 self.transition1 = self._make_transition_layer(
353 [stage1_out_channel], num_channels
354 )
355 self.stage2, pre_stage_channels = self._make_stage(
356 self.stage2_cfg, num_channels
357 )
359 self.stage3_cfg = cfg["MODEL"]["EXTRA"]["STAGE3"]
360 num_channels = self.stage3_cfg["NUM_CHANNELS"]
361 block = blocks_dict[self.stage3_cfg["BLOCK"]]
362 num_channels = [
363 num_channels[i] * block.expansion for i in range(len(num_channels))
364 ]
365 self.transition2 = self._make_transition_layer(
366 pre_stage_channels, num_channels
367 )
368 self.stage3, pre_stage_channels = self._make_stage(
369 self.stage3_cfg, num_channels
370 )
372 self.stage4_cfg = cfg["MODEL"]["EXTRA"]["STAGE4"]
373 num_channels = self.stage4_cfg["NUM_CHANNELS"]
374 block = blocks_dict[self.stage4_cfg["BLOCK"]]
375 num_channels = [
376 num_channels[i] * block.expansion for i in range(len(num_channels))
377 ]
378 self.transition3 = self._make_transition_layer(
379 pre_stage_channels, num_channels
380 )
381 self.stage4, pre_stage_channels = self._make_stage(
382 self.stage4_cfg, num_channels, multi_scale_output=True
383 )
385 # Classification Head
386 (
387 self.incre_modules,
388 self.downsamp_modules,
389 self.final_layer,
390 ) = self._make_head(pre_stage_channels)
392 # self.classifier = nn.Linear(2048, 1000)
393 self.output_layer = Sequential(
394 Flatten(),
395 Linear(
396 2048 * cfg["MODEL"]["out_h"] * cfg["MODEL"]["out_w"],
397 cfg["MODEL"]["feat_dim"],
398 False,
399 ),
400 BatchNorm1d(512),
401 )
403 def _make_head(self, pre_stage_channels):
404 head_block = Bottleneck
405 head_channels = [32, 64, 128, 256]
407 # Increasing the #channels on each resolution
408 # from C, 2C, 4C, 8C to 128, 256, 512, 1024
409 incre_modules = []
410 for i, channels in enumerate(pre_stage_channels):
411 incre_module = self._make_layer(
412 head_block, channels, head_channels[i], 1, stride=1
413 )
414 incre_modules.append(incre_module)
415 incre_modules = nn.ModuleList(incre_modules)
417 # downsampling modules
418 downsamp_modules = []
419 for i in range(len(pre_stage_channels) - 1):
420 in_channels = head_channels[i] * head_block.expansion
421 out_channels = head_channels[i + 1] * head_block.expansion
423 downsamp_module = nn.Sequential(
424 nn.Conv2d(
425 in_channels=in_channels,
426 out_channels=out_channels,
427 kernel_size=3,
428 stride=2,
429 padding=1,
430 ),
431 nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
432 nn.ReLU(inplace=True),
433 )
435 downsamp_modules.append(downsamp_module)
436 downsamp_modules = nn.ModuleList(downsamp_modules)
438 final_layer = nn.Sequential(
439 nn.Conv2d(
440 in_channels=head_channels[3] * head_block.expansion,
441 out_channels=2048,
442 kernel_size=1,
443 stride=1,
444 padding=0,
445 ),
446 nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
447 nn.ReLU(inplace=True),
448 )
450 return incre_modules, downsamp_modules, final_layer
452 def _make_transition_layer(
453 self, num_channels_pre_layer, num_channels_cur_layer
454 ):
455 num_branches_cur = len(num_channels_cur_layer)
456 num_branches_pre = len(num_channels_pre_layer)
458 transition_layers = []
459 for i in range(num_branches_cur):
460 if i < num_branches_pre:
461 if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
462 transition_layers.append(
463 nn.Sequential(
464 nn.Conv2d(
465 num_channels_pre_layer[i],
466 num_channels_cur_layer[i],
467 3,
468 1,
469 1,
470 bias=False,
471 ),
472 nn.BatchNorm2d(
473 num_channels_cur_layer[i], momentum=BN_MOMENTUM
474 ),
475 nn.ReLU(inplace=True),
476 )
477 )
478 else:
479 transition_layers.append(None)
480 else:
481 conv3x3s = []
482 for j in range(i + 1 - num_branches_pre):
483 inchannels = num_channels_pre_layer[-1]
484 outchannels = (
485 num_channels_cur_layer[i]
486 if j == i - num_branches_pre
487 else inchannels
488 )
489 conv3x3s.append(
490 nn.Sequential(
491 nn.Conv2d(
492 inchannels, outchannels, 3, 2, 1, bias=False
493 ),
494 nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
495 nn.ReLU(inplace=True),
496 )
497 )
498 transition_layers.append(nn.Sequential(*conv3x3s))
500 return nn.ModuleList(transition_layers)
502 def _make_layer(self, block, inplanes, planes, blocks, stride=1):
503 downsample = None
504 if stride != 1 or inplanes != planes * block.expansion:
505 downsample = nn.Sequential(
506 nn.Conv2d(
507 inplanes,
508 planes * block.expansion,
509 kernel_size=1,
510 stride=stride,
511 bias=False,
512 ),
513 nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
514 )
516 layers = []
517 layers.append(block(inplanes, planes, stride, downsample))
518 inplanes = planes * block.expansion
519 for i in range(1, blocks):
520 layers.append(block(inplanes, planes))
522 return nn.Sequential(*layers)
524 def _make_stage(
525 self, layer_config, num_inchannels, multi_scale_output=True
526 ):
527 num_modules = layer_config["NUM_MODULES"]
528 num_branches = layer_config["NUM_BRANCHES"]
529 num_blocks = layer_config["NUM_BLOCKS"]
530 num_channels = layer_config["NUM_CHANNELS"]
531 block = blocks_dict[layer_config["BLOCK"]]
532 fuse_method = layer_config["FUSE_METHOD"]
534 modules = []
535 for i in range(num_modules):
536 # multi_scale_output is only used last module
537 if not multi_scale_output and i == num_modules - 1:
538 reset_multi_scale_output = False
539 else:
540 reset_multi_scale_output = True
542 modules.append(
543 HighResolutionModule(
544 num_branches,
545 block,
546 num_blocks,
547 num_inchannels,
548 num_channels,
549 fuse_method,
550 reset_multi_scale_output,
551 )
552 )
553 num_inchannels = modules[-1].get_num_inchannels()
555 return nn.Sequential(*modules), num_inchannels
557 def forward(self, x):
558 x = self.conv1(x)
559 x = self.bn1(x)
560 x = self.relu(x)
561 x = self.conv2(x)
562 x = self.bn2(x)
563 x = self.relu(x)
564 x = self.layer1(x)
566 x_list = []
567 for i in range(self.stage2_cfg["NUM_BRANCHES"]):
568 if self.transition1[i] is not None:
569 x_list.append(self.transition1[i](x))
570 else:
571 x_list.append(x)
572 y_list = self.stage2(x_list)
574 x_list = []
575 for i in range(self.stage3_cfg["NUM_BRANCHES"]):
576 if self.transition2[i] is not None:
577 x_list.append(self.transition2[i](y_list[-1]))
578 else:
579 x_list.append(y_list[i])
580 y_list = self.stage3(x_list)
582 x_list = []
583 for i in range(self.stage4_cfg["NUM_BRANCHES"]):
584 if self.transition3[i] is not None:
585 x_list.append(self.transition3[i](y_list[-1]))
586 else:
587 x_list.append(y_list[i])
588 y_list = self.stage4(x_list)
590 # Classification Head
591 y = self.incre_modules[0](y_list[0])
592 for i in range(len(self.downsamp_modules)):
593 y = self.incre_modules[i + 1](
594 y_list[i + 1]
595 ) + self.downsamp_modules[i](y)
597 y = self.final_layer(y)
598 """
599 if torch._C._get_tracing_state():
600 y = y.flatten(start_dim=2).mean(dim=2)
601 else:
602 y = F.avg_pool2d(y, kernel_size=y.size()
603 [2:]).view(y.size(0), -1)
605 y = self.classifier(y)
606 """
607 y = self.output_layer(y)
608 return y
610 def init_weights(
611 self,
612 pretrained="",
613 ):
614 logger.info("=> init weights from normal distribution")
615 for m in self.modules():
616 if isinstance(m, nn.Conv2d):
617 nn.init.kaiming_normal_(
618 m.weight, mode="fan_out", nonlinearity="relu"
619 )
620 elif isinstance(m, nn.BatchNorm2d):
621 nn.init.constant_(m.weight, 1)
622 nn.init.constant_(m.bias, 0)
623 if os.path.isfile(pretrained):
624 pretrained_dict = torch.load(pretrained)
625 logger.info("=> loading pretrained model {}".format(pretrained))
626 model_dict = self.state_dict()
627 pretrained_dict = {
628 k: v
629 for k, v in pretrained_dict.items()
630 if k in model_dict.keys()
631 }
632 for k, _ in pretrained_dict.items():
633 logger.info(
634 "=> loading {} pretrained model {}".format(k, pretrained)
635 )
636 model_dict.update(pretrained_dict)
637 self.load_state_dict(model_dict)
640def get_cls_net(config, **kwargs):
641 model = HighResolutionNet(config, **kwargs)
642 model.init_weights()
643 return model