1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Common utilities"""
6
7import contextlib
8
9import torch
10import torch.utils.data
11import PIL
12import numpy as np
13from torchvision.transforms import Compose, ToTensor
14
15
16class SampleListDataset(torch.utils.data.Dataset):
17 """PyTorch dataset wrapper around Sample lists
18
19 A transform object can be passed that will be applied to the image, ground
20 truth and mask (if present).
21
22 It supports indexing such that dataset[i] can be used to get ith sample.
23
24 Parameters
25 ----------
26
27 samples : list
28 A list of :py:class:`bob.med.tb.data.sample.Sample` objects
29
30 transforms : :py:class:`list`, Optional
31 a list of transformations to be applied to **both** image and
32 ground-truth data. Notice a last transform
33 (:py:class:`torchvision.transforms.transforms.ToTensor`) is always
34 applied - you do not need to add that.
35
36 """
37
38 def __init__(self, samples, transforms=[]):
39
40 self._samples = samples
41 self.transforms = transforms
42
43 @property
44 def transforms(self):
45 return self._transforms.transforms[:-1]
46
47 @transforms.setter
48 def transforms(self, l):
49 if any([isinstance(t, ToTensor) for t in l]):
50 self._transforms = Compose(l)
51 else:
52 self._transforms = Compose(l + [ToTensor()])
53
54 def copy(self, transforms=None):
55 """Returns a deep copy of itself, optionally resetting transforms
56
57 Parameters
58 ----------
59
60 transforms : :py:class:`list`, Optional
61 An optional list of transforms to set in the copy. If not
62 specified, use ``self.transforms``.
63 """
64
65 return SampleListDataset(self._samples, transforms or self.transforms)
66
67 def random_permute(self, feature):
68 """Randomly permute feature values from all samples
69
70 Useful for permutation feature importance computation
71
72 Parameters
73 ----------
74
75 feature : int
76 The position of the feature
77 """
78 feature_values = np.zeros(len(self))
79
80 for k, s in enumerate(self._samples):
81 features = s.data['data']
82 if isinstance(features, list):
83 feature_values[k] = features[feature]
84
85 np.random.shuffle(feature_values)
86
87 for k, s in enumerate(self._samples):
88 features = s.data["data"]
89 features[feature] = feature_values[k]
90
91 def __len__(self):
92 """
93
94 Returns
95 -------
96
97 size : int
98 size of the dataset
99
100 """
101 return len(self._samples)
102
103 def __getitem__(self, key):
104 """
105
106 Parameters
107 ----------
108
109 key : int, slice
110
111 Returns
112 -------
113
114 sample : list
115 The sample data: ``[key, image, label]``
116
117 """
118
119 if isinstance(key, slice):
120 return [self[k] for k in range(*key.indices(len(self)))]
121 else: # we try it as an int
122 item = data = self._samples[key]
123 if not isinstance(data, dict):
124 key = item.key
125 data = item.data # triggers data loading
126
127 retval = data["data"]
128
129 if self._transforms and isinstance(retval, PIL.Image.Image):
130 retval = self._transforms(retval)
131 elif isinstance(retval, list):
132 retval = torch.FloatTensor(retval)
133
134 if "label" in data:
135 if isinstance(data["label"], list):
136 return [key, retval, torch.FloatTensor(data["label"])]
137 else:
138 return [key, retval, data["label"]]
139
140 return [item.key, retval]