Coverage for src/bob/bio/face/pytorch/facexzoo/ReXNets.py: 86%
111 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""
2@author: Jun Wang
3@date: 20210322
4@contact: jun21wangustc@gmail.com
5"""
7# based on:
8# https://github.com/clovaai/rexnet/blob/master/rexnetv1.py
9"""
10ReXNet
11Copyright (c) 2020-present NAVER Corp.
12MIT license
13"""
15from math import ceil
17import torch
18import torch.nn as nn
21class Flatten(nn.Module):
22 def forward(self, input):
23 return input.view(input.size(0), -1)
26# Memory-efficient Siwsh using torch.jit.script borrowed from the code in (https://twitter.com/jeremyphoward/status/1188251041835315200)
27# Currently use memory-efficient Swish as default:
28USE_MEMORY_EFFICIENT_SWISH = True
30if USE_MEMORY_EFFICIENT_SWISH:
32 @torch.jit.script
33 def swish_fwd(x):
34 return x.mul(torch.sigmoid(x))
36 @torch.jit.script
37 def swish_bwd(x, grad_output):
38 x_sigmoid = torch.sigmoid(x)
39 return grad_output * (x_sigmoid * (1.0 + x * (1.0 - x_sigmoid)))
41 class SwishJitImplementation(torch.autograd.Function):
42 @staticmethod
43 def forward(ctx, x):
44 ctx.save_for_backward(x)
45 return swish_fwd(x)
47 @staticmethod
48 def backward(ctx, grad_output):
49 x = ctx.saved_tensors[0]
50 return swish_bwd(x, grad_output)
52 def swish(x, inplace=False):
53 return SwishJitImplementation.apply(x)
55else:
57 def swish(x, inplace=False):
58 return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
61class Swish(nn.Module):
62 def __init__(self, inplace=True):
63 super(Swish, self).__init__()
64 self.inplace = inplace
66 def forward(self, x):
67 return swish(x, self.inplace)
70def ConvBNAct(
71 out,
72 in_channels,
73 channels,
74 kernel=1,
75 stride=1,
76 pad=0,
77 num_group=1,
78 active=True,
79 relu6=False,
80):
81 out.append(
82 nn.Conv2d(
83 in_channels,
84 channels,
85 kernel,
86 stride,
87 pad,
88 groups=num_group,
89 bias=False,
90 )
91 )
92 out.append(nn.BatchNorm2d(channels))
93 if active:
94 out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))
97def ConvBNSwish(
98 out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1
99):
100 out.append(
101 nn.Conv2d(
102 in_channels,
103 channels,
104 kernel,
105 stride,
106 pad,
107 groups=num_group,
108 bias=False,
109 )
110 )
111 out.append(nn.BatchNorm2d(channels))
112 out.append(Swish())
115class SE(nn.Module):
116 def __init__(self, in_channels, channels, se_ratio=12):
117 super(SE, self).__init__()
118 self.avg_pool = nn.AdaptiveAvgPool2d(1)
119 self.fc = nn.Sequential(
120 nn.Conv2d(
121 in_channels, channels // se_ratio, kernel_size=1, padding=0
122 ),
123 nn.BatchNorm2d(channels // se_ratio),
124 nn.ReLU(inplace=True),
125 nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0),
126 nn.Sigmoid(),
127 )
129 def forward(self, x):
130 y = self.avg_pool(x)
131 y = self.fc(y)
132 return x * y
135class LinearBottleneck(nn.Module):
136 def __init__(
137 self,
138 in_channels,
139 channels,
140 t,
141 stride,
142 use_se=True,
143 se_ratio=12,
144 **kwargs,
145 ):
146 super(LinearBottleneck, self).__init__(**kwargs)
147 self.use_shortcut = stride == 1 and in_channels <= channels
148 self.in_channels = in_channels
149 self.out_channels = channels
151 out = []
152 if t != 1:
153 dw_channels = in_channels * t
154 ConvBNSwish(out, in_channels=in_channels, channels=dw_channels)
155 else:
156 dw_channels = in_channels
158 ConvBNAct(
159 out,
160 in_channels=dw_channels,
161 channels=dw_channels,
162 kernel=3,
163 stride=stride,
164 pad=1,
165 num_group=dw_channels,
166 active=False,
167 )
169 if use_se:
170 out.append(SE(dw_channels, dw_channels, se_ratio))
172 out.append(nn.ReLU6())
173 ConvBNAct(
174 out,
175 in_channels=dw_channels,
176 channels=channels,
177 active=False,
178 relu6=True,
179 )
180 self.out = nn.Sequential(*out)
182 def forward(self, x):
183 out = self.out(x)
184 if self.use_shortcut:
185 out[:, 0 : self.in_channels] += x
187 return out
190class ReXNetV1(nn.Module):
191 def __init__(
192 self,
193 input_ch=16,
194 final_ch=180,
195 width_mult=1.0,
196 depth_mult=1.0,
197 use_se=True,
198 se_ratio=12,
199 out_h=7,
200 out_w=7,
201 feat_dim=512,
202 dropout_ratio=0.2,
203 bn_momentum=0.9,
204 ):
205 super(ReXNetV1, self).__init__()
207 layers = [1, 2, 2, 3, 3, 5]
208 strides = [1, 2, 2, 2, 1, 2]
209 use_ses = [False, False, True, True, True, True]
211 layers = [ceil(element * depth_mult) for element in layers]
212 strides = sum(
213 [
214 [element] + [1] * (layers[idx] - 1)
215 for idx, element in enumerate(strides)
216 ],
217 [],
218 )
219 if use_se:
220 use_ses = sum(
221 [
222 [element] * layers[idx]
223 for idx, element in enumerate(use_ses)
224 ],
225 [],
226 )
227 else:
228 use_ses = [False] * sum(layers[:])
229 ts = [1] * layers[0] + [6] * sum(layers[1:])
231 self.depth = sum(layers[:]) * 3
232 stem_channel = 32 / width_mult if width_mult < 1.0 else 32
233 inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch
235 features = []
236 in_channels_group = []
237 channels_group = []
239 # The following channel configuration is a simple instance to make each layer become an expand layer.
240 for i in range(self.depth // 3):
241 if i == 0:
242 in_channels_group.append(int(round(stem_channel * width_mult)))
243 channels_group.append(int(round(inplanes * width_mult)))
244 else:
245 in_channels_group.append(int(round(inplanes * width_mult)))
246 inplanes += final_ch / (self.depth // 3 * 1.0)
247 channels_group.append(int(round(inplanes * width_mult)))
249 # ConvBNSwish(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1)
250 ConvBNSwish(
251 features,
252 3,
253 int(round(stem_channel * width_mult)),
254 kernel=3,
255 stride=1,
256 pad=1,
257 )
259 for block_idx, (in_c, c, t, s, se) in enumerate(
260 zip(in_channels_group, channels_group, ts, strides, use_ses)
261 ):
262 features.append(
263 LinearBottleneck(
264 in_channels=in_c,
265 channels=c,
266 t=t,
267 stride=s,
268 use_se=se,
269 se_ratio=se_ratio,
270 )
271 )
273 # pen_channels = int(1280 * width_mult)
274 pen_channels = int(512 * width_mult)
275 ConvBNSwish(features, c, pen_channels)
277 # features.append(nn.AdaptiveAvgPool2d(1))
278 self.features = nn.Sequential(*features)
279 self.output_layer = nn.Sequential(
280 nn.BatchNorm2d(512),
281 nn.Dropout(dropout_ratio),
282 Flatten(),
283 nn.Linear(512 * out_h * out_w, feat_dim),
284 nn.BatchNorm1d(feat_dim),
285 )
287 def forward(self, x):
288 x = self.features(x)
289 x = self.output_layer(x)
290 return x