1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7from collections import OrderedDict
8from .normalizer import TorchVisionNormalizer
9
10class PASA(nn.Module):
11 """
12 PASA module
13
14 Based on paper by [PASA-2019]_.
15
16 """
17 def __init__(self):
18 super().__init__()
19 # First convolution block
20 self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
21 self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
22 self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4))
23
24 self.batchNorm2d_4 = nn.BatchNorm2d(4)
25 self.batchNorm2d_16 = nn.BatchNorm2d(16)
26 self.batchNorm2d_16_2 = nn.BatchNorm2d(16)
27
28 # Second convolution block
29 self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
30 self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
31 self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1)) # Original stride (2, 2)
32
33 self.batchNorm2d_24 = nn.BatchNorm2d(24)
34 self.batchNorm2d_32 = nn.BatchNorm2d(32)
35 self.batchNorm2d_32_2 = nn.BatchNorm2d(32)
36
37 # Third convolution block
38 self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
39 self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
40 self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1)) # Original stride (2, 2)
41
42 self.batchNorm2d_40 = nn.BatchNorm2d(40)
43 self.batchNorm2d_48 = nn.BatchNorm2d(48)
44 self.batchNorm2d_48_2 = nn.BatchNorm2d(48)
45
46 # Fourth convolution block
47 self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
48 self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
49 self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1)) # Original stride (2, 2)
50
51 self.batchNorm2d_56 = nn.BatchNorm2d(56)
52 self.batchNorm2d_64 = nn.BatchNorm2d(64)
53 self.batchNorm2d_64_2 = nn.BatchNorm2d(64)
54
55 # Fifth convolution block
56 self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
57 self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
58 self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1)) # Original stride (2, 2)
59
60 self.batchNorm2d_72 = nn.BatchNorm2d(72)
61 self.batchNorm2d_80 = nn.BatchNorm2d(80)
62 self.batchNorm2d_80_2 = nn.BatchNorm2d(80)
63
64 self.pool2d = nn.MaxPool2d((3, 3), (2, 2)) # Pool after conv. block
65 self.dense = nn.Linear(80, 1) # Fully connected layer
66
67 def forward(self, x):
68 """
69
70 Parameters
71 ----------
72
73 x : list
74 list of tensors.
75
76 Returns
77 -------
78
79 tensor : :py:class:`torch.Tensor`
80
81 """
82
83 # First convolution block
84 _x = x
85 x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution
86 x = F.relu(self.batchNorm2d_16(self.fc2(x))) # 2nd convolution
87 x = (x + F.relu(self.batchNorm2d_16_2(self.fc3(_x))))/2 # Parallel
88 x = self.pool2d(x) # Pooling
89
90 # Second convolution block
91 _x = x
92 x = F.relu(self.batchNorm2d_24(self.fc4(x))) # 1st convolution
93 x = F.relu(self.batchNorm2d_32(self.fc5(x))) # 2nd convolution
94 x = (x + F.relu(self.batchNorm2d_32_2(self.fc6(_x))))/2 # Parallel
95 x = self.pool2d(x) # Pooling
96
97 # Third convolution block
98 _x = x
99 x = F.relu(self.batchNorm2d_40(self.fc7(x))) # 1st convolution
100 x = F.relu(self.batchNorm2d_48(self.fc8(x))) # 2nd convolution
101 x = (x + F.relu(self.batchNorm2d_48_2(self.fc9(_x))))/2 # Parallel
102 x = self.pool2d(x) # Pooling
103
104 # Fourth convolution block
105 _x = x
106 x = F.relu(self.batchNorm2d_56(self.fc10(x))) # 1st convolution
107 x = F.relu(self.batchNorm2d_64(self.fc11(x))) # 2nd convolution
108 x = (x + F.relu(self.batchNorm2d_64_2(self.fc12(_x))))/2 # Parallel
109 x = self.pool2d(x) # Pooling
110
111 # Fifth convolution block
112 _x = x
113 x = F.relu(self.batchNorm2d_72(self.fc13(x))) # 1st convolution
114 x = F.relu(self.batchNorm2d_80(self.fc14(x))) # 2nd convolution
115 x = (x + F.relu(self.batchNorm2d_80_2(self.fc15(_x))))/2 # Parallel
116 # no pooling
117
118 # Global average pooling
119 x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
120
121 # Dense layer
122 x = self.dense(x)
123
124 # x = F.log_softmax(x, dim=1) # 0 is batch size
125
126 return x
127
128def build_pasa():
129 """
130 Build pasa CNN
131
132 Returns
133 -------
134
135 module : :py:class:`torch.nn.Module`
136
137 """
138
139 model = PASA()
140 model = [("normalizer", TorchVisionNormalizer(nb_channels=1)),
141 ("model", model)]
142 model = nn.Sequential(OrderedDict(model))
143
144 model.name = "pasa"
145 return model