Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/resnest.py: 76%
29 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Jun Wang
3@date: 20210301
4@contact: jun21wangustc@gmail.com
5"""
7# based on:
8# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
10import torch
11import torch.nn as nn
13from .resnet import Bottleneck, ResNet
16class Flatten(nn.Module):
17 def forward(self, input):
18 return input.view(input.size(0), -1)
21def l2_norm(input, axis=1):
22 norm = torch.norm(input, 2, axis, True)
23 output = torch.div(input, norm)
24 return output
27class ResNeSt(nn.Module):
28 def __init__(self, num_layers, drop_ratio, feat_dim, out_h=7, out_w=7):
29 super(ResNeSt, self).__init__()
30 self.input_layer = nn.Sequential(
31 nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
32 nn.BatchNorm2d(64),
33 nn.PReLU(64),
34 )
35 self.output_layer = nn.Sequential(
36 nn.BatchNorm2d(2048),
37 nn.Dropout(drop_ratio),
38 Flatten(),
39 nn.Linear(2048 * out_h * out_w, feat_dim),
40 nn.BatchNorm1d(feat_dim),
41 )
42 if num_layers == 50:
43 self.body = ResNet(
44 Bottleneck,
45 [3, 4, 6, 3],
46 radix=2,
47 groups=1,
48 bottleneck_width=64,
49 deep_stem=True,
50 stem_width=32,
51 avg_down=True,
52 avd=True,
53 avd_first=False,
54 )
55 elif num_layers == 101:
56 self.body = ResNet(
57 Bottleneck,
58 [3, 4, 23, 3],
59 radix=2,
60 groups=1,
61 bottleneck_width=64,
62 deep_stem=True,
63 stem_width=64,
64 avg_down=True,
65 avd=True,
66 avd_first=False,
67 )
68 elif num_layers == 200:
69 self.body = ResNet(
70 Bottleneck,
71 [3, 24, 36, 3],
72 radix=2,
73 groups=1,
74 bottleneck_width=64,
75 deep_stem=True,
76 stem_width=64,
77 avg_down=True,
78 avd=True,
79 avd_first=False,
80 )
81 elif num_layers == 269:
82 self.body = ResNet(
83 Bottleneck,
84 [3, 30, 48, 8],
85 radix=2,
86 groups=1,
87 bottleneck_width=64,
88 deep_stem=True,
89 stem_width=64,
90 avg_down=True,
91 avd=True,
92 avd_first=False,
93 )
94 else:
95 pass
97 def forward(self, x):
98 x = self.input_layer(x)
99 x = self.body(x)
100 x = self.output_layer(x)
101 return l2_norm(x)