Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/resnet.py: 76%
127 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2# Created by: Hang Zhang
3# Email: zhanghang0704@gmail.com
4# Copyright (c) 2020
5#
6# LICENSE file in the root directory of this source tree
7# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8"""ResNet variants"""
9import torch.nn as nn
11from .splat import SplAtConv2d
13__all__ = ["ResNet", "Bottleneck"]
16class DropBlock2D(object):
17 def __init__(self, *args, **kwargs):
18 raise NotImplementedError
21class GlobalAvgPool2d(nn.Module):
22 def __init__(self):
23 """Global average pooling over the input's spatial dimensions"""
24 super(GlobalAvgPool2d, self).__init__()
26 def forward(self, inputs):
27 return nn.functional.adaptive_avg_pool2d(inputs, 1).view(
28 inputs.size(0), -1
29 )
32class Bottleneck(nn.Module):
33 """ResNet Bottleneck"""
35 # pylint: disable=unused-argument
36 expansion = 4
38 def __init__(
39 self,
40 inplanes,
41 planes,
42 stride=1,
43 downsample=None,
44 radix=1,
45 cardinality=1,
46 bottleneck_width=64,
47 avd=False,
48 avd_first=False,
49 dilation=1,
50 is_first=False,
51 rectified_conv=False,
52 rectify_avg=False,
53 norm_layer=None,
54 dropblock_prob=0.0,
55 last_gamma=False,
56 ):
57 super(Bottleneck, self).__init__()
58 group_width = int(planes * (bottleneck_width / 64.0)) * cardinality
59 self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
60 self.bn1 = norm_layer(group_width)
61 self.dropblock_prob = dropblock_prob
62 self.radix = radix
63 self.avd = avd and (stride > 1 or is_first)
64 self.avd_first = avd_first
66 if self.avd:
67 self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
68 stride = 1
70 if dropblock_prob > 0.0:
71 self.dropblock1 = DropBlock2D(dropblock_prob, 3)
72 if radix == 1:
73 self.dropblock2 = DropBlock2D(dropblock_prob, 3)
74 self.dropblock3 = DropBlock2D(dropblock_prob, 3)
76 if radix >= 1:
77 self.conv2 = SplAtConv2d(
78 group_width,
79 group_width,
80 kernel_size=3,
81 stride=stride,
82 padding=dilation,
83 dilation=dilation,
84 groups=cardinality,
85 bias=False,
86 radix=radix,
87 rectify=rectified_conv,
88 rectify_avg=rectify_avg,
89 norm_layer=norm_layer,
90 dropblock_prob=dropblock_prob,
91 )
92 elif rectified_conv:
93 from rfconv import RFConv2d
95 self.conv2 = RFConv2d(
96 group_width,
97 group_width,
98 kernel_size=3,
99 stride=stride,
100 padding=dilation,
101 dilation=dilation,
102 groups=cardinality,
103 bias=False,
104 average_mode=rectify_avg,
105 )
106 self.bn2 = norm_layer(group_width)
107 else:
108 self.conv2 = nn.Conv2d(
109 group_width,
110 group_width,
111 kernel_size=3,
112 stride=stride,
113 padding=dilation,
114 dilation=dilation,
115 groups=cardinality,
116 bias=False,
117 )
118 self.bn2 = norm_layer(group_width)
120 self.conv3 = nn.Conv2d(
121 group_width, planes * 4, kernel_size=1, bias=False
122 )
123 self.bn3 = norm_layer(planes * 4)
125 if last_gamma:
126 from torch.nn.init import zeros_
128 zeros_(self.bn3.weight)
129 self.relu = nn.ReLU(inplace=True)
130 self.downsample = downsample
131 self.dilation = dilation
132 self.stride = stride
134 def forward(self, x):
135 residual = x
137 out = self.conv1(x)
138 out = self.bn1(out)
139 if self.dropblock_prob > 0.0:
140 out = self.dropblock1(out)
141 out = self.relu(out)
143 if self.avd and self.avd_first:
144 out = self.avd_layer(out)
146 out = self.conv2(out)
147 if self.radix == 0:
148 out = self.bn2(out)
149 if self.dropblock_prob > 0.0:
150 out = self.dropblock2(out)
151 out = self.relu(out)
153 if self.avd and not self.avd_first:
154 out = self.avd_layer(out)
156 out = self.conv3(out)
157 out = self.bn3(out)
158 if self.dropblock_prob > 0.0:
159 out = self.dropblock3(out)
161 if self.downsample is not None:
162 residual = self.downsample(x)
164 out += residual
165 out = self.relu(out)
167 return out
170class ResNet(nn.Module):
171 """ResNet Variants
173 Parameters
174 ----------
175 block : Block
176 Class for the residual block. Options are BasicBlockV1, BottleneckV1.
177 layers : list of int
178 Numbers of layers in each block
179 classes : int, default 1000
180 Number of classification classes.
181 dilated : bool, default False
182 Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
183 typically used in Semantic Segmentation.
184 norm_layer : object
185 Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
186 for Synchronized Cross-GPU BachNormalization).
188 Reference:
190 - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
192 - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
193 """
195 # pylint: disable=unused-variable
196 def __init__(
197 self,
198 block,
199 layers,
200 radix=1,
201 groups=1,
202 bottleneck_width=64,
203 num_classes=1000,
204 dilated=False,
205 dilation=1,
206 deep_stem=False,
207 stem_width=64,
208 avg_down=False,
209 rectified_conv=False,
210 rectify_avg=False,
211 avd=False,
212 avd_first=False,
213 final_drop=0.0,
214 dropblock_prob=0,
215 last_gamma=False,
216 norm_layer=nn.BatchNorm2d,
217 ):
218 self.cardinality = groups
219 self.bottleneck_width = bottleneck_width
220 # ResNet-D params
221 self.inplanes = stem_width * 2 if deep_stem else 64
222 self.avg_down = avg_down
223 self.last_gamma = last_gamma
224 # ResNeSt params
225 self.radix = radix
226 self.avd = avd
227 self.avd_first = avd_first
229 super(ResNet, self).__init__()
230 self.rectified_conv = rectified_conv
231 self.rectify_avg = rectify_avg
232 """
233 if rectified_conv:
234 from rfconv import RFConv2d
236 conv_layer = RFConv2d
237 else:
238 conv_layer = nn.Conv2d
239 conv_kwargs = {"average_mode": rectify_avg} if rectified_conv else {}
240 if deep_stem:
241 self.conv1 = nn.Sequential(
242 conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
243 norm_layer(stem_width),
244 nn.ReLU(inplace=True),
245 conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
246 norm_layer(stem_width),
247 nn.ReLU(inplace=True),
248 conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
249 )
250 else:
251 self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
252 bias=False, **conv_kwargs)
253 self.bn1 = norm_layer(self.inplanes)
254 self.relu = nn.ReLU(inplace=True)
255 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
256 """
257 # self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
258 self.layer1 = self._make_layer(
259 block,
260 64,
261 layers[0],
262 stride=2,
263 norm_layer=norm_layer,
264 is_first=False,
265 )
266 self.layer2 = self._make_layer(
267 block, 128, layers[1], stride=2, norm_layer=norm_layer
268 )
269 if dilated or dilation == 4:
270 self.layer3 = self._make_layer(
271 block,
272 256,
273 layers[2],
274 stride=1,
275 dilation=2,
276 norm_layer=norm_layer,
277 dropblock_prob=dropblock_prob,
278 )
279 self.layer4 = self._make_layer(
280 block,
281 512,
282 layers[3],
283 stride=1,
284 dilation=4,
285 norm_layer=norm_layer,
286 dropblock_prob=dropblock_prob,
287 )
288 elif dilation == 2:
289 self.layer3 = self._make_layer(
290 block,
291 256,
292 layers[2],
293 stride=2,
294 dilation=1,
295 norm_layer=norm_layer,
296 dropblock_prob=dropblock_prob,
297 )
298 self.layer4 = self._make_layer(
299 block,
300 512,
301 layers[3],
302 stride=1,
303 dilation=2,
304 norm_layer=norm_layer,
305 dropblock_prob=dropblock_prob,
306 )
307 else:
308 self.layer3 = self._make_layer(
309 block,
310 256,
311 layers[2],
312 stride=2,
313 norm_layer=norm_layer,
314 dropblock_prob=dropblock_prob,
315 )
316 self.layer4 = self._make_layer(
317 block,
318 512,
319 layers[3],
320 stride=2,
321 norm_layer=norm_layer,
322 dropblock_prob=dropblock_prob,
323 )
324 """
325 self.avgpool = GlobalAvgPool2d()
326 self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
327 self.fc = nn.Linear(512 * block.expansion, num_classes)
329 for m in self.modules():
330 if isinstance(m, nn.Conv2d):
331 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
332 m.weight.data.normal_(0, math.sqrt(2. / n))
333 elif isinstance(m, norm_layer):
334 m.weight.data.fill_(1)
335 m.bias.data.zero_()
336 """
338 def _make_layer(
339 self,
340 block,
341 planes,
342 blocks,
343 stride=1,
344 dilation=1,
345 norm_layer=None,
346 dropblock_prob=0.0,
347 is_first=True,
348 ):
349 downsample = None
350 if stride != 1 or self.inplanes != planes * block.expansion:
351 down_layers = []
352 if self.avg_down:
353 if dilation == 1:
354 down_layers.append(
355 nn.AvgPool2d(
356 kernel_size=stride,
357 stride=stride,
358 ceil_mode=True,
359 count_include_pad=False,
360 )
361 )
362 else:
363 down_layers.append(
364 nn.AvgPool2d(
365 kernel_size=1,
366 stride=1,
367 ceil_mode=True,
368 count_include_pad=False,
369 )
370 )
371 down_layers.append(
372 nn.Conv2d(
373 self.inplanes,
374 planes * block.expansion,
375 kernel_size=1,
376 stride=1,
377 bias=False,
378 )
379 )
380 else:
381 down_layers.append(
382 nn.Conv2d(
383 self.inplanes,
384 planes * block.expansion,
385 kernel_size=1,
386 stride=stride,
387 bias=False,
388 )
389 )
390 down_layers.append(norm_layer(planes * block.expansion))
391 downsample = nn.Sequential(*down_layers)
393 layers = []
394 if dilation == 1 or dilation == 2:
395 layers.append(
396 block(
397 self.inplanes,
398 planes,
399 stride,
400 downsample=downsample,
401 radix=self.radix,
402 cardinality=self.cardinality,
403 bottleneck_width=self.bottleneck_width,
404 avd=self.avd,
405 avd_first=self.avd_first,
406 dilation=1,
407 is_first=is_first,
408 rectified_conv=self.rectified_conv,
409 rectify_avg=self.rectify_avg,
410 norm_layer=norm_layer,
411 dropblock_prob=dropblock_prob,
412 last_gamma=self.last_gamma,
413 )
414 )
415 elif dilation == 4:
416 layers.append(
417 block(
418 self.inplanes,
419 planes,
420 stride,
421 downsample=downsample,
422 radix=self.radix,
423 cardinality=self.cardinality,
424 bottleneck_width=self.bottleneck_width,
425 avd=self.avd,
426 avd_first=self.avd_first,
427 dilation=2,
428 is_first=is_first,
429 rectified_conv=self.rectified_conv,
430 rectify_avg=self.rectify_avg,
431 norm_layer=norm_layer,
432 dropblock_prob=dropblock_prob,
433 last_gamma=self.last_gamma,
434 )
435 )
436 else:
437 raise RuntimeError("=> unknown dilation size: {}".format(dilation))
439 self.inplanes = planes * block.expansion
440 for i in range(1, blocks):
441 layers.append(
442 block(
443 self.inplanes,
444 planes,
445 radix=self.radix,
446 cardinality=self.cardinality,
447 bottleneck_width=self.bottleneck_width,
448 avd=self.avd,
449 avd_first=self.avd_first,
450 dilation=dilation,
451 rectified_conv=self.rectified_conv,
452 rectify_avg=self.rectify_avg,
453 norm_layer=norm_layer,
454 dropblock_prob=dropblock_prob,
455 last_gamma=self.last_gamma,
456 )
457 )
459 return nn.Sequential(*layers)
461 def forward(self, x):
462 """
463 x = self.conv1(x)
464 x = self.bn1(x)
465 x = self.relu(x)
466 x = self.maxpool(x)
467 """
468 x = self.layer1(x)
469 x = self.layer2(x)
470 x = self.layer3(x)
471 x = self.layer4(x)
472 """
473 x = self.avgpool(x)
474 #x = x.view(x.size(0), -1)
475 x = torch.flatten(x, 1)
476 if self.drop:
477 x = self.drop(x)
478 x = self.fc(x)
479 """
480 return x