Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/resnet.py: 76%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 

2# Created by: Hang Zhang 

3# Email: zhanghang0704@gmail.com 

4# Copyright (c) 2020 

5# 

6# LICENSE file in the root directory of this source tree 

7# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 

8"""ResNet variants""" 

9import torch.nn as nn 

10 

11from .splat import SplAtConv2d 

12 

13__all__ = ["ResNet", "Bottleneck"] 

14 

15 

16class DropBlock2D(object): 

17 def __init__(self, *args, **kwargs): 

18 raise NotImplementedError 

19 

20 

21class GlobalAvgPool2d(nn.Module): 

22 def __init__(self): 

23 """Global average pooling over the input's spatial dimensions""" 

24 super(GlobalAvgPool2d, self).__init__() 

25 

26 def forward(self, inputs): 

27 return nn.functional.adaptive_avg_pool2d(inputs, 1).view( 

28 inputs.size(0), -1 

29 ) 

30 

31 

32class Bottleneck(nn.Module): 

33 """ResNet Bottleneck""" 

34 

35 # pylint: disable=unused-argument 

36 expansion = 4 

37 

38 def __init__( 

39 self, 

40 inplanes, 

41 planes, 

42 stride=1, 

43 downsample=None, 

44 radix=1, 

45 cardinality=1, 

46 bottleneck_width=64, 

47 avd=False, 

48 avd_first=False, 

49 dilation=1, 

50 is_first=False, 

51 rectified_conv=False, 

52 rectify_avg=False, 

53 norm_layer=None, 

54 dropblock_prob=0.0, 

55 last_gamma=False, 

56 ): 

57 super(Bottleneck, self).__init__() 

58 group_width = int(planes * (bottleneck_width / 64.0)) * cardinality 

59 self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 

60 self.bn1 = norm_layer(group_width) 

61 self.dropblock_prob = dropblock_prob 

62 self.radix = radix 

63 self.avd = avd and (stride > 1 or is_first) 

64 self.avd_first = avd_first 

65 

66 if self.avd: 

67 self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 

68 stride = 1 

69 

70 if dropblock_prob > 0.0: 

71 self.dropblock1 = DropBlock2D(dropblock_prob, 3) 

72 if radix == 1: 

73 self.dropblock2 = DropBlock2D(dropblock_prob, 3) 

74 self.dropblock3 = DropBlock2D(dropblock_prob, 3) 

75 

76 if radix >= 1: 

77 self.conv2 = SplAtConv2d( 

78 group_width, 

79 group_width, 

80 kernel_size=3, 

81 stride=stride, 

82 padding=dilation, 

83 dilation=dilation, 

84 groups=cardinality, 

85 bias=False, 

86 radix=radix, 

87 rectify=rectified_conv, 

88 rectify_avg=rectify_avg, 

89 norm_layer=norm_layer, 

90 dropblock_prob=dropblock_prob, 

91 ) 

92 elif rectified_conv: 

93 from rfconv import RFConv2d 

94 

95 self.conv2 = RFConv2d( 

96 group_width, 

97 group_width, 

98 kernel_size=3, 

99 stride=stride, 

100 padding=dilation, 

101 dilation=dilation, 

102 groups=cardinality, 

103 bias=False, 

104 average_mode=rectify_avg, 

105 ) 

106 self.bn2 = norm_layer(group_width) 

107 else: 

108 self.conv2 = nn.Conv2d( 

109 group_width, 

110 group_width, 

111 kernel_size=3, 

112 stride=stride, 

113 padding=dilation, 

114 dilation=dilation, 

115 groups=cardinality, 

116 bias=False, 

117 ) 

118 self.bn2 = norm_layer(group_width) 

119 

120 self.conv3 = nn.Conv2d( 

121 group_width, planes * 4, kernel_size=1, bias=False 

122 ) 

123 self.bn3 = norm_layer(planes * 4) 

124 

125 if last_gamma: 

126 from torch.nn.init import zeros_ 

127 

128 zeros_(self.bn3.weight) 

129 self.relu = nn.ReLU(inplace=True) 

130 self.downsample = downsample 

131 self.dilation = dilation 

132 self.stride = stride 

133 

134 def forward(self, x): 

135 residual = x 

136 

137 out = self.conv1(x) 

138 out = self.bn1(out) 

139 if self.dropblock_prob > 0.0: 

140 out = self.dropblock1(out) 

141 out = self.relu(out) 

142 

143 if self.avd and self.avd_first: 

144 out = self.avd_layer(out) 

145 

146 out = self.conv2(out) 

147 if self.radix == 0: 

148 out = self.bn2(out) 

149 if self.dropblock_prob > 0.0: 

150 out = self.dropblock2(out) 

151 out = self.relu(out) 

152 

153 if self.avd and not self.avd_first: 

154 out = self.avd_layer(out) 

