Coverage for src/bob/bio/face/pytorch/facexzoo/AttentionNets.py: 100%

152 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Jun Wang 

3@date: 20201019 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# https://github.com/tengshaofeng/ResidualAttentionNetwork-pytorch/tree/master/Residual-Attention-Network/model 

9 

10import torch.nn as nn 

11 

12 

13class Flatten(nn.Module): 

14 def forward(self, x): 

15 return x.view(x.size(0), -1) 

16 

17 

18class ResidualBlock(nn.Module): 

19 def __init__(self, input_channels, output_channels, stride=1): 

20 super(ResidualBlock, self).__init__() 

21 self.input_channels = input_channels 

22 self.output_channels = output_channels 

23 self.stride = stride 

24 self.bn1 = nn.BatchNorm2d(input_channels) 

25 self.relu = nn.ReLU(inplace=True) 

26 self.conv1 = nn.Conv2d( 

27 input_channels, output_channels // 4, 1, 1, bias=False 

28 ) 

29 self.bn2 = nn.BatchNorm2d(output_channels // 4) 

30 self.relu = nn.ReLU(inplace=True) 

31 self.conv2 = nn.Conv2d( 

32 output_channels // 4, 

33 output_channels // 4, 

34 3, 

35 stride, 

36 padding=1, 

37 bias=False, 

38 ) 

39 self.bn3 = nn.BatchNorm2d(output_channels // 4) 

40 self.relu = nn.ReLU(inplace=True) 

41 self.conv3 = nn.Conv2d( 

42 output_channels // 4, output_channels, 1, 1, bias=False 

43 ) 

44 self.conv4 = nn.Conv2d( 

45 input_channels, output_channels, 1, stride, bias=False 

46 ) 

47 

48 def forward(self, x): 

49 residual = x 

50 out = self.bn1(x) 

51 out1 = self.relu(out) 

52 out = self.conv1(out1) 

53 out = self.bn2(out) 

54 out = self.relu(out) 

55 out = self.conv2(out) 

56 out = self.bn3(out) 

57 out = self.relu(out) 

58 out = self.conv3(out) 

59 if (self.input_channels != self.output_channels) or (self.stride != 1): 

60 residual = self.conv4(out1) 

61 out += residual 

62 return out 

63 

64 

65class AttentionModule_stage1(nn.Module): 

66 # input size is 56*56 

67 def __init__( 

68 self, 

69 in_channels, 

70 out_channels, 

71 size1=(56, 56), 

72 size2=(28, 28), 

73 size3=(14, 14), 

74 ): 

75 super(AttentionModule_stage1, self).__init__() 

76 self.first_residual_blocks = ResidualBlock(in_channels, out_channels) 

77 self.trunk_branches = nn.Sequential( 

78 ResidualBlock(in_channels, out_channels), 

79 ResidualBlock(in_channels, out_channels), 

80 ) 

81 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

82 self.softmax1_blocks = ResidualBlock(in_channels, out_channels) 

83 self.skip1_connection_residual_block = ResidualBlock( 

84 in_channels, out_channels 

85 ) 

86 self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

87 self.softmax2_blocks = ResidualBlock(in_channels, out_channels) 

88 self.skip2_connection_residual_block = ResidualBlock( 

89 in_channels, out_channels 

90 ) 

91 self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

92 self.softmax3_blocks = nn.Sequential( 

93 ResidualBlock(in_channels, out_channels), 

94 ResidualBlock(in_channels, out_channels), 

95 ) 

96 self.interpolation3 = nn.UpsamplingBilinear2d(size=size3) 

97 self.softmax4_blocks = ResidualBlock(in_channels, out_channels) 

98 self.interpolation2 = nn.UpsamplingBilinear2d(size=size2) 

99 self.softmax5_blocks = ResidualBlock(in_channels, out_channels) 

100 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 

101 self.softmax6_blocks = nn.Sequential( 

102 nn.BatchNorm2d(out_channels), 

103 nn.ReLU(inplace=True), 

104 nn.Conv2d( 

105 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

106 ), 

107 nn.BatchNorm2d(out_channels), 

108 nn.ReLU(inplace=True), 

109 nn.Conv2d( 

110 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

111 ), 

112 nn.Sigmoid(), 

113 ) 

114 self.last_blocks = ResidualBlock(in_channels, out_channels) 

115 

116 def forward(self, x): 

117 x = self.first_residual_blocks(x) 

118 out_trunk = self.trunk_branches(x) 

119 out_mpool1 = self.mpool1(x) 

120 out_softmax1 = self.softmax1_blocks(out_mpool1) 

121 out_skip1_connection = self.skip1_connection_residual_block( 

122 out_softmax1 

123 ) 

124 out_mpool2 = self.mpool2(out_softmax1) 

125 out_softmax2 = self.softmax2_blocks(out_mpool2) 

126 out_skip2_connection = self.skip2_connection_residual_block( 

127 out_softmax2 

128 ) 

129 out_mpool3 = self.mpool3(out_softmax2) 

130 out_softmax3 = self.softmax3_blocks(out_mpool3) 

131 # 

132 out_interp3 = self.interpolation3(out_softmax3) + out_softmax2 

133 # print(out_skip2_connection.data) 

134 # print(out_interp3.data) 

135 out = out_interp3 + out_skip2_connection 

136 out_softmax4 = self.softmax4_blocks(out) 

137 out_interp2 = self.interpolation2(out_softmax4) + out_softmax1 

138 out = out_interp2 + out_skip1_connection 

139 out_softmax5 = self.softmax5_blocks(out) 

140 out_interp1 = self.interpolation1(out_softmax5) + out_trunk 

141 out_softmax6 = self.softmax6_blocks(out_interp1) 

142 out = (1 + out_softmax6) * out_trunk 

143 out_last = self.last_blocks(out) 

144 

145 return out_last 

146 

147 

148class AttentionModule_stage2(nn.Module): 

149 # input image size is 28*28 

150 def __init__( 

151 self, in_channels, out_channels, size1=(28, 28), size2=(14, 14) 

152 ): 

153 super(AttentionModule_stage2, self).__init__() 

154 self.first_residual_blocks = ResidualBlock(in_channels, out_channels) 

155 self.trunk_branches = nn.Sequential( 

156 ResidualBlock(in_channels, out_channels), 

157 ResidualBlock(in_channels, out_channels), 

158 ) 

159 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

160 self.softmax1_blocks = ResidualBlock(in_channels, out_channels) 

161 self.skip1_connection_residual_block = ResidualBlock( 

162 in_channels, out_channels 

163 ) 

164 self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

165 self.softmax2_blocks = nn.Sequential( 

166 ResidualBlock(in_channels, out_channels), 

167 ResidualBlock(in_channels, out_channels), 

168 ) 

169 self.interpolation2 = nn.UpsamplingBilinear2d(size=size2) 

170 self.softmax3_blocks = ResidualBlock(in_channels, out_channels) 

171 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 

172 self.softmax4_blocks = nn.Sequential( 

173 nn.BatchNorm2d(out_channels), 

174 nn.ReLU(inplace=True), 

175 nn.Conv2d( 

176 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

177 ), 

178 nn.BatchNorm2d(out_channels), 

179 nn.ReLU(inplace=True), 

180 nn.Conv2d( 

181 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

182 ), 

183 nn.Sigmoid(), 

184 ) 

185 self.last_blocks = ResidualBlock(in_channels, out_channels) 

186 

187 def forward(self, x): 

188 x = self.first_residual_blocks(x) 

189 out_trunk = self.trunk_branches(x) 

190 out_mpool1 = self.mpool1(x) 

191 out_softmax1 = self.softmax1_blocks(out_mpool1) 

192 out_skip1_connection = self.skip1_connection_residual_block( 

193 out_softmax1 

194 ) 

195 out_mpool2 = self.mpool2(out_softmax1) 

196 out_softmax2 = self.softmax2_blocks(out_mpool2) 

197 out_interp2 = self.interpolation2(out_softmax2) + out_softmax1 

198 # print(out_skip2_connection.data) 

199 # print(out_interp3.data) 

200 out = out_interp2 + out_skip1_connection 

201 out_softmax3 = self.softmax3_blocks(out) 

202 out_interp1 = self.interpolation1(out_softmax3) + out_trunk 

203 out_softmax4 = self.softmax4_blocks(out_interp1) 

204 out = (1 + out_softmax4) * out_trunk 

205 out_last = self.last_blocks(out) 

206 return out_last 

207 

208 

209class AttentionModule_stage3(nn.Module): 

210 # input image size is 14*14 

211 def __init__(self, in_channels, out_channels, size1=(14, 14)): 

212 super(AttentionModule_stage3, self).__init__() 

213 self.first_residual_blocks = ResidualBlock(in_channels, out_channels) 

214 self.trunk_branches = nn.Sequential( 

215 ResidualBlock(in_channels, out_channels), 

216 ResidualBlock(in_channels, out_channels), 

217 ) 

218 self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

219 self.softmax1_blocks = nn.Sequential( 

220 ResidualBlock(in_channels, out_channels), 

221 ResidualBlock(in_channels, out_channels), 

222 ) 

223 self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 

224 self.softmax2_blocks = nn.Sequential( 

225 nn.BatchNorm2d(out_channels), 

226 nn.ReLU(inplace=True), 

227 nn.Conv2d( 

228 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

229 ), 

230 nn.BatchNorm2d(out_channels), 

231 nn.ReLU(inplace=True), 

232 nn.Conv2d( 

233 out_channels, out_channels, kernel_size=1, stride=1, bias=False 

234 ), 

235 nn.Sigmoid(), 

236 ) 

237 self.last_blocks = ResidualBlock(in_channels, out_channels) 

238 

239 def forward(self, x): 

240 x = self.first_residual_blocks(x) 

241 out_trunk = self.trunk_branches(x) 

242 out_mpool1 = self.mpool1(x) 

243 out_softmax1 = self.softmax1_blocks(out_mpool1) 

244 out_interp1 = self.interpolation1(out_softmax1) + out_trunk 

245 out_softmax2 = self.softmax2_blocks(out_interp1) 

246 out = (1 + out_softmax2) * out_trunk 

247 out_last = self.last_blocks(out) 

248 return out_last 

249 

250 

251class ResidualAttentionNet(nn.Module): 

252 def __init__( 

253 self, 

254 stage1_modules, 

255 stage2_modules, 

256 stage3_modules, 

257 feat_dim, 

258 out_h, 

259 out_w, 

260 ): 

261 super(ResidualAttentionNet, self).__init__() 

262 self.conv1 = nn.Sequential( 

263 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), 

264 nn.BatchNorm2d(64), 

265 nn.ReLU(inplace=True), 

266 ) 

267 attention_modules = [] 

268 

269 attention_modules.append(ResidualBlock(64, 256)) 

270 # stage 1 

271 for i in range(stage1_modules): 

272 attention_modules.append(AttentionModule_stage1(256, 256)) 

273 

274 attention_modules.append(ResidualBlock(256, 512, 2)) 

275 # stage2 

276 for i in range(stage2_modules): 

277 attention_modules.append(AttentionModule_stage2(512, 512)) 

278 

279 attention_modules.append(ResidualBlock(512, 1024, 2)) 

280 # stage3 

281 for i in range(stage3_modules): 

282 attention_modules.append(AttentionModule_stage3(1024, 1024)) 

283 

284 # final residual 

285 attention_modules.append(ResidualBlock(1024, 2048, 2)) 

286 attention_modules.append(ResidualBlock(2048, 2048)) 

287 attention_modules.append(ResidualBlock(2048, 2048)) 

288 self.attention_body = nn.Sequential(*attention_modules) 

289 # output layer 

290 self.output_layer = nn.Sequential( 

291 Flatten(), 

292 nn.Linear(2048 * out_h * out_w, feat_dim, False), 

293 nn.BatchNorm1d(feat_dim), 

294 ) 

295 

296 def forward(self, x): 

297 out = self.conv1(x) 

298 out = self.attention_body(out) 

299 out = self.output_layer(out) 

300 return out