Coverage for src/bob/pipelines/sample.py: 92%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 21:32 +0200

1"""Base definition of sample.""" 

2 

3from collections.abc import MutableSequence, Sequence 

4from typing import Any 

5 

6import numpy as np 

7 

8from bob.io.base import vstack_features 

9 

10SAMPLE_DATA_ATTRS = ("data", "samples") 

11 

12 

13def _copy_attributes(sample, parent, kwargs, exclude_list=None): 

14 """Copies attributes from a dictionary to self.""" 

15 exclude_list = exclude_list or [] 

16 if parent is not None: 

17 for key in parent.__dict__: 

18 if ( 

19 key.startswith("_") 

20 or key in SAMPLE_DATA_ATTRS 

21 or key in exclude_list 

22 ): 

23 continue 

24 

25 setattr(sample, key, getattr(parent, key)) 

26 

27 for key, value in kwargs.items(): 

28 if ( 

29 key.startswith("_") 

30 or key in SAMPLE_DATA_ATTRS 

31 or key in exclude_list 

32 ): 

33 continue 

34 

35 setattr(sample, key, value) 

36 

37 

38class _ReprMixin: 

39 def __repr__(self): 

40 return ( 

41 f"{self.__class__.__name__}(" 

42 + ", ".join( 

43 f"{k}={v!r}" 

44 for k, v in self.__dict__.items() 

45 if not k.startswith("_") 

46 ) 

47 + ")" 

48 ) 

49 

50 def __eq__(self, other): 

51 sorted_self = { 

52 k: v 

53 for k, v in sorted(self.__dict__.items(), key=lambda item: item[0]) 

54 } 

55 sorted_other = { 

56 k: v 

57 for k, v in sorted(other.__dict__.items(), key=lambda item: item[0]) 

58 } 

59 

60 for s, o in zip(sorted_self, sorted_other): 

61 # Checking keys 

62 if s != o: 

63 return False 

64 

65 # Checking values 

66 if isinstance(sorted_self[s], np.ndarray) and isinstance( 

67 sorted_self[o], np.ndarray 

68 ): 

69 if not np.allclose(sorted_self[s], sorted_other[o]): 

70 return False 

71 else: 

72 if sorted_self[s] != sorted_other[o]: 

73 return False 

74 

75 return True 

76 

77 

78class Sample(_ReprMixin): 

79 """Representation of sample. A Sample is a simple container that wraps a 

80 data-point (see :ref:`bob.pipelines.sample`) 

81 

82 Each sample must have the following attributes: 

83 

84 * attribute ``data``: Contains the data for this sample 

85 

86 

87 Parameters 

88 ---------- 

89 

90 data : object 

91 Object representing the data to initialize this sample with. 

92 

93 parent : object 

94 A parent object from which to inherit all other attributes (except 

95 ``data``) 

96 """ 

97 

98 def __init__(self, data, parent=None, **kwargs): 

99 self.data = data 

100 _copy_attributes(self, parent, kwargs) 

101 

102 

103class DelayedSample(Sample): 

104 """Representation of sample that can be loaded via a callable. 

105 

106 The optional ``**kwargs`` argument allows you to attach more attributes to 

107 this sample instance. 

108 

109 

110 Parameters 

111 ---------- 

112 

113 load 

114 A python function that can be called parameterlessly, to load the 

115 sample in question from whatever medium 

116 

117 parent : :any:`DelayedSample`, :any:`Sample`, None 

118 If passed, consider this as a parent of this sample, to copy 

119 information 

120 

121 delayed_attributes : dict or None 

122 A dictionary of name : load_fn pairs that will be used to create 

123 attributes of name : load_fn() in this class. Use this to option 

124 to create more delayed attributes than just ``sample.data``. 

125 

126 kwargs : dict 

127 Further attributes of this sample, to be stored and eventually 

128 transmitted to transformed versions of the sample 

129 """ 

130 

131 def __init__(self, load, parent=None, delayed_attributes=None, **kwargs): 

132 self.__running_init__ = True 

133 # Merge parent's and param's delayed_attributes 

134 parent_attr = getattr(parent, "_delayed_attributes", None) 

135 self._delayed_attributes = None 

136 if parent_attr is not None: 

137 self._delayed_attributes = parent_attr.copy() 

138 

139 if delayed_attributes is not None: 

140 # Sanity check, `delayed_attributes` can not be present in `kwargs` 

141 # as well 

142 for name, attr in delayed_attributes.items(): 

143 if name in kwargs: 

144 raise ValueError( 

145 "`{}` can not be in both `delayed_attributes` and " 

146 "`kwargs` inputs".format(name) 

147 ) 

148 if self._delayed_attributes is None: 

149 self._delayed_attributes = delayed_attributes.copy() 

150 else: 

151 self._delayed_attributes.update(delayed_attributes) 

152 

153 # Inherit attributes from parent, without calling delayed_attributes 

154 for key in getattr(parent, "__dict__", []): 

155 if key.startswith("_"): 

156 continue 

157 if key in SAMPLE_DATA_ATTRS: 

158 continue 

159 if self._delayed_attributes is not None: 

160 if key in self._delayed_attributes: 

161 continue 

162 setattr(self, key, getattr(parent, key)) 

163 

164 # Create the delayed attributes, but leave their values as None for now. 

165 if self._delayed_attributes is not None: 

166 update = {} 

167 for k in list(self._delayed_attributes): 

