Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/ablation.py: 36%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 

2# Created by: Hang Zhang 

3# Email: zhanghang0704@gmail.com 

4# Copyright (c) 2020 

5# 

6# LICENSE file in the root directory of this source tree 

7# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 

8"""ResNeSt ablation study models""" 

9 

10import torch 

11 

12from .resnet import Bottleneck, ResNet 

13 

14__all__ = [ 

15 "resnest50_fast_1s1x64d", 

16 "resnest50_fast_2s1x64d", 

17 "resnest50_fast_4s1x64d", 

18 "resnest50_fast_1s2x40d", 

19 "resnest50_fast_2s2x40d", 

20 "resnest50_fast_4s2x40d", 

21 "resnest50_fast_1s4x24d", 

22] 

23 

24_url_format = "https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth" 

25 

26_model_sha256 = { 

27 name: checksum 

28 for checksum, name in [ 

29 ("d8fbf808", "resnest50_fast_1s1x64d"), 

30 ("44938639", "resnest50_fast_2s1x64d"), 

31 ("f74f3fc3", "resnest50_fast_4s1x64d"), 

32 ("32830b84", "resnest50_fast_1s2x40d"), 

33 ("9d126481", "resnest50_fast_2s2x40d"), 

34 ("41d14ed0", "resnest50_fast_4s2x40d"), 

35 ("d4a4f76f", "resnest50_fast_1s4x24d"), 

36 ] 

37} 

38 

39 

40def short_hash(name): 

41 if name not in _model_sha256: 

42 raise ValueError( 

43 "Pretrained model for {name} is not available.".format(name=name) 

44 ) 

45 return _model_sha256[name][:8] 

46 

47 

48resnest_model_urls = { 

49 name: _url_format.format(name, short_hash(name)) 

50 for name in _model_sha256.keys() 

51} 

52 

53 

54def resnest50_fast_1s1x64d( 

55 pretrained=False, root="~/.encoding/models", **kwargs 

56): 

57 model = ResNet( 

58 Bottleneck, 

59 [3, 4, 6, 3], 

60 radix=1, 

61 groups=1, 

62 bottleneck_width=64, 

63 deep_stem=True, 

64 stem_width=32, 

65 avg_down=True, 

66 avd=True, 

67 avd_first=True, 

68 **kwargs, 

69 ) 

70 if pretrained: 

71 model.load_state_dict( 

72 torch.hub.load_state_dict_from_url( 

73 resnest_model_urls["resnest50_fast_1s1x64d"], 

74 progress=True, 

75 check_hash=True, 

76 ) 

77 ) 

78 return model 

79 

80 

81def resnest50_fast_2s1x64d( 

82 pretrained=False, root="~/.encoding/models", **kwargs 

83): 

84 model = ResNet( 

85 Bottleneck, 

86 [3, 4, 6, 3], 

87 radix=2, 

88 groups=1, 

89 bottleneck_width=64, 

90 deep_stem=True, 

91 stem_width=32, 

92 avg_down=True, 

93 avd=True, 

94 avd_first=True, 

95 **kwargs, 

96 ) 

97 if pretrained: 

98 model.load_state_dict( 

99 torch.hub.load_state_dict_from_url( 

100 resnest_model_urls["resnest50_fast_2s1x64d"], 

101 progress=True, 

102 check_hash=True, 

103 ) 

104 ) 

105 return model 

106 

107 

108def resnest50_fast_4s1x64d( 

109 pretrained=False, root="~/.encoding/models", **kwargs 

110): 

111 model = ResNet( 

112 Bottleneck, 

113 [3, 4, 6, 3], 

114 radix=4, 

115 groups=1, 

116 bottleneck_width=64, 

117 deep_stem=True, 

118 stem_width=32, 

119 avg_down=True, 

120 avd=True, 

121 avd_first=True, 

122 **kwargs, 

123 ) 

124 if pretrained: 

125 model.load_state_dict( 

126 torch.hub.load_state_dict_from_url( 

127 resnest_model_urls["resnest50_fast_4s1x64d"], 

128 progress=True, 

129 check_hash=True, 

130 ) 

131 ) 

132 return model 

133 

134 

135def resnest50_fast_1s2x40d( 

136 pretrained=False, root="~/.encoding/models", **kwargs 

137): 

138 model = ResNet( 

139 Bottleneck, 

140 [3, 4, 6, 3], 

141 radix=1, 

142 groups=2, 

143 bottleneck_width=40, 

144 deep_stem=True, 

145 stem_width=32, 

146 avg_down=True, 

147 avd=True, 

148 avd_first=True, 

149 **kwargs, 

150 ) 

151 if pretrained: 

152 model.load_state_dict( 

153 torch.hub.load_state_dict_from_url( 

154 resnest_model_urls["resnest50_fast_1s2x40d"], 

155 progress=True, 

156 check_hash=True, 

157 ) 

158 ) 

159 return model 

160 

161 

162def resnest50_fast_2s2x40d( 

163 pretrained=False, root="~/.encoding/models", **kwargs 

164): 

165 model = ResNet( 

166 Bottleneck, 

167 [3, 4, 6, 3], 

168 radix=2, 

169 groups=2, 

170 bottleneck_width=40, 

171 deep_stem=True, 

172 stem_width=32, 

173 avg_down=True, 

174 avd=True, 

175 avd_first=True, 

176 **kwargs, 

177 ) 

178 if pretrained: 

179 model.load_state_dict( 

180 torch.hub.load_state_dict_from_url( 

181 resnest_model_urls["resnest50_fast_2s2x40d"], 

182 progress=True, 

183 check_hash=True, 

184 ) 

185 ) 

186 return model 

187 

188 

189def resnest50_fast_4s2x40d( 

190 pretrained=False, root="~/.encoding/models", **kwargs 

191): 

192 model = ResNet( 

193 Bottleneck, 

194 [3, 4, 6, 3], 

195 radix=4, 

196 groups=2, 

197 bottleneck_width=40, 

198 deep_stem=True, 

199 stem_width=32, 

200 avg_down=True, 

201 avd=True, 

202 avd_first=True, 

203 **kwargs, 

204 ) 

205 if pretrained: 

206 model.load_state_dict( 

207 torch.hub.load_state_dict_from_url( 

208 resnest_model_urls["resnest50_fast_4s2x40d"], 

209 progress=True, 

210 check_hash=True, 

211 ) 

212 ) 

213 return model 

214 

215 

216def resnest50_fast_1s4x24d( 

217 pretrained=False, root="~/.encoding/models", **kwargs 

218): 

219 model = ResNet( 

220 Bottleneck, 

221 [3, 4, 6, 3], 

222 radix=1, 

223 groups=4, 

224 bottleneck_width=24, 

225 deep_stem=True, 

226 stem_width=32, 

227 avg_down=True, 

228 avd=True, 

229 avd_first=True, 

230 **kwargs, 

231 ) 

232 if pretrained: 

233 model.load_state_dict( 

234 torch.hub.load_state_dict_from_url( 

235 resnest_model_urls["resnest50_fast_1s4x24d"], 

236 progress=True, 

237 check_hash=True, 

238 ) 

239 ) 

240 return model