Coverage for src/bob/bio/face/pytorch/facexzoo/ReXNets.py: 86%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Jun Wang 

3@date: 20210322 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# https://github.com/clovaai/rexnet/blob/master/rexnetv1.py 

9""" 

10ReXNet 

11Copyright (c) 2020-present NAVER Corp. 

12MIT license 

13""" 

14 

15from math import ceil 

16 

17import torch 

18import torch.nn as nn 

19 

20 

21class Flatten(nn.Module): 

22 def forward(self, input): 

23 return input.view(input.size(0), -1) 

24 

25 

26# Memory-efficient Siwsh using torch.jit.script borrowed from the code in (https://twitter.com/jeremyphoward/status/1188251041835315200) 

27# Currently use memory-efficient Swish as default: 

28USE_MEMORY_EFFICIENT_SWISH = True 

29 

30if USE_MEMORY_EFFICIENT_SWISH: 

31 

32 @torch.jit.script 

33 def swish_fwd(x): 

34 return x.mul(torch.sigmoid(x)) 

35 

36 @torch.jit.script 

37 def swish_bwd(x, grad_output): 

38 x_sigmoid = torch.sigmoid(x) 

39 return grad_output * (x_sigmoid * (1.0 + x * (1.0 - x_sigmoid))) 

40 

41 class SwishJitImplementation(torch.autograd.Function): 

42 @staticmethod 

43 def forward(ctx, x): 

44 ctx.save_for_backward(x) 

45 return swish_fwd(x) 

46 

47 @staticmethod 

48 def backward(ctx, grad_output): 

49 x = ctx.saved_tensors[0] 

50 return swish_bwd(x, grad_output) 

51 

52 def swish(x, inplace=False): 

53 return SwishJitImplementation.apply(x) 

54 

55else: 

56 

57 def swish(x, inplace=False): 

58 return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 

59 

60 

61class Swish(nn.Module): 

62 def __init__(self, inplace=True): 

63 super(Swish, self).__init__() 

64 self.inplace = inplace 

65 

66 def forward(self, x): 

67 return swish(x, self.inplace) 

68 

69 

70def ConvBNAct( 

71 out, 

72 in_channels, 

73 channels, 

74 kernel=1, 

75 stride=1, 

76 pad=0, 

77 num_group=1, 

78 active=True, 

79 relu6=False, 

80): 

81 out.append( 

82 nn.Conv2d( 

83 in_channels, 

84 channels, 

85 kernel, 

86 stride, 

87 pad, 

88 groups=num_group, 

89 bias=False, 

90 ) 

91 ) 

92 out.append(nn.BatchNorm2d(channels)) 

93 if active: 

94 out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True)) 

95 

96 

97def ConvBNSwish( 

98 out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1 

99): 

100 out.append( 

101 nn.Conv2d( 

102 in_channels, 

103 channels, 

104 kernel, 

105 stride, 

106 pad, 

107 groups=num_group, 

108 bias=False, 

109 ) 

110 ) 

111 out.append(nn.BatchNorm2d(channels)) 

112 out.append(Swish()) 

113 

114 

115class SE(nn.Module): 

116 def __init__(self, in_channels, channels, se_ratio=12): 

117 super(SE, self).__init__() 

118 self.avg_pool = nn.AdaptiveAvgPool2d(1) 

119 self.fc = nn.Sequential( 

120 nn.Conv2d( 

121 in_channels, channels // se_ratio, kernel_size=1, padding=0 

122 ), 

123 nn.BatchNorm2d(channels // se_ratio), 

124 nn.ReLU(inplace=True), 

125 nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0), 

126 nn.Sigmoid(), 

127 ) 

128 

129 def forward(self, x): 

130 y = self.avg_pool(x) 

131 y = self.fc(y) 

132 return x * y 

133 

134 

135class LinearBottleneck(nn.Module): 

136 def __init__( 

137 self, 

138 in_channels, 

139 channels, 

140 t, 

141 stride, 

142 use_se=True, 

143 se_ratio=12, 

144 **kwargs, 

145 ): 

146 super(LinearBottleneck, self).__init__(**kwargs) 

147 self.use_shortcut = stride == 1 and in_channels <= channels 

148 self.in_channels = in_channels 

149 self.out_channels = channels 

150 

151 out = [] 

152 if t != 1: 

153 dw_channels = in_channels * t 

