Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/ablation.py: 36%
45 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2# Created by: Hang Zhang
3# Email: zhanghang0704@gmail.com
4# Copyright (c) 2020
5#
6# LICENSE file in the root directory of this source tree
7# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8"""ResNeSt ablation study models"""
10import torch
12from .resnet import Bottleneck, ResNet
14__all__ = [
15 "resnest50_fast_1s1x64d",
16 "resnest50_fast_2s1x64d",
17 "resnest50_fast_4s1x64d",
18 "resnest50_fast_1s2x40d",
19 "resnest50_fast_2s2x40d",
20 "resnest50_fast_4s2x40d",
21 "resnest50_fast_1s4x24d",
22]
24_url_format = "https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth"
26_model_sha256 = {
27 name: checksum
28 for checksum, name in [
29 ("d8fbf808", "resnest50_fast_1s1x64d"),
30 ("44938639", "resnest50_fast_2s1x64d"),
31 ("f74f3fc3", "resnest50_fast_4s1x64d"),
32 ("32830b84", "resnest50_fast_1s2x40d"),
33 ("9d126481", "resnest50_fast_2s2x40d"),
34 ("41d14ed0", "resnest50_fast_4s2x40d"),
35 ("d4a4f76f", "resnest50_fast_1s4x24d"),
36 ]
37}
40def short_hash(name):
41 if name not in _model_sha256:
42 raise ValueError(
43 "Pretrained model for {name} is not available.".format(name=name)
44 )
45 return _model_sha256[name][:8]
48resnest_model_urls = {
49 name: _url_format.format(name, short_hash(name))
50 for name in _model_sha256.keys()
51}
54def resnest50_fast_1s1x64d(
55 pretrained=False, root="~/.encoding/models", **kwargs
56):
57 model = ResNet(
58 Bottleneck,
59 [3, 4, 6, 3],
60 radix=1,
61 groups=1,
62 bottleneck_width=64,
63 deep_stem=True,
64 stem_width=32,
65 avg_down=True,
66 avd=True,
67 avd_first=True,
68 **kwargs,
69 )
70 if pretrained:
71 model.load_state_dict(
72 torch.hub.load_state_dict_from_url(
73 resnest_model_urls["resnest50_fast_1s1x64d"],
74 progress=True,
75 check_hash=True,
76 )
77 )
78 return model
81def resnest50_fast_2s1x64d(
82 pretrained=False, root="~/.encoding/models", **kwargs
83):
84 model = ResNet(
85 Bottleneck,
86 [3, 4, 6, 3],
87 radix=2,
88 groups=1,
89 bottleneck_width=64,
90 deep_stem=True,
91 stem_width=32,
92 avg_down=True,
93 avd=True,
94 avd_first=True,
95 **kwargs,
96 )
97 if pretrained:
98 model.load_state_dict(
99 torch.hub.load_state_dict_from_url(
100 resnest_model_urls["resnest50_fast_2s1x64d"],
101 progress=True,
102 check_hash=True,
103 )
104 )
105 return model
108def resnest50_fast_4s1x64d(
109 pretrained=False, root="~/.encoding/models", **kwargs
110):
111 model = ResNet(
112 Bottleneck,
113 [3, 4, 6, 3],
114 radix=4,
115 groups=1,
116 bottleneck_width=64,
117 deep_stem=True,
118 stem_width=32,
119 avg_down=True,
120 avd=True,
121 avd_first=True,
122 **kwargs,
123 )
124 if pretrained:
125 model.load_state_dict(
126 torch.hub.load_state_dict_from_url(
127 resnest_model_urls["resnest50_fast_4s1x64d"],
128 progress=True,
129 check_hash=True,
130 )
131 )
132 return model
135def resnest50_fast_1s2x40d(
136 pretrained=False, root="~/.encoding/models", **kwargs
137):
138 model = ResNet(
139 Bottleneck,
140 [3, 4, 6, 3],
141 radix=1,
142 groups=2,
143 bottleneck_width=40,
144 deep_stem=True,
145 stem_width=32,
146 avg_down=True,
147 avd=True,
148 avd_first=True,
149 **kwargs,
150 )
151 if pretrained:
152 model.load_state_dict(
153 torch.hub.load_state_dict_from_url(
154 resnest_model_urls["resnest50_fast_1s2x40d"],
155 progress=True,
156 check_hash=True,
157 )
158 )
159 return model
162def resnest50_fast_2s2x40d(
163 pretrained=False, root="~/.encoding/models", **kwargs
164):
165 model = ResNet(
166 Bottleneck,
167 [3, 4, 6, 3],
168 radix=2,
169 groups=2,
170 bottleneck_width=40,
171 deep_stem=True,
172 stem_width=32,
173 avg_down=True,
174 avd=True,
175 avd_first=True,
176 **kwargs,
177 )
178 if pretrained:
179 model.load_state_dict(
180 torch.hub.load_state_dict_from_url(
181 resnest_model_urls["resnest50_fast_2s2x40d"],
182 progress=True,
183 check_hash=True,
184 )
185 )
186 return model
189def resnest50_fast_4s2x40d(
190 pretrained=False, root="~/.encoding/models", **kwargs
191):
192 model = ResNet(
193 Bottleneck,
194 [3, 4, 6, 3],
195 radix=4,
196 groups=2,
197 bottleneck_width=40,
198 deep_stem=True,
199 stem_width=32,
200 avg_down=True,
201 avd=True,
202 avd_first=True,
203 **kwargs,
204 )
205 if pretrained:
206 model.load_state_dict(
207 torch.hub.load_state_dict_from_url(
208 resnest_model_urls["resnest50_fast_4s2x40d"],
209 progress=True,
210 check_hash=True,
211 )
212 )
213 return model
216def resnest50_fast_1s4x24d(
217 pretrained=False, root="~/.encoding/models", **kwargs
218):
219 model = ResNet(
220 Bottleneck,
221 [3, 4, 6, 3],
222 radix=1,
223 groups=4,
224 bottleneck_width=24,
225 deep_stem=True,
226 stem_width=32,
227 avg_down=True,
228 avd=True,
229 avd_first=True,
230 **kwargs,
231 )
232 if pretrained:
233 model.load_state_dict(
234 torch.hub.load_state_dict_from_url(
235 resnest_model_urls["resnest50_fast_1s4x24d"],
236 progress=True,
237 check_hash=True,
238 )
239 )
240 return model