1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch.nn as nn
5import torchvision.models as models
6from collections import OrderedDict
7from .normalizer import TorchVisionNormalizer
8
9
10class DensenetRS(nn.Module):
11 """
12 Densenet121 module for radiological extraction
13
14 """
15
16 def __init__(self):
17 super(DensenetRS, self).__init__()
18
19 # Load pretrained model
20 self.model_ft = models.densenet121(
21 weights=models.DenseNet121_Weights.DEFAULT
22 )
23
24 # Adapt output features
25 num_ftrs = self.model_ft.classifier.in_features
26 self.model_ft.classifier = nn.Linear(num_ftrs, 14)
27
28 def forward(self, x):
29 """
30
31 Parameters
32 ----------
33
34 x : list
35 list of tensors.
36
37 Returns
38 -------
39
40 tensor : :py:class:`torch.Tensor`
41
42 """
43
44 return self.model_ft(x)
45
46
47def build_densenetrs():
48 """
49 Build DensenetRS CNN
50
51 Returns
52 -------
53
54 module : :py:class:`torch.nn.Module`
55
56 """
57
58 model = DensenetRS()
59 model = [("normalizer", TorchVisionNormalizer()), ("model", model)]
60 model = nn.Sequential(OrderedDict(model))
61
62 model.name = "DensenetRS"
63 return model