Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/splat.py: 88%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1"""Split-Attention""" 

2 

3import torch 

4import torch.nn.functional as F 

5 

6from torch import nn 

7from torch.nn import Conv2d, Module, ReLU 

8from torch.nn.modules.utils import _pair 

9 

10__all__ = ["SplAtConv2d"] 

11 

12 

13class SplAtConv2d(Module): 

14 """Split-Attention Conv2d""" 

15 

16 def __init__( 

17 self, 

18 in_channels, 

19 channels, 

20 kernel_size, 

21 stride=(1, 1), 

22 padding=(0, 0), 

23 dilation=(1, 1), 

24 groups=1, 

25 bias=True, 

26 radix=2, 

27 reduction_factor=4, 

28 rectify=False, 

29 rectify_avg=False, 

30 norm_layer=None, 

31 dropblock_prob=0.0, 

32 **kwargs, 

33 ): 

34 super(SplAtConv2d, self).__init__() 

35 padding = _pair(padding) 

36 self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 

37 self.rectify_avg = rectify_avg 

38 inter_channels = max(in_channels * radix // reduction_factor, 32) 

39 self.radix = radix 

40 self.cardinality = groups 

41 self.channels = channels 

42 self.dropblock_prob = dropblock_prob 

43 if self.rectify: 

44 from rfconv import RFConv2d 

45 

46 self.conv = RFConv2d( 

47 in_channels, 

48 channels * radix, 

49 kernel_size, 

50 stride, 

51 padding, 

52 dilation, 

53 groups=groups * radix, 

54 bias=bias, 

55 average_mode=rectify_avg, 

56 **kwargs, 

57 ) 

58 else: 

59 self.conv = Conv2d( 

60 in_channels, 

61 channels * radix, 

62 kernel_size, 

63 stride, 

64 padding, 

65 dilation, 

66 groups=groups * radix, 

67 bias=bias, 

68 **kwargs, 

69 ) 

70 self.use_bn = norm_layer is not None 

71 if self.use_bn: 

72 self.bn0 = norm_layer(channels * radix) 

73 self.relu = ReLU(inplace=True) 

74 self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 

75 if self.use_bn: 

76 self.bn1 = norm_layer(inter_channels) 

77 self.fc2 = Conv2d( 

78 inter_channels, channels * radix, 1, groups=self.cardinality 

79 ) 

80 if dropblock_prob > 0.0: 

81 raise NotImplementedError("DropBlock2D was not imported!") 

82 # self.dropblock = DropBlock2D(dropblock_prob, 3) 

83 self.rsoftmax = rSoftMax(radix, groups) 

84 

85 def forward(self, x): 

86 x = self.conv(x) 

87 if self.use_bn: 

88 x = self.bn0(x) 

89 if self.dropblock_prob > 0.0: 

90 x = self.dropblock(x) 

91 x = self.relu(x) 

92 

93 batch, rchannel = x.shape[:2] 

94 if self.radix > 1: 

95 if torch.__version__ < "1.5": 

96 splited = torch.split(x, int(rchannel // self.radix), dim=1) 

97 else: 

98 splited = torch.split(x, rchannel // self.radix, dim=1) 

99 gap = sum(splited) 

100 else: 

101 gap = x 

102 gap = F.adaptive_avg_pool2d(gap, 1) 

103 gap = self.fc1(gap) 

104 

105 if self.use_bn: 

106 gap = self.bn1(gap) 

107 gap = self.relu(gap) 

108 

109 atten = self.fc2(gap) 

110 atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 

111 

112 if self.radix > 1: 

113 if torch.__version__ < "1.5": 

114 attens = torch.split(atten, int(rchannel // self.radix), dim=1) 

115 else: 

116 attens = torch.split(atten, rchannel // self.radix, dim=1) 

117 out = sum([att * split for (att, split) in zip(attens, splited)]) 

118 else: 

119 out = atten * x 

120 return out.contiguous() 

121 

122 

123class rSoftMax(nn.Module): 

124 def __init__(self, radix, cardinality): 

125 super().__init__() 

126 self.radix = radix 

127 self.cardinality = cardinality 

128 

129 def forward(self, x): 

130 batch = x.size(0) 

131 if self.radix > 1: 

132 x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 

133 x = F.softmax(x, dim=1) 

134 x = x.reshape(batch, -1) 

135 else: 

136 x = torch.sigmoid(x) 

137 return x