Coverage for src/bob/bio/face/pytorch/backbones/iresnet.py: 0%

117 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import torch 

2 

3from torch import nn 

4 

5__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"] 

6 

7 

8def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 

9 """3x3 convolution with padding""" 

10 return nn.Conv2d( 

11 in_planes, 

12 out_planes, 

13 kernel_size=3, 

14 stride=stride, 

15 padding=dilation, 

16 groups=groups, 

17 bias=False, 

18 dilation=dilation, 

19 ) 

20 

21 

22def conv1x1(in_planes, out_planes, stride=1): 

23 """1x1 convolution""" 

24 return nn.Conv2d( 

25 in_planes, out_planes, kernel_size=1, stride=stride, bias=False 

26 ) 

27 

28 

29class IBasicBlock(nn.Module): 

30 expansion = 1 

31 

32 def __init__( 

33 self, 

34 inplanes, 

35 planes, 

36 stride=1, 

37 downsample=None, 

38 groups=1, 

39 base_width=64, 

40 dilation=1, 

41 ): 

42 super(IBasicBlock, self).__init__() 

43 if groups != 1 or base_width != 64: 

44 raise ValueError( 

45 "BasicBlock only supports groups=1 and base_width=64" 

46 ) 

47 if dilation > 1: 

48 raise NotImplementedError( 

49 "Dilation > 1 not supported in BasicBlock" 

50 ) 

51 self.bn1 = nn.BatchNorm2d( 

52 inplanes, 

53 eps=1e-05, 

54 ) 

55 self.conv1 = conv3x3(inplanes, planes) 

56 self.bn2 = nn.BatchNorm2d( 

57 planes, 

58 eps=1e-05, 

59 ) 

60 self.prelu = nn.PReLU(planes) 

61 self.conv2 = conv3x3(planes, planes, stride) 

62 self.bn3 = nn.BatchNorm2d( 

63 planes, 

64 eps=1e-05, 

65 ) 

66 self.downsample = downsample 

67 self.stride = stride 

68 

69 def forward(self, x): 

70 identity = x 

71 out = self.bn1(x) 

72 out = self.conv1(out) 

73 out = self.bn2(out) 

74 out = self.prelu(out) 

75 out = self.conv2(out) 

76 out = self.bn3(out) 

77 if self.downsample is not None: 

78 identity = self.downsample(x) 

79 out += identity 

80 return out 

81 

82 

83class IResNet(nn.Module): 

84 fc_scale = 7 * 7 

85 

86 def __init__( 

87 self, 

88 block, 

89 layers, 

90 dropout=0, 

91 num_features=512, 

92 zero_init_residual=False, 

93 groups=1, 

94 width_per_group=64, 

95 replace_stride_with_dilation=None, 

96 fp16=False, 

97 ): 

98 super(IResNet, self).__init__() 

99 self.fp16 = fp16 

100 self.inplanes = 64 

101 self.dilation = 1 

102 if replace_stride_with_dilation is None: 

103 replace_stride_with_dilation = [False, False, False] 

104 if len(replace_stride_with_dilation) != 3: 

105 raise ValueError( 

106 "replace_stride_with_dilation should be None " 

107 "or a 3-element tuple, got {}".format( 

108 replace_stride_with_dilation 

109 ) 

110 ) 

111 self.groups = groups 

112 self.base_width = width_per_group 

113 self.conv1 = nn.Conv2d( 

114 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 

115 ) 

116 self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 

117 self.prelu = nn.PReLU(self.inplanes) 

118 self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 

119 self.layer2 = self._make_layer( 

120 block, 

121 128, 

122 layers[1], 

123 stride=2, 

124 dilate=replace_stride_with_dilation[0], 

125 ) 

126 self.layer3 = self._make_layer( 

127 block, 

128 256, 

129 layers[2], 

130 stride=2, 

131 dilate=replace_stride_with_dilation[1], 

132 ) 

133 self.layer4 = self._make_layer( 

134 block, 

135 512, 

136 layers[3], 

137 stride=2, 

138 dilate=replace_stride_with_dilation[2], 

139 ) 

