Coverage for src/bob/pipelines/sample.py: 92%
147 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
1"""Base definition of sample."""
3from collections.abc import MutableSequence, Sequence
4from typing import Any
6import numpy as np
8from bob.io.base import vstack_features
10SAMPLE_DATA_ATTRS = ("data", "samples")
13def _copy_attributes(sample, parent, kwargs, exclude_list=None):
14 """Copies attributes from a dictionary to self."""
15 exclude_list = exclude_list or []
16 if parent is not None:
17 for key in parent.__dict__:
18 if (
19 key.startswith("_")
20 or key in SAMPLE_DATA_ATTRS
21 or key in exclude_list
22 ):
23 continue
25 setattr(sample, key, getattr(parent, key))
27 for key, value in kwargs.items():
28 if (
29 key.startswith("_")
30 or key in SAMPLE_DATA_ATTRS
31 or key in exclude_list
32 ):
33 continue
35 setattr(sample, key, value)
38class _ReprMixin:
39 def __repr__(self):
40 return (
41 f"{self.__class__.__name__}("
42 + ", ".join(
43 f"{k}={v!r}"
44 for k, v in self.__dict__.items()
45 if not k.startswith("_")
46 )
47 + ")"
48 )
50 def __eq__(self, other):
51 sorted_self = {
52 k: v
53 for k, v in sorted(self.__dict__.items(), key=lambda item: item[0])
54 }
55 sorted_other = {
56 k: v
57 for k, v in sorted(other.__dict__.items(), key=lambda item: item[0])
58 }
60 for s, o in zip(sorted_self, sorted_other):
61 # Checking keys
62 if s != o:
63 return False
65 # Checking values
66 if isinstance(sorted_self[s], np.ndarray) and isinstance(
67 sorted_self[o], np.ndarray
68 ):
69 if not np.allclose(sorted_self[s], sorted_other[o]):
70 return False
71 else:
72 if sorted_self[s] != sorted_other[o]:
73 return False
75 return True
78class Sample(_ReprMixin):
79 """Representation of sample. A Sample is a simple container that wraps a
80 data-point (see :ref:`bob.pipelines.sample`)
82 Each sample must have the following attributes:
84 * attribute ``data``: Contains the data for this sample
87 Parameters
88 ----------
90 data : object
91 Object representing the data to initialize this sample with.
93 parent : object
94 A parent object from which to inherit all other attributes (except
95 ``data``)
96 """
98 def __init__(self, data, parent=None, **kwargs):
99 self.data = data
100 _copy_attributes(self, parent, kwargs)
103class DelayedSample(Sample):
104 """Representation of sample that can be loaded via a callable.
106 The optional ``**kwargs`` argument allows you to attach more attributes to
107 this sample instance.
110 Parameters
111 ----------
113 load
114 A python function that can be called parameterlessly, to load the
115 sample in question from whatever medium
117 parent : :any:`DelayedSample`, :any:`Sample`, None
118 If passed, consider this as a parent of this sample, to copy
119 information
121 delayed_attributes : dict or None
122 A dictionary of name : load_fn pairs that will be used to create
123 attributes of name : load_fn() in this class. Use this to option
124 to create more delayed attributes than just ``sample.data``.
126 kwargs : dict
127 Further attributes of this sample, to be stored and eventually
128 transmitted to transformed versions of the sample
129 """
131 def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
132 self.__running_init__ = True
133 # Merge parent's and param's delayed_attributes
134 parent_attr = getattr(parent, "_delayed_attributes", None)
135 self._delayed_attributes = None
136 if parent_attr is not None:
137 self._delayed_attributes = parent_attr.copy()
139 if delayed_attributes is not None:
140 # Sanity check, `delayed_attributes` can not be present in `kwargs`
141 # as well
142 for name, attr in delayed_attributes.items():
143 if name in kwargs:
144 raise ValueError(
145 "`{}` can not be in both `delayed_attributes` and "
146 "`kwargs` inputs".format(name)
147 )
148 if self._delayed_attributes is None:
149 self._delayed_attributes = delayed_attributes.copy()
150 else:
151 self._delayed_attributes.update(delayed_attributes)
153 # Inherit attributes from parent, without calling delayed_attributes
154 for key in getattr(parent, "__dict__", []):
155 if key.startswith("_"):
156 continue
157 if key in SAMPLE_DATA_ATTRS:
158 continue
159 if self._delayed_attributes is not None:
160 if key in self._delayed_attributes:
161 continue
162 setattr(self, key, getattr(parent, key))
164 # Create the delayed attributes, but leave their values as None for now.
165 if self._delayed_attributes is not None:
166 update = {}
167 for k in list(self._delayed_attributes):
168 if k not in kwargs:
169 update[k] = None
170 else:
171 # k is not a delay_attribute anymore
172 del self._delayed_attributes[k]
173 if len(self._delayed_attributes) == 0:
174 self._delayed_attributes = None
175 kwargs.update(update)
176 # kwargs.update({k: None for k in self._delayed_attributes})
177 # Set attribute from kwargs
178 _copy_attributes(self, None, kwargs)
179 self._load = load
180 del self.__running_init__
182 def __getattribute__(self, name: str) -> Any:
183 try:
184 delayed_attributes = super().__getattribute__("_delayed_attributes")
185 except AttributeError:
186 delayed_attributes = None
187 if delayed_attributes is None or name not in delayed_attributes:
188 return super().__getattribute__(name)
189 return delayed_attributes[name]()
191 def __setattr__(self, name: str, value: Any) -> None:
192 if (
193 name != "delayed_attributes"
194 and "__running_init__" not in self.__dict__
195 ):
196 delayed_attributes = getattr(self, "_delayed_attributes", None)
197 # if setting an attribute which was delayed, remove it from delayed_attributes
198 if delayed_attributes is not None and name in delayed_attributes:
199 del delayed_attributes[name]
201 super().__setattr__(name, value)
203 @property
204 def data(self):
205 """Loads the data from the disk file."""
206 return self._load()
208 @classmethod
209 def from_sample(cls, sample: Sample, **kwargs):
210 """Creates a DelayedSample from another DelayedSample or a Sample.
211 If the sample is a DelayedSample, its data will not be loaded.
213 Parameters
214 ----------
216 sample : :any:`Sample`
217 The sample to convert to a DelayedSample
218 """
219 if hasattr(sample, "_load"):
220 data = sample._load
221 else:
223 def data():
224 return sample.data
226 return cls(data, parent=sample, **kwargs)
229class SampleSet(MutableSequence, _ReprMixin):
230 """A set of samples with extra attributes"""
232 def __init__(self, samples, parent=None, **kwargs):
233 self.samples = samples
234 _copy_attributes(
235 self,
236 parent,
237 kwargs,
238 exclude_list=getattr(parent, "_delayed_attributes", None),
239 )
241 def __len__(self):
242 return len(self.samples)
244 def __getitem__(self, item):
245 return self.samples.__getitem__(item)
247 def __setitem__(self, key, item):
248 return self.samples.__setitem__(key, item)
250 def __delitem__(self, item):
251 return self.samples.__delitem__(item)
253 def insert(self, index, item):
254 # if not item in self.samples:
255 self.samples.insert(index, item)
258class DelayedSampleSet(SampleSet):
259 """A set of samples with extra attributes"""
261 def __init__(self, load, parent=None, **kwargs):
262 self._load = load
263 _copy_attributes(
264 self,
265 parent,
266 kwargs,
267 exclude_list=getattr(parent, "_delayed_attributes", None),
268 )
270 @property
271 def samples(self):
272 return self._load()
275class DelayedSampleSetCached(DelayedSampleSet):
276 """A cached version of DelayedSampleSet"""
278 def __init__(self, load, parent=None, **kwargs):
279 super().__init__(load, parent=parent, kwargs=kwargs)
280 self._data = None
281 _copy_attributes(
282 self,
283 parent,
284 kwargs,
285 exclude_list=getattr(parent, "_delayed_attributes", None),
286 )
288 @property
289 def samples(self):
290 if self._data is None:
291 self._data = self._load()
292 return self._data
295class SampleBatch(Sequence, _ReprMixin):
296 """A batch of samples that looks like [s.data for s in samples]
298 However, when you call np.array(SampleBatch), it will construct a numpy array from
299 sample.data attributes in a memory efficient way.
300 """
302 def __init__(self, samples, sample_attribute="data"):
303 self.samples = samples
304 self.sample_attribute = sample_attribute
306 def __len__(self):
307 return len(self.samples)
309 def __getitem__(self, item):
310 return getattr(self.samples[item], self.sample_attribute)
312 def __array__(self, dtype=None, *args, **kwargs):
313 def _reader(s):
314 # adding one more dimension to data so they get stacked sample-wise
315 return getattr(s, self.sample_attribute)[None, ...]
317 if self.samples and hasattr(
318 getattr(self.samples[0], self.sample_attribute), "shape"
319 ):
320 try:
321 arr = vstack_features(_reader, self.samples, dtype=dtype)
322 except Exception as e:
323 try:
324 # try computing one feature to show a better traceback
325 _ = getattr(self.samples[0], self.sample_attribute)
326 raise e
327 except Exception as e2:
328 raise e2 from e
330 else:
331 # to handle string data
332 arr = [getattr(s, self.sample_attribute) for s in self.samples]
333 return np.asarray(arr, dtype, *args, **kwargs)