168 if k not in kwargs: 

169 update[k] = None 

170 else: 

171 # k is not a delay_attribute anymore 

172 del self._delayed_attributes[k] 

173 if len(self._delayed_attributes) == 0: 

174 self._delayed_attributes = None 

175 kwargs.update(update) 

176 # kwargs.update({k: None for k in self._delayed_attributes}) 

177 # Set attribute from kwargs 

178 _copy_attributes(self, None, kwargs) 

179 self._load = load 

180 del self.__running_init__ 

181 

182 def __getattribute__(self, name: str) -> Any: 

183 try: 

184 delayed_attributes = super().__getattribute__("_delayed_attributes") 

185 except AttributeError: 

186 delayed_attributes = None 

187 if delayed_attributes is None or name not in delayed_attributes: 

188 return super().__getattribute__(name) 

189 return delayed_attributes[name]() 

190 

191 def __setattr__(self, name: str, value: Any) -> None: 

192 if ( 

193 name != "delayed_attributes" 

194 and "__running_init__" not in self.__dict__ 

195 ): 

196 delayed_attributes = getattr(self, "_delayed_attributes", None) 

197 # if setting an attribute which was delayed, remove it from delayed_attributes 

198 if delayed_attributes is not None and name in delayed_attributes: 

199 del delayed_attributes[name] 

200 

201 super().__setattr__(name, value) 

202 

203 @property 

204 def data(self): 

205 """Loads the data from the disk file.""" 

206 return self._load() 

207 

208 @classmethod 

209 def from_sample(cls, sample: Sample, **kwargs): 

210 """Creates a DelayedSample from another DelayedSample or a Sample. 

211 If the sample is a DelayedSample, its data will not be loaded. 

212 

213 Parameters 

214 ---------- 

215 

216 sample : :any:`Sample` 

217 The sample to convert to a DelayedSample 

218 """ 

219 if hasattr(sample, "_load"): 

220 data = sample._load 

221 else: 

222 

223 def data(): 

224 return sample.data 

225 

226 return cls(data, parent=sample, **kwargs) 

227 

228 

229class SampleSet(MutableSequence, _ReprMixin): 

230 """A set of samples with extra attributes""" 

231 

232 def __init__(self, samples, parent=None, **kwargs): 

233 self.samples = samples 

234 _copy_attributes( 

235 self, 

236 parent, 

237 kwargs, 

238 exclude_list=getattr(parent, "_delayed_attributes", None), 

239 ) 

240 

241 def __len__(self): 

242 return len(self.samples) 

243 

244 def __getitem__(self, item): 

245 return self.samples.__getitem__(item) 

246 

247 def __setitem__(self, key, item): 

248 return self.samples.__setitem__(key, item) 

249 

250 def __delitem__(self, item): 

251 return self.samples.__delitem__(item) 

252 

253 def insert(self, index, item): 

254 # if not item in self.samples: 

255 self.samples.insert(index, item) 

256 

257 

258class DelayedSampleSet(SampleSet): 

259 """A set of samples with extra attributes""" 

260 

261 def __init__(self, load, parent=None, **kwargs): 

262 self._load = load 

263 _copy_attributes( 

264 self, 

265 parent, 

266 kwargs, 

267 exclude_list=getattr(parent, "_delayed_attributes", None), 

268 ) 

269 

270 @property 

271 def samples(self): 

272 return self._load() 

273 

274 

275class DelayedSampleSetCached(DelayedSampleSet): 

276 """A cached version of DelayedSampleSet""" 

277 

278 def __init__(self, load, parent=None, **kwargs): 

279 super().__init__(load, parent=parent, kwargs=kwargs) 

280 self._data = None 

281 _copy_attributes( 

282 self, 

283 parent, 

284 kwargs, 

285 exclude_list=getattr(parent, "_delayed_attributes", None), 

286 ) 

287 

288 @property 

289 def samples(self): 

290 if self._data is None: 

291 self._data = self._load() 

292 return self._data 

293 

294 

295class SampleBatch(Sequence, _ReprMixin): 

296 """A batch of samples that looks like [s.data for s in samples] 

297 

298 However, when you call np.array(SampleBatch), it will construct a numpy array from 

299 sample.data attributes in a memory efficient way. 

300 """ 

301 

302 def __init__(self, samples, sample_attribute="data"): 

303 self.samples = samples 

304 self.sample_attribute = sample_attribute 

305 

306 def __len__(self): 

307 return len(self.samples) 

308 

309 def __getitem__(self, item): 

310 return getattr(self.samples[item], self.sample_attribute) 

311 

312 def __array__(self, dtype=None, *args, **kwargs): 

313 def _reader(s): 

314 # adding one more dimension to data so they get stacked sample-wise 

315 return getattr(s, self.sample_attribute)[None, ...] 

316 

317 if self.samples and hasattr( 

318 getattr(self.samples[0], self.sample_attribute), "shape" 

319 ): 

320 try: 

321 arr = vstack_features(_reader, self.samples, dtype=dtype) 

322 except Exception as e: 

323 try: 

324 # try computing one feature to show a better traceback 

325 _ = getattr(self.samples[0], self.sample_attribute) 

326 raise e 

327 except Exception as e2: 

328 raise e2 from e 

329 

330 else: 

331 # to handle string data 

332 arr = [getattr(s, self.sample_attribute) for s in self.samples] 

333 return np.asarray(arr, dtype, *args, **kwargs)