Coverage for src/bob/bio/face/pytorch/facexzoo/MobileFaceNets.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Jun Wang 

3@date: 20201019 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py 

9 

10from torch.nn import ( 

11 BatchNorm1d, 

12 BatchNorm2d, 

13 Conv2d, 

14 Linear, 

15 Module, 

16 PReLU, 

17 Sequential, 

18) 

19 

20 

21class Flatten(Module): 

22 def forward(self, input): 

23 return input.view(input.size(0), -1) 

24 

25 

26class Conv_block(Module): 

27 def __init__( 

28 self, 

29 in_c, 

30 out_c, 

31 kernel=(1, 1), 

32 stride=(1, 1), 

33 padding=(0, 0), 

34 groups=1, 

35 ): 

36 super(Conv_block, self).__init__() 

37 self.conv = Conv2d( 

38 in_c, 

39 out_channels=out_c, 

40 kernel_size=kernel, 

41 groups=groups, 

42 stride=stride, 

43 padding=padding, 

44 bias=False, 

45 ) 

46 self.bn = BatchNorm2d(out_c) 

47 self.prelu = PReLU(out_c) 

48 

49 def forward(self, x): 

50 x = self.conv(x) 

51 x = self.bn(x) 

52 x = self.prelu(x) 

53 return x 

54 

55 

56class Linear_block(Module): 

57 def __init__( 

58 self, 

59 in_c, 

60 out_c, 

61 kernel=(1, 1), 

62 stride=(1, 1), 

63 padding=(0, 0), 

64 groups=1, 

65 ): 

66 super(Linear_block, self).__init__() 

67 self.conv = Conv2d( 

68 in_c, 

69 out_channels=out_c, 

70 kernel_size=kernel, 

71 groups=groups, 

72 stride=stride, 

73 padding=padding, 

74 bias=False, 

75 ) 

76 self.bn = BatchNorm2d(out_c) 

77 

78 def forward(self, x): 

79 x = self.conv(x) 

80 x = self.bn(x) 

81 return x 

82 

83 

84class Depth_Wise(Module): 

85 def __init__( 

86 self, 

87 in_c, 

88 out_c, 

89 residual=False, 

90 kernel=(3, 3), 

91 stride=(2, 2), 

92 padding=(1, 1), 

93 groups=1, 

94 ): 

95 super(Depth_Wise, self).__init__() 

96 self.conv = Conv_block( 

97 in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1) 

98 ) 

99 self.conv_dw = Conv_block( 

100 groups, 

101 groups, 

102 groups=groups, 

103 kernel=kernel, 

104 padding=padding, 

105 stride=stride, 

106 ) 

107 self.project = Linear_block( 

108 groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1) 

109 ) 

110 self.residual = residual 

111 

112 def forward(self, x): 

113 if self.residual: 

114 short_cut = x 

115 x = self.conv(x) 

116 x = self.conv_dw(x) 

117 x = self.project(x) 

118 if self.residual: 

119 output = short_cut + x 

120 else: 

121 output = x 

122 return output 

123 

124 

125class Residual(Module): 

126 def __init__( 

127 self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1) 

128 ): 

129 super(Residual, self).__init__() 

130 modules = [] 

131 for _ in range(num_block): 

132 modules.append( 

133 Depth_Wise( 

134 c, 

135 c, 

136 residual=True, 

137 kernel=kernel, 

138 padding=padding, 

139 stride=stride, 

140 groups=groups, 

141 ) 

142 ) 

143 self.model = Sequential(*modules) 

144 

145 def forward(self, x): 

146 return self.model(x) 

147 

148 

149class MobileFaceNet(Module): 

150 def __init__(self, embedding_size, out_h, out_w): 

151 super(MobileFaceNet, self).__init__() 

152 self.conv1 = Conv_block( 

153 3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1) 

154 ) 

155 self.conv2_dw = Conv_block( 

156 64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64 

157 ) 

158 self.conv_23 = Depth_Wise( 

159 64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128 

160 ) 

161 self.conv_3 = Residual( 

162 64, 

163 num_block=4, 

164 groups=128, 

165 kernel=(3, 3), 

166 stride=(1, 1), 

167 padding=(1, 1), 

168 ) 

169 self.conv_34 = Depth_Wise( 

170 64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256 

171 ) 

172 self.conv_4 = Residual( 

173 128, 

174 num_block=6, 

175 groups=256, 

176 kernel=(3, 3), 

177 stride=(1, 1), 

178 padding=(1, 1), 

179 ) 

180 self.conv_45 = Depth_Wise( 

181 128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512 

182 ) 

183 self.conv_5 = Residual( 

184 128, 

185 num_block=2, 

186 groups=256, 

187 kernel=(3, 3), 

188 stride=(1, 1), 

189 padding=(1, 1), 

190 ) 

191 self.conv_6_sep = Conv_block( 

192 128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0) 

193 ) 

194 # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0)) 

195 # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(4,7), stride=(1, 1), padding=(0, 0)) 

196 self.conv_6_dw = Linear_block( 

197 512, 

198 512, 

199 groups=512, 

200 kernel=(out_h, out_w), 

201 stride=(1, 1), 

202 padding=(0, 0), 

203 ) 

204 self.conv_6_flatten = Flatten() 

205 self.linear = Linear(512, embedding_size, bias=False) 

206 self.bn = BatchNorm1d(embedding_size) 

207 

208 def forward(self, x): 

209 out = self.conv1(x) 

210 out = self.conv2_dw(out) 

211 out = self.conv_23(out) 

212 out = self.conv_3(out) 

213 out = self.conv_34(out) 

214 out = self.conv_4(out) 

215 out = self.conv_45(out) 

216 out = self.conv_5(out) 

217 out = self.conv_6_sep(out) 

218 out = self.conv_6_dw(out) 

219 out = self.conv_6_flatten(out) 

220 out = self.linear(out) 

221 out = self.bn(out) 

222 return out