155 

156 out = self.conv3(out) 

157 out = self.bn3(out) 

158 if self.dropblock_prob > 0.0: 

159 out = self.dropblock3(out) 

160 

161 if self.downsample is not None: 

162 residual = self.downsample(x) 

163 

164 out += residual 

165 out = self.relu(out) 

166 

167 return out 

168 

169 

170class ResNet(nn.Module): 

171 """ResNet Variants 

172 

173 Parameters 

174 ---------- 

175 block : Block 

176 Class for the residual block. Options are BasicBlockV1, BottleneckV1. 

177 layers : list of int 

178 Numbers of layers in each block 

179 classes : int, default 1000 

180 Number of classification classes. 

181 dilated : bool, default False 

182 Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 

183 typically used in Semantic Segmentation. 

184 norm_layer : object 

185 Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 

186 for Synchronized Cross-GPU BachNormalization). 

187 

188 Reference: 

189 

190 - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 

191 

192 - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 

193 """ 

194 

195 # pylint: disable=unused-variable 

196 def __init__( 

197 self, 

198 block, 

199 layers, 

200 radix=1, 

201 groups=1, 

202 bottleneck_width=64, 

203 num_classes=1000, 

204 dilated=False, 

205 dilation=1, 

206 deep_stem=False, 

207 stem_width=64, 

208 avg_down=False, 

209 rectified_conv=False, 

210 rectify_avg=False, 

211 avd=False, 

212 avd_first=False, 

213 final_drop=0.0, 

214 dropblock_prob=0, 

215 last_gamma=False, 

216 norm_layer=nn.BatchNorm2d, 

217 ): 

218 self.cardinality = groups 

219 self.bottleneck_width = bottleneck_width 

220 # ResNet-D params 

221 self.inplanes = stem_width * 2 if deep_stem else 64 

222 self.avg_down = avg_down 

223 self.last_gamma = last_gamma 

224 # ResNeSt params 

225 self.radix = radix 

226 self.avd = avd 

227 self.avd_first = avd_first 

228 

229 super(ResNet, self).__init__() 

230 self.rectified_conv = rectified_conv 

231 self.rectify_avg = rectify_avg 

232 """ 

233 if rectified_conv: 

234 from rfconv import RFConv2d 

235 

236 conv_layer = RFConv2d 

237 else: 

238 conv_layer = nn.Conv2d 

239 conv_kwargs = {"average_mode": rectify_avg} if rectified_conv else {} 

240 if deep_stem: 

241 self.conv1 = nn.Sequential( 

242 conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), 

243 norm_layer(stem_width), 

244 nn.ReLU(inplace=True), 

245 conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 

246 norm_layer(stem_width), 

247 nn.ReLU(inplace=True), 

248 conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 

249 ) 

250 else: 

251 self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 

252 bias=False, **conv_kwargs) 

253 self.bn1 = norm_layer(self.inplanes) 

254 self.relu = nn.ReLU(inplace=True) 

255 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

256 """ 

257 # self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 

258 self.layer1 = self._make_layer( 

259 block, 

260 64, 

261 layers[0], 

262 stride=2, 

263 norm_layer=norm_layer, 

264 is_first=False, 

265 ) 

266 self.layer2 = self._make_layer( 

267 block, 128, layers[1], stride=2, norm_layer=norm_layer 

268 ) 

269 if dilated or dilation == 4: 

270 self.layer3 = self._make_layer( 

271 block, 

272 256, 

273 layers[2], 

274 stride=1, 

275 dilation=2, 

276 norm_layer=norm_layer, 

277 dropblock_prob=dropblock_prob, 

278 ) 

279 self.layer4 = self._make_layer( 

280 block, 

281 512, 

282 layers[3], 

283 stride=1, 

284 dilation=4, 

285 norm_layer=norm_layer, 

286 dropblock_prob=dropblock_prob, 

287 ) 

288 elif dilation == 2: 

289 self.layer3 = self._make_layer( 

290 block, 

291 256, 

292 layers[2], 

293 stride=2, 

294 dilation=1, 

295 norm_layer=norm_layer, 

296 dropblock_prob=dropblock_prob, 

297 ) 

298 self.layer4 = self._make_layer( 

299 block, 

300 512, 

301 layers[3], 

302 stride=1, 

303 dilation=2, 

304 norm_layer=norm_layer, 

305 dropblock_prob=dropblock_prob, 

306 ) 

307 else: 

308 self.layer3 = self._make_layer( 

309 block, 

310 256, 

311 layers[2], 

312 stride=2, 

313 norm_layer=norm_layer, 

314 dropblock_prob=dropblock_prob, 

315 ) 

316 self.layer4 = self._make_layer( 

317 block, 

318 512, 

319 layers[3], 

320 stride=2, 

321 norm_layer=norm_layer, 

322 dropblock_prob=dropblock_prob, 

323 ) 