154 ConvBNSwish(out, in_channels=in_channels, channels=dw_channels) 

155 else: 

156 dw_channels = in_channels 

157 

158 ConvBNAct( 

159 out, 

160 in_channels=dw_channels, 

161 channels=dw_channels, 

162 kernel=3, 

163 stride=stride, 

164 pad=1, 

165 num_group=dw_channels, 

166 active=False, 

167 ) 

168 

169 if use_se: 

170 out.append(SE(dw_channels, dw_channels, se_ratio)) 

171 

172 out.append(nn.ReLU6()) 

173 ConvBNAct( 

174 out, 

175 in_channels=dw_channels, 

176 channels=channels, 

177 active=False, 

178 relu6=True, 

179 ) 

180 self.out = nn.Sequential(*out) 

181 

182 def forward(self, x): 

183 out = self.out(x) 

184 if self.use_shortcut: 

185 out[:, 0 : self.in_channels] += x 

186 

187 return out 

188 

189 

190class ReXNetV1(nn.Module): 

191 def __init__( 

192 self, 

193 input_ch=16, 

194 final_ch=180, 

195 width_mult=1.0, 

196 depth_mult=1.0, 

197 use_se=True, 

198 se_ratio=12, 

199 out_h=7, 

200 out_w=7, 

201 feat_dim=512, 

202 dropout_ratio=0.2, 

203 bn_momentum=0.9, 

204 ): 

205 super(ReXNetV1, self).__init__() 

206 

207 layers = [1, 2, 2, 3, 3, 5] 

208 strides = [1, 2, 2, 2, 1, 2] 

209 use_ses = [False, False, True, True, True, True] 

210 

211 layers = [ceil(element * depth_mult) for element in layers] 

212 strides = sum( 

213 [ 

214 [element] + [1] * (layers[idx] - 1) 

215 for idx, element in enumerate(strides) 

216 ], 

217 [], 

218 ) 

219 if use_se: 

220 use_ses = sum( 

221 [ 

222 [element] * layers[idx] 

223 for idx, element in enumerate(use_ses) 

224 ], 

225 [], 

226 ) 

227 else: 

228 use_ses = [False] * sum(layers[:]) 

229 ts = [1] * layers[0] + [6] * sum(layers[1:]) 

230 

231 self.depth = sum(layers[:]) * 3 

232 stem_channel = 32 / width_mult if width_mult < 1.0 else 32 

233 inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch 

234 

235 features = [] 

236 in_channels_group = [] 

237 channels_group = [] 

238 

239 # The following channel configuration is a simple instance to make each layer become an expand layer. 

240 for i in range(self.depth // 3): 

241 if i == 0: 

242 in_channels_group.append(int(round(stem_channel * width_mult))) 

243 channels_group.append(int(round(inplanes * width_mult))) 

244 else: 

245 in_channels_group.append(int(round(inplanes * width_mult))) 

246 inplanes += final_ch / (self.depth // 3 * 1.0) 

247 channels_group.append(int(round(inplanes * width_mult))) 

248 

249 # ConvBNSwish(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1) 

250 ConvBNSwish( 

251 features, 

252 3, 

253 int(round(stem_channel * width_mult)), 

254 kernel=3, 

255 stride=1, 

256 pad=1, 

257 ) 

258 

259 for block_idx, (in_c, c, t, s, se) in enumerate( 

260 zip(in_channels_group, channels_group, ts, strides, use_ses) 

261 ): 

262 features.append( 

263 LinearBottleneck( 

264 in_channels=in_c, 

265 channels=c, 

266 t=t, 

267 stride=s, 

268 use_se=se, 

269 se_ratio=se_ratio, 

270 ) 

271 ) 

272 

273 # pen_channels = int(1280 * width_mult) 

274 pen_channels = int(512 * width_mult) 

275 ConvBNSwish(features, c, pen_channels) 

276 

277 # features.append(nn.AdaptiveAvgPool2d(1)) 

278 self.features = nn.Sequential(*features) 

279 self.output_layer = nn.Sequential( 

280 nn.BatchNorm2d(512), 

281 nn.Dropout(dropout_ratio), 

282 Flatten(), 

283 nn.Linear(512 * out_h * out_w, feat_dim), 

284 nn.BatchNorm1d(feat_dim), 

285 ) 

286 

287 def forward(self, x): 

288 x = self.features(x) 

289 x = self.output_layer(x) 

290 return x