Coverage for src/bob/bio/face/pytorch/facexzoo/MobileFaceNets.py: 100%
81 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Jun Wang
3@date: 20201019
4@contact: jun21wangustc@gmail.com
5"""
7# based on:
8# https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py
10from torch.nn import (
11 BatchNorm1d,
12 BatchNorm2d,
13 Conv2d,
14 Linear,
15 Module,
16 PReLU,
17 Sequential,
18)
21class Flatten(Module):
22 def forward(self, input):
23 return input.view(input.size(0), -1)
26class Conv_block(Module):
27 def __init__(
28 self,
29 in_c,
30 out_c,
31 kernel=(1, 1),
32 stride=(1, 1),
33 padding=(0, 0),
34 groups=1,
35 ):
36 super(Conv_block, self).__init__()
37 self.conv = Conv2d(
38 in_c,
39 out_channels=out_c,
40 kernel_size=kernel,
41 groups=groups,
42 stride=stride,
43 padding=padding,
44 bias=False,
45 )
46 self.bn = BatchNorm2d(out_c)
47 self.prelu = PReLU(out_c)
49 def forward(self, x):
50 x = self.conv(x)
51 x = self.bn(x)
52 x = self.prelu(x)
53 return x
56class Linear_block(Module):
57 def __init__(
58 self,
59 in_c,
60 out_c,
61 kernel=(1, 1),
62 stride=(1, 1),
63 padding=(0, 0),
64 groups=1,
65 ):
66 super(Linear_block, self).__init__()
67 self.conv = Conv2d(
68 in_c,
69 out_channels=out_c,
70 kernel_size=kernel,
71 groups=groups,
72 stride=stride,
73 padding=padding,
74 bias=False,
75 )
76 self.bn = BatchNorm2d(out_c)
78 def forward(self, x):
79 x = self.conv(x)
80 x = self.bn(x)
81 return x
84class Depth_Wise(Module):
85 def __init__(
86 self,
87 in_c,
88 out_c,
89 residual=False,
90 kernel=(3, 3),
91 stride=(2, 2),
92 padding=(1, 1),
93 groups=1,
94 ):
95 super(Depth_Wise, self).__init__()
96 self.conv = Conv_block(
97 in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
98 )
99 self.conv_dw = Conv_block(
100 groups,
101 groups,
102 groups=groups,
103 kernel=kernel,
104 padding=padding,
105 stride=stride,
106 )
107 self.project = Linear_block(
108 groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
109 )
110 self.residual = residual
112 def forward(self, x):
113 if self.residual:
114 short_cut = x
115 x = self.conv(x)
116 x = self.conv_dw(x)
117 x = self.project(x)
118 if self.residual:
119 output = short_cut + x
120 else:
121 output = x
122 return output
125class Residual(Module):
126 def __init__(
127 self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
128 ):
129 super(Residual, self).__init__()
130 modules = []
131 for _ in range(num_block):
132 modules.append(
133 Depth_Wise(
134 c,
135 c,
136 residual=True,
137 kernel=kernel,
138 padding=padding,
139 stride=stride,
140 groups=groups,
141 )
142 )
143 self.model = Sequential(*modules)
145 def forward(self, x):
146 return self.model(x)
149class MobileFaceNet(Module):
150 def __init__(self, embedding_size, out_h, out_w):
151 super(MobileFaceNet, self).__init__()
152 self.conv1 = Conv_block(
153 3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)
154 )
155 self.conv2_dw = Conv_block(
156 64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
157 )
158 self.conv_23 = Depth_Wise(
159 64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128
160 )
161 self.conv_3 = Residual(
162 64,
163 num_block=4,
164 groups=128,
165 kernel=(3, 3),
166 stride=(1, 1),
167 padding=(1, 1),
168 )
169 self.conv_34 = Depth_Wise(
170 64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256
171 )
172 self.conv_4 = Residual(
173 128,
174 num_block=6,
175 groups=256,
176 kernel=(3, 3),
177 stride=(1, 1),
178 padding=(1, 1),
179 )
180 self.conv_45 = Depth_Wise(
181 128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512
182 )
183 self.conv_5 = Residual(
184 128,
185 num_block=2,
186 groups=256,
187 kernel=(3, 3),
188 stride=(1, 1),
189 padding=(1, 1),
190 )
191 self.conv_6_sep = Conv_block(
192 128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)
193 )
194 # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
195 # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(4,7), stride=(1, 1), padding=(0, 0))
196 self.conv_6_dw = Linear_block(
197 512,
198 512,
199 groups=512,
200 kernel=(out_h, out_w),
201 stride=(1, 1),
202 padding=(0, 0),
203 )
204 self.conv_6_flatten = Flatten()
205 self.linear = Linear(512, embedding_size, bias=False)
206 self.bn = BatchNorm1d(embedding_size)
208 def forward(self, x):
209 out = self.conv1(x)
210 out = self.conv2_dw(out)
211 out = self.conv_23(out)
212 out = self.conv_3(out)
213 out = self.conv_34(out)
214 out = self.conv_4(out)
215 out = self.conv_45(out)
216 out = self.conv_5(out)
217 out = self.conv_6_sep(out)
218 out = self.conv_6_dw(out)
219 out = self.conv_6_flatten(out)
220 out = self.linear(out)
221 out = self.bn(out)
222 return out