Coverage for src/bob/bio/face/pytorch/backbones/iresnet.py: 0%
117 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1import torch
3from torch import nn
5__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
8def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9 """3x3 convolution with padding"""
10 return nn.Conv2d(
11 in_planes,
12 out_planes,
13 kernel_size=3,
14 stride=stride,
15 padding=dilation,
16 groups=groups,
17 bias=False,
18 dilation=dilation,
19 )
22def conv1x1(in_planes, out_planes, stride=1):
23 """1x1 convolution"""
24 return nn.Conv2d(
25 in_planes, out_planes, kernel_size=1, stride=stride, bias=False
26 )
29class IBasicBlock(nn.Module):
30 expansion = 1
32 def __init__(
33 self,
34 inplanes,
35 planes,
36 stride=1,
37 downsample=None,
38 groups=1,
39 base_width=64,
40 dilation=1,
41 ):
42 super(IBasicBlock, self).__init__()
43 if groups != 1 or base_width != 64:
44 raise ValueError(
45 "BasicBlock only supports groups=1 and base_width=64"
46 )
47 if dilation > 1:
48 raise NotImplementedError(
49 "Dilation > 1 not supported in BasicBlock"
50 )
51 self.bn1 = nn.BatchNorm2d(
52 inplanes,
53 eps=1e-05,
54 )
55 self.conv1 = conv3x3(inplanes, planes)
56 self.bn2 = nn.BatchNorm2d(
57 planes,
58 eps=1e-05,
59 )
60 self.prelu = nn.PReLU(planes)
61 self.conv2 = conv3x3(planes, planes, stride)
62 self.bn3 = nn.BatchNorm2d(
63 planes,
64 eps=1e-05,
65 )
66 self.downsample = downsample
67 self.stride = stride
69 def forward(self, x):
70 identity = x
71 out = self.bn1(x)
72 out = self.conv1(out)
73 out = self.bn2(out)
74 out = self.prelu(out)
75 out = self.conv2(out)
76 out = self.bn3(out)
77 if self.downsample is not None:
78 identity = self.downsample(x)
79 out += identity
80 return out
83class IResNet(nn.Module):
84 fc_scale = 7 * 7
86 def __init__(
87 self,
88 block,
89 layers,
90 dropout=0,
91 num_features=512,
92 zero_init_residual=False,
93 groups=1,
94 width_per_group=64,
95 replace_stride_with_dilation=None,
96 fp16=False,
97 ):
98 super(IResNet, self).__init__()
99 self.fp16 = fp16
100 self.inplanes = 64
101 self.dilation = 1
102 if replace_stride_with_dilation is None:
103 replace_stride_with_dilation = [False, False, False]
104 if len(replace_stride_with_dilation) != 3:
105 raise ValueError(
106 "replace_stride_with_dilation should be None "
107 "or a 3-element tuple, got {}".format(
108 replace_stride_with_dilation
109 )
110 )
111 self.groups = groups
112 self.base_width = width_per_group
113 self.conv1 = nn.Conv2d(
114 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
115 )
116 self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
117 self.prelu = nn.PReLU(self.inplanes)
118 self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
119 self.layer2 = self._make_layer(
120 block,
121 128,
122 layers[1],
123 stride=2,
124 dilate=replace_stride_with_dilation[0],
125 )
126 self.layer3 = self._make_layer(
127 block,
128 256,
129 layers[2],
130 stride=2,
131 dilate=replace_stride_with_dilation[1],
132 )
133 self.layer4 = self._make_layer(
134 block,
135 512,
136 layers[3],
137 stride=2,
138 dilate=replace_stride_with_dilation[2],
139 )
140 self.bn2 = nn.BatchNorm2d(
141 512 * block.expansion,
142 eps=1e-05,
143 )
144 self.dropout = nn.Dropout(p=dropout, inplace=True)
145 self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
146 self.features = nn.BatchNorm1d(num_features, eps=1e-05)
147 nn.init.constant_(self.features.weight, 1.0)
148 self.features.weight.requires_grad = False
150 for m in self.modules():
151 if isinstance(m, nn.Conv2d):
152 nn.init.normal_(m.weight, 0, 0.1)
153 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 nn.init.constant_(m.weight, 1)
155 nn.init.constant_(m.bias, 0)
157 if zero_init_residual:
158 for m in self.modules():
159 if isinstance(m, IBasicBlock):
160 nn.init.constant_(m.bn2.weight, 0)
162 def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
163 downsample = None
164 previous_dilation = self.dilation
165 if dilate:
166 self.dilation *= stride
167 stride = 1
168 if stride != 1 or self.inplanes != planes * block.expansion:
169 downsample = nn.Sequential(
170 conv1x1(self.inplanes, planes * block.expansion, stride),
171 nn.BatchNorm2d(
172 planes * block.expansion,
173 eps=1e-05,
174 ),
175 )
176 layers = []
177 layers.append(
178 block(
179 self.inplanes,
180 planes,
181 stride,
182 downsample,
183 self.groups,
184 self.base_width,
185 previous_dilation,
186 )
187 )
188 self.inplanes = planes * block.expansion
189 for _ in range(1, blocks):
190 layers.append(
191 block(
192 self.inplanes,
193 planes,
194 groups=self.groups,
195 base_width=self.base_width,
196 dilation=self.dilation,
197 )
198 )
200 return nn.Sequential(*layers)
202 def forward(self, x):
203 with torch.cuda.amp.autocast(self.fp16):
204 x = self.conv1(x)
205 x = self.bn1(x)
206 x = self.prelu(x)
207 x = self.layer1(x)
208 x = self.layer2(x)
209 x = self.layer3(x)
210 x = self.layer4(x)
211 x = self.bn2(x)
212 x = torch.flatten(x, 1)
213 x = self.dropout(x)
214 x = self.fc(x.float() if self.fp16 else x)
215 x = self.features(x)
216 return x
219def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
220 model = IResNet(block, layers, **kwargs)
221 if pretrained:
222 map_location = (
223 torch.device("cuda")
224 if torch.cuda.is_available()
225 else torch.device("cpu")
226 )
227 state_dict = torch.load(pretrained, map_location=map_location)
228 model.load_state_dict(state_dict)
230 return model
233def iresnet18(pretrained=False, progress=True, **kwargs):
234 return _iresnet(
235 "iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs
236 )
239def iresnet34(pretrained=False, progress=True, **kwargs):
240 return _iresnet(
241 "iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs
242 )
245def iresnet50(pretrained=False, progress=True, **kwargs):
246 return _iresnet(
247 "iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs
248 )
251def iresnet100(pretrained=False, progress=True, **kwargs):
252 return _iresnet(
253 "iresnet100",
254 IBasicBlock,
255 [3, 13, 30, 3],
256 pretrained,
257 progress,
258 **kwargs,
259 )
262def iresnet200(pretrained=False, progress=True, **kwargs):
263 return _iresnet(
264 "iresnet200",
265 IBasicBlock,
266 [6, 26, 60, 6],
267 pretrained,
268 progress,
269 **kwargs,
270 )