Coverage for src/deepdraw/models/backbones/resnet.py: 65%
26 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import torchvision.models
7try:
8 # pytorch >= 1.12
9 from torch.hub import load_state_dict_from_url
10except ImportError:
11 # pytorch < 1.12
12 from torchvision.models.utils import load_state_dict_from_url
15class ResNet4Segmentation(torchvision.models.resnet.ResNet):
16 """Adaptation of base ResNet functionality to U-Net style segmentation.
18 This version of ResNet is slightly modified so it can be used through
19 torchvision's API. It outputs intermediate features which are normally not
20 output by the base ResNet implementation, but are required for segmentation
21 operations.
24 Parameters
25 ==========
27 return_features : :py:class:`list`, Optional
28 A list of integers indicating the feature layers to be returned from
29 the original module.
30 """
32 def __init__(self, *args, **kwargs):
33 self._return_features = kwargs.pop("return_features")
34 super().__init__(*args, **kwargs)
36 def forward(self, x):
37 outputs = []
38 # hardwiring of input
39 outputs.append(x.shape[2:4])
40 for index, m in enumerate(self.features):
41 x = m(x)
42 # extract layers
43 if index in self.return_features:
44 outputs.append(x)
45 return outputs
48def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
49 model = ResNet4Segmentation(
50 torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs
51 )
53 if pretrained:
54 state_dict = load_state_dict_from_url(
55 torchvision.models.resnet.ResNet50_Weights.DEFAULT.url,
56 progress=progress,
57 )
58 model.load_state_dict(state_dict)
60 # erase ResNet head (for classification), not used for segmentation
61 delattr(model, "avgpool")
62 delattr(model, "fc")
64 return model
67resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__