Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/splat.py: 88%
73 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1"""Split-Attention"""
3import torch
4import torch.nn.functional as F
6from torch import nn
7from torch.nn import Conv2d, Module, ReLU
8from torch.nn.modules.utils import _pair
10__all__ = ["SplAtConv2d"]
13class SplAtConv2d(Module):
14 """Split-Attention Conv2d"""
16 def __init__(
17 self,
18 in_channels,
19 channels,
20 kernel_size,
21 stride=(1, 1),
22 padding=(0, 0),
23 dilation=(1, 1),
24 groups=1,
25 bias=True,
26 radix=2,
27 reduction_factor=4,
28 rectify=False,
29 rectify_avg=False,
30 norm_layer=None,
31 dropblock_prob=0.0,
32 **kwargs,
33 ):
34 super(SplAtConv2d, self).__init__()
35 padding = _pair(padding)
36 self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
37 self.rectify_avg = rectify_avg
38 inter_channels = max(in_channels * radix // reduction_factor, 32)
39 self.radix = radix
40 self.cardinality = groups
41 self.channels = channels
42 self.dropblock_prob = dropblock_prob
43 if self.rectify:
44 from rfconv import RFConv2d
46 self.conv = RFConv2d(
47 in_channels,
48 channels * radix,
49 kernel_size,
50 stride,
51 padding,
52 dilation,
53 groups=groups * radix,
54 bias=bias,
55 average_mode=rectify_avg,
56 **kwargs,
57 )
58 else:
59 self.conv = Conv2d(
60 in_channels,
61 channels * radix,
62 kernel_size,
63 stride,
64 padding,
65 dilation,
66 groups=groups * radix,
67 bias=bias,
68 **kwargs,
69 )
70 self.use_bn = norm_layer is not None
71 if self.use_bn:
72 self.bn0 = norm_layer(channels * radix)
73 self.relu = ReLU(inplace=True)
74 self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
75 if self.use_bn:
76 self.bn1 = norm_layer(inter_channels)
77 self.fc2 = Conv2d(
78 inter_channels, channels * radix, 1, groups=self.cardinality
79 )
80 if dropblock_prob > 0.0:
81 raise NotImplementedError("DropBlock2D was not imported!")
82 # self.dropblock = DropBlock2D(dropblock_prob, 3)
83 self.rsoftmax = rSoftMax(radix, groups)
85 def forward(self, x):
86 x = self.conv(x)
87 if self.use_bn:
88 x = self.bn0(x)
89 if self.dropblock_prob > 0.0:
90 x = self.dropblock(x)
91 x = self.relu(x)
93 batch, rchannel = x.shape[:2]
94 if self.radix > 1:
95 if torch.__version__ < "1.5":
96 splited = torch.split(x, int(rchannel // self.radix), dim=1)
97 else:
98 splited = torch.split(x, rchannel // self.radix, dim=1)
99 gap = sum(splited)
100 else:
101 gap = x
102 gap = F.adaptive_avg_pool2d(gap, 1)
103 gap = self.fc1(gap)
105 if self.use_bn:
106 gap = self.bn1(gap)
107 gap = self.relu(gap)
109 atten = self.fc2(gap)
110 atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
112 if self.radix > 1:
113 if torch.__version__ < "1.5":
114 attens = torch.split(atten, int(rchannel // self.radix), dim=1)
115 else:
116 attens = torch.split(atten, rchannel // self.radix, dim=1)
117 out = sum([att * split for (att, split) in zip(attens, splited)])
118 else:
119 out = atten * x
120 return out.contiguous()
123class rSoftMax(nn.Module):
124 def __init__(self, radix, cardinality):
125 super().__init__()
126 self.radix = radix
127 self.cardinality = cardinality
129 def forward(self, x):
130 batch = x.size(0)
131 if self.radix > 1:
132 x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
133 x = F.softmax(x, dim=1)
134 x = x.reshape(batch, -1)
135 else:
136 x = torch.sigmoid(x)
137 return x