140 self.bn2 = nn.BatchNorm2d( 

141 512 * block.expansion, 

142 eps=1e-05, 

143 ) 

144 self.dropout = nn.Dropout(p=dropout, inplace=True) 

145 self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 

146 self.features = nn.BatchNorm1d(num_features, eps=1e-05) 

147 nn.init.constant_(self.features.weight, 1.0) 

148 self.features.weight.requires_grad = False 

149 

150 for m in self.modules(): 

151 if isinstance(m, nn.Conv2d): 

152 nn.init.normal_(m.weight, 0, 0.1) 

153 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 

154 nn.init.constant_(m.weight, 1) 

155 nn.init.constant_(m.bias, 0) 

156 

157 if zero_init_residual: 

158 for m in self.modules(): 

159 if isinstance(m, IBasicBlock): 

160 nn.init.constant_(m.bn2.weight, 0) 

161 

162 def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 

163 downsample = None 

164 previous_dilation = self.dilation 

165 if dilate: 

166 self.dilation *= stride 

167 stride = 1 

168 if stride != 1 or self.inplanes != planes * block.expansion: 

169 downsample = nn.Sequential( 

170 conv1x1(self.inplanes, planes * block.expansion, stride), 

171 nn.BatchNorm2d( 

172 planes * block.expansion, 

173 eps=1e-05, 

174 ), 

175 ) 

176 layers = [] 

177 layers.append( 

178 block( 

179 self.inplanes, 

180 planes, 

181 stride, 

182 downsample, 

183 self.groups, 

184 self.base_width, 

185 previous_dilation, 

186 ) 

187 ) 

188 self.inplanes = planes * block.expansion 

189 for _ in range(1, blocks): 

190 layers.append( 

191 block( 

192 self.inplanes, 

193 planes, 

194 groups=self.groups, 

195 base_width=self.base_width, 

196 dilation=self.dilation, 

197 ) 

198 ) 

199 

200 return nn.Sequential(*layers) 

201 

202 def forward(self, x): 

203 with torch.cuda.amp.autocast(self.fp16): 

204 x = self.conv1(x) 

205 x = self.bn1(x) 

206 x = self.prelu(x) 

207 x = self.layer1(x) 

208 x = self.layer2(x) 

209 x = self.layer3(x) 

210 x = self.layer4(x) 

211 x = self.bn2(x) 

212 x = torch.flatten(x, 1) 

213 x = self.dropout(x) 

214 x = self.fc(x.float() if self.fp16 else x) 

215 x = self.features(x) 

216 return x 

217 

218 

219def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 

220 model = IResNet(block, layers, **kwargs) 

221 if pretrained: 

222 map_location = ( 

223 torch.device("cuda") 

224 if torch.cuda.is_available() 

225 else torch.device("cpu") 

226 ) 

227 state_dict = torch.load(pretrained, map_location=map_location) 

228 model.load_state_dict(state_dict) 

229 

230 return model 

231 

232 

233def iresnet18(pretrained=False, progress=True, **kwargs): 

234 return _iresnet( 

235 "iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs 

236 ) 

237 

238 

239def iresnet34(pretrained=False, progress=True, **kwargs): 

240 return _iresnet( 

241 "iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs 

242 ) 

243 

244 

245def iresnet50(pretrained=False, progress=True, **kwargs): 

246 return _iresnet( 

247 "iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs 

248 ) 

249 

250 

251def iresnet100(pretrained=False, progress=True, **kwargs): 

252 return _iresnet( 

253 "iresnet100", 

254 IBasicBlock, 

255 [3, 13, 30, 3], 

256 pretrained, 

257 progress, 

258 **kwargs, 

259 ) 

260 

261 

262def iresnet200(pretrained=False, progress=True, **kwargs): 

263 return _iresnet( 

264 "iresnet200", 

265 IBasicBlock, 

266 [6, 26, 60, 6], 

267 pretrained, 

268 progress, 

269 **kwargs, 

270 )