324 """ 

325 self.avgpool = GlobalAvgPool2d() 

326 self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None 

327 self.fc = nn.Linear(512 * block.expansion, num_classes) 

328 

329 for m in self.modules(): 

330 if isinstance(m, nn.Conv2d): 

331 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 

332 m.weight.data.normal_(0, math.sqrt(2. / n)) 

333 elif isinstance(m, norm_layer): 

334 m.weight.data.fill_(1) 

335 m.bias.data.zero_() 

336 """ 

337 

338 def _make_layer( 

339 self, 

340 block, 

341 planes, 

342 blocks, 

343 stride=1, 

344 dilation=1, 

345 norm_layer=None, 

346 dropblock_prob=0.0, 

347 is_first=True, 

348 ): 

349 downsample = None 

350 if stride != 1 or self.inplanes != planes * block.expansion: 

351 down_layers = [] 

352 if self.avg_down: 

353 if dilation == 1: 

354 down_layers.append( 

355 nn.AvgPool2d( 

356 kernel_size=stride, 

357 stride=stride, 

358 ceil_mode=True, 

359 count_include_pad=False, 

360 ) 

361 ) 

362 else: 

363 down_layers.append( 

364 nn.AvgPool2d( 

365 kernel_size=1, 

366 stride=1, 

367 ceil_mode=True, 

368 count_include_pad=False, 

369 ) 

370 ) 

371 down_layers.append( 

372 nn.Conv2d( 

373 self.inplanes, 

374 planes * block.expansion, 

375 kernel_size=1, 

376 stride=1, 

377 bias=False, 

378 ) 

379 ) 

380 else: 

381 down_layers.append( 

382 nn.Conv2d( 

383 self.inplanes, 

384 planes * block.expansion, 

385 kernel_size=1, 

386 stride=stride, 

387 bias=False, 

388 ) 

389 ) 

390 down_layers.append(norm_layer(planes * block.expansion)) 

391 downsample = nn.Sequential(*down_layers) 

392 

393 layers = [] 

394 if dilation == 1 or dilation == 2: 

395 layers.append( 

396 block( 

397 self.inplanes, 

398 planes, 

399 stride, 

400 downsample=downsample, 

401 radix=self.radix, 

402 cardinality=self.cardinality, 

403 bottleneck_width=self.bottleneck_width, 

404 avd=self.avd, 

405 avd_first=self.avd_first, 

406 dilation=1, 

407 is_first=is_first, 

408 rectified_conv=self.rectified_conv, 

409 rectify_avg=self.rectify_avg, 

410 norm_layer=norm_layer, 

411 dropblock_prob=dropblock_prob, 

412 last_gamma=self.last_gamma, 

413 ) 

414 ) 

415 elif dilation == 4: 

416 layers.append( 

417 block( 

418 self.inplanes, 

419 planes, 

420 stride, 

421 downsample=downsample, 

422 radix=self.radix, 

423 cardinality=self.cardinality, 

424 bottleneck_width=self.bottleneck_width, 

425 avd=self.avd, 

426 avd_first=self.avd_first, 

427 dilation=2, 

428 is_first=is_first, 

429 rectified_conv=self.rectified_conv, 

430 rectify_avg=self.rectify_avg, 

431 norm_layer=norm_layer, 

432 dropblock_prob=dropblock_prob, 

433 last_gamma=self.last_gamma, 

434 ) 

435 ) 

436 else: 

437 raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 

438 

439 self.inplanes = planes * block.expansion 

440 for i in range(1, blocks): 

441 layers.append( 

442 block( 

443 self.inplanes, 

444 planes, 

445 radix=self.radix, 

446 cardinality=self.cardinality, 

447 bottleneck_width=self.bottleneck_width, 

448 avd=self.avd, 

449 avd_first=self.avd_first, 

450 dilation=dilation, 

451 rectified_conv=self.rectified_conv, 

452 rectify_avg=self.rectify_avg, 

453 norm_layer=norm_layer, 

454 dropblock_prob=dropblock_prob, 

455 last_gamma=self.last_gamma, 

456 ) 

457 ) 

458 

459 return nn.Sequential(*layers) 

460 

461 def forward(self, x): 

462 """ 

463 x = self.conv1(x) 

464 x = self.bn1(x) 

465 x = self.relu(x) 

466 x = self.maxpool(x) 

467 """ 

468 x = self.layer1(x) 

469 x = self.layer2(x) 

470 x = self.layer3(x) 

471 x = self.layer4(x) 

472 """ 

473 x = self.avgpool(x) 

474 #x = x.view(x.size(0), -1) 

475 x = torch.flatten(x, 1) 

476 if self.drop: 

477 x = self.drop(x) 

478 x = self.fc(x) 

479 """ 

480 return x