Coverage for src/bob/pipelines/xarray.py: 95%
306 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
1import logging
2import os
3import random
4import string
6from functools import partial
8import cloudpickle
9import dask
10import h5py
11import numpy as np
12import xarray as xr
14from sklearn.base import BaseEstimator
15from sklearn.pipeline import _name_estimators
16from sklearn.utils.metaestimators import _BaseComposition
18from .sample import SAMPLE_DATA_ATTRS, _ReprMixin
19from .wrappers import estimator_requires_fit
21logger = logging.getLogger(__name__)
24def save(data, path):
25 array = np.require(data, requirements=("C_CONTIGUOUS", "ALIGNED"))
26 with h5py.File(path, "w") as f:
27 f.create_dataset("array", data=array)
30def load(path):
31 with h5py.File(path, "r") as f:
32 data = f["array"][()]
33 return data
36def _load_fn_to_xarray(load_fn, meta=None):
37 if meta is None:
38 meta = np.array(load_fn())
40 da = dask.array.from_delayed(
41 dask.delayed(load_fn)(), meta.shape, dtype=meta.dtype, name=False
42 )
43 try:
44 dims = meta.dims
45 except Exception:
46 dims = None
48 xa = xr.DataArray(da, dims=dims)
49 return xa, meta
52def _one_sample_to_dataset(sample, meta=None):
53 dataset = {}
54 delayed_attributes = getattr(sample, "_delayed_attributes", None) or {}
55 for k in sample.__dict__:
56 if (
57 k in SAMPLE_DATA_ATTRS
58 or k in delayed_attributes
59 or k.startswith("_")
60 ):
61 continue
62 dataset[k] = getattr(sample, k)
64 meta = meta or {}
66 for k in ["data"] + list(delayed_attributes.keys()):
67 attr_meta = meta.get(k)
68 attr_array, attr_meta = _load_fn_to_xarray(
69 partial(getattr, sample, k), meta=attr_meta
70 )
71 meta[k] = attr_meta
72 dataset[k] = attr_array
74 return xr.Dataset(dataset).chunk(), meta
77def samples_to_dataset(samples, meta=None, npartitions=48, shuffle=False):
78 """Converts a list of samples to a dataset.
80 See :ref:`bob.pipelines.dataset_pipeline`.
82 Parameters
83 ----------
84 samples : list
85 A list of :any:`Sample` or :any:`DelayedSample` objects.
86 meta : ``xarray.DataArray``, optional
87 An xarray.DataArray to be used as a template for data inside samples.
88 npartitions : :obj:`int`, optional
89 The number of partitions to partition the samples.
90 shuffle : :obj:`bool`, optional
91 If True, shuffles the samples (in-place) before constructing the dataset.
93 Returns
94 -------
95 ``xarray.Dataset``
96 The constructed dataset with at least a ``data`` variable.
97 """
98 if meta is not None and not isinstance(meta, dict):
99 meta = dict(data=meta)
101 delayed_attributes = getattr(samples[0], "delayed_attributes", None) or {}
102 if meta is None or not all(
103 k in meta for k in ["data"] + list(delayed_attributes.keys())
104 ):
105 dataset, meta = _one_sample_to_dataset(samples[0])
107 if shuffle:
108 random.shuffle(samples)
110 dataset = xr.concat(
111 [_one_sample_to_dataset(s, meta=meta)[0] for s in samples], dim="sample"
112 )
113 if npartitions is not None:
114 dataset = dataset.chunk({"sample": max(1, len(samples) // npartitions)})
115 return dataset
118class Block(_ReprMixin):
119 """A block representation in a graph.
120 This class is meant to be used with :any:`DatasetPipeline`.
122 Attributes
123 ----------
124 dataset_map : ``callable``
125 A callable that transforms the input dataset into another dataset.
126 estimator : object
127 A scikit-learn estimator
128 estimator_name : str
129 Name of the estimator
130 extension : str
131 The extension of checkpointed features.
132 features_dir : str
133 The directory to save the features.
134 fit_input : str or list
135 A str or list of str of column names of the dataset to be given to the ``.fit``
136 method.
137 fit_kwargs : None or dict
138 A dict of ``fit_kwargs`` to be passed to the ``.fit`` method of the estimator.
139 input_dask_array : bool
140 Whether the estimator takes dask arrays in its fit method or not.
141 load_func : ``callable``
142 A function to save the features. Defaults to ``np.load``.
143 model_path : str or None
144 If given, the estimator will be pickled here.
145 output_dims : list
146 A list of ``(dim_name, dim_size)`` tuples. If ``dim_name`` is ``None``, a new
147 name is automatically generated, otherwise it should be a string. ``dim_size``
148 should be a positive integer or nan for new dimensions or ``None`` for existing
149 dimensions.
150 output_dtype : object
151 The dtype of the output of the transformer. Defaults to ``float``.
152 save_func : ``callable``
153 A function to save the features. Defaults to ``np.save`` with ``allow_pickle``
154 set to ``False``.
155 transform_input : str or list
156 A str or list of str of column names of the dataset to be given to the
157 ``.transform`` method.
158 """
160 def __init__(
161 self,
162 estimator=None,
163 output_dtype=float,
164 output_dims=((None, np.nan),),
165 fit_input="data",
166 transform_input="data",
167 estimator_name=None,
168 model_path=None,
169 features_dir=None,
170 extension=".hdf5",
171 save_func=None,
172 load_func=None,
173 dataset_map=None,
174 input_dask_array=False,
175 fit_kwargs=None,
176 **kwargs,
177 ):
178 super().__init__(**kwargs)
179 self.estimator = estimator
180 self.output_dtype = output_dtype
181 if not all(len(d) == 2 for d in output_dims):
182 raise ValueError(
183 "output_dims must be an iterable of size 2 tuples "
184 f"(dim_name, dim_size), not {output_dims}"
185 )
186 self.output_dims = output_dims
187 self.fit_input = fit_input
188 self.transform_input = transform_input
189 if estimator_name is None:
190 estimator_name = _name_estimators([estimator])[0][0]
191 self.estimator_name = estimator_name
192 self.model_path = model_path
193 self.features_dir = features_dir
194 self.extension = extension
195 estimator_save_fn = (
196 None
197 if estimator is None
198 else estimator._get_tags().get("bob_features_save_fn")
199 )
200 estimator_load_fn = (
201 None
202 if estimator is None
203 else estimator._get_tags().get("bob_features_load_fn")
204 )
205 self.save_func = save_func or estimator_save_fn or save
206 self.load_func = load_func or estimator_load_fn or load
207 self.dataset_map = dataset_map
208 self.input_dask_array = input_dask_array
209 self.fit_kwargs = fit_kwargs or {}
211 def __getitem__(self, key):
212 return getattr(self, key)
214 def __setitem__(self, key, value):
215 setattr(self, key, value)
217 @property
218 def output_ndim(self):
219 return len(self.output_dims) + 1
221 def make_path(self, key):
222 key = str(key)
223 if key.startswith(os.sep) or ".." in key:
224 raise ValueError(
225 "Sample.key values should be relative paths with no "
226 f"reference to upper folders. Got: {key}"
227 )
228 return os.path.join(self.features_dir, key + self.extension)
230 def save(self, key, data):
231 path = self.make_path(key)
232 os.makedirs(os.path.dirname(path), exist_ok=True)
233 # this should be save_func(data, path) so it's compatible with bob.io.base.save
234 return self.save_func(data, path)
236 def load(self, key):
237 path = self.make_path(key)
238 return self.load_func(path)
241def _fit(*args, block):
242 logger.info(f"Calling {block.estimator_name}.fit")
243 block.estimator.fit(*args, **block.fit_kwargs)
244 if block.model_path is not None:
245 logger.info(f"Saving {block.estimator_name} in {block.model_path}")
246 os.makedirs(os.path.dirname(block.model_path), exist_ok=True)
247 with open(block.model_path, "wb") as f:
248 cloudpickle.dump(block.estimator, f)
249 return block.estimator
252class _TokenStableTransform:
253 def __init__(self, block, method_name=None, input_has_keys=False, **kwargs):
254 super().__init__(**kwargs)
255 self.block = block
256 self.method_name = method_name or "transform"
257 self.input_has_keys = input_has_keys
259 def __dask_tokenize__(self):
260 return (self.method_name, self.block.features_dir)
262 def __call__(self, *args, estimator):
263 block, method_name = self.block, self.method_name
264 logger.info(f"Calling {block.estimator_name}.{method_name}")
266 input_args = args[:-1] if self.input_has_keys else args
267 try:
268 features = getattr(estimator, self.method_name)(*input_args)
269 except Exception as e:
270 raise RuntimeError(
271 f"Failed to transform data: {estimator}.{self.method_name}(*{input_args})"
272 ) from e
274 # if keys are provided, checkpoint features
275 if self.input_has_keys:
276 data = args[0]
277 key = args[-1]
279 l1, l2 = len(data), len(features)
280 if l1 != l2:
281 raise ValueError(
282 f"Got {l2} features from processing {l1} samples!"
283 )
285 # save computed_features
286 logger.info(f"Saving {l2} features in {block.features_dir}")
287 for feat, k in zip(features, key):
288 block.save(k, feat)
290 return features
293def _populate_graph(graph):
294 new_graph = []
295 for block in graph:
296 if isinstance(block, BaseEstimator):
297 block = {"estimator": block}
298 if isinstance(block, dict):
299 block = Block(**block)
300 new_graph.append(block)
301 return new_graph
304def _get_dask_args_from_ds(ds, columns):
305 if isinstance(columns, str):
306 args = [(ds[columns].data, ds[columns].dims)]
307 else:
308 args = []
309 for c in columns:
310 args.extend(_get_dask_args_from_ds(ds, c))
311 args = tuple(args)
312 return args
315def _blockwise_with_block_args(args, block, method_name=None):
316 meta = []
317 for _ in range(1, block.output_ndim):
318 meta = [meta]
319 meta = np.array(meta, dtype=block.output_dtype)
321 ascii_letters = list(string.ascii_lowercase)
322 dim_map = {}
324 input_arg_pairs = []
325 for array, dims in args:
326 dim_name = []
327 for dim, dim_size in zip(dims, array.shape):
328 if dim not in dim_map:
329 dim_map[dim] = (ascii_letters.pop(0), dim_size)
330 dim_name.append(dim_map[dim][0])
331 input_arg_pairs.extend((array, "".join(dim_name)))
333 # the sample dimension is always kept the same
334 output_dim_name = f"{input_arg_pairs[1][0]}"
335 new_axes = dict()
336 for dim_name, dim_size in block.output_dims:
337 if dim_name in dim_map:
338 output_dim_name += dim_map[dim_name][0]
339 else:
340 try:
341 dim_size = float(dim_size)
342 except Exception:
343 raise ValueError(
344 "Expected a float dim_size (positive integers or nan) for new "
345 f"dimension: {dim_name} but got: {dim_size}"
346 )
348 new_letter = ascii_letters.pop(0)
349 if dim_name is None:
350 dim_name = new_letter
351 dim_map[dim_name] = (new_letter, dim_size)
352 output_dim_name += new_letter
353 new_axes[new_letter] = dim_size
355 dims = []
356 inv_map = {v[0]: k for k, v in dim_map.items()}
357 for dim_name in output_dim_name:
358 dims.append(inv_map[dim_name])
360 output_shape = [dim_map[d][1] for d in dims]
362 return output_dim_name, new_axes, input_arg_pairs, dims, meta, output_shape
365def _blockwise_with_block(args, block, method_name=None, input_has_keys=False):
366 (
367 output_dim_name,
368 new_axes,
369 input_arg_pairs,
370 dims,
371 meta,
372 _,
373 ) = _blockwise_with_block_args(args, block, method_name=None)
374 transform_func = _TokenStableTransform(
375 block, method_name, input_has_keys=input_has_keys
376 )
377 transform_func.__name__ = f"{block.estimator_name}.{method_name}"
379 data = dask.array.blockwise(
380 transform_func,
381 output_dim_name,
382 *input_arg_pairs,
383 meta=meta,
384 new_axes=new_axes,
385 concatenate=True,
386 estimator=block.estimator_,
387 )
389 return dims, data
392def _load_estimator(block):
393 logger.info(f"Loading {block.estimator_name} from {block.model_path}")
394 with open(block.model_path, "rb") as f:
395 block.estimator = cloudpickle.load(f)
396 return block.estimator
399def _transform_or_load(block, ds, input_columns, mn):
400 if isinstance(input_columns, str):
401 input_columns = [input_columns]
402 input_columns = list(input_columns) + ["key"]
404 # filter dataset based on existing checkpoints
405 key = np.asarray(ds["key"])
406 paths = [block.make_path(k) for k in key]
407 saved_samples = np.asarray([os.path.isfile(p) for p in paths])
408 # compute/load features per chunk
409 chunksize = ds.data.data.chunksize[0]
410 for i in range(0, len(saved_samples), chunksize):
411 if not np.all(saved_samples[i : i + chunksize]):
412 saved_samples[i : i + chunksize] = False
414 nonsaved_samples = np.logical_not(saved_samples)
415 total_samples_n, saved_samples_n = len(key), saved_samples.sum()
416 saved_ds = ds.sel({"sample": saved_samples})
417 nonsaved_ds = ds.sel({"sample": nonsaved_samples})
419 computed_data = loaded_data = None
420 # compute non-saved data
421 if total_samples_n - saved_samples_n > 0:
422 args = _get_dask_args_from_ds(nonsaved_ds, input_columns)
423 dims, computed_data = _blockwise_with_block(
424 args, block, mn, input_has_keys=True
425 )
427 # load saved data
428 if saved_samples_n > 0:
429 logger.info(
430 f"Might load {saved_samples_n} features of {block.estimator_name}.{mn} from disk."
431 )
432 args = _get_dask_args_from_ds(saved_ds, input_columns)
433 dims, meta, shape = _blockwise_with_block_args(args, block, mn)[-3:]
434 loaded_data = [
435 dask.array.from_delayed(
436 dask.delayed(block.load)(k),
437 shape=shape[1:],
438 meta=meta,
439 name=False,
440 )[None, ...]
441 for k in key[saved_samples]
442 ]
443 loaded_data = dask.array.concatenate(loaded_data, axis=0)
445 # merge loaded and computed data
446 if computed_data is None:
447 data = loaded_data
448 elif loaded_data is None:
449 data = computed_data
450 else:
451 # merge data chunk-based
452 data = []
453 i, j = 0, 0
454 for k in range(0, len(saved_samples), chunksize):
455 saved = saved_samples[k]
456 if saved:
457 pick = loaded_data[j : j + chunksize]
458 j += chunksize
459 else:
460 pick = computed_data[i : i + chunksize]
461 i += chunksize
462 data.append(pick)
463 data = dask.array.concatenate(data, axis=0)
465 data = dask.array.rechunk(data, {0: chunksize})
466 return dims, data
469class DatasetPipeline(_BaseComposition):
470 """A dataset-based scikit-learn pipeline.
471 See :ref:`bob.pipelines.dataset_pipeline`.
473 Attributes
474 ----------
475 graph : list
476 A list of :any:`Block`'s to be applied on input dataset.
477 """
479 def __init__(self, graph, **kwargs):
480 super().__init__(**kwargs)
481 self.graph = _populate_graph(graph)
483 def _transform(self, ds, do_fit=False, method_name=None):
484 for i, block in enumerate(self.graph):
485 if block.dataset_map is not None:
486 try:
487 ds = block.dataset_map(ds)
488 except Exception as e:
489 raise RuntimeError(
490 f"Could not map ds {ds}\n with {block.dataset_map}"
491 ) from e
492 continue
494 if do_fit:
495 args = _get_dask_args_from_ds(ds, block.fit_input)
496 args = [d for d, dims in args]
497 estimator = block.estimator
498 if not estimator_requires_fit(estimator):
499 block.estimator_ = estimator
500 elif block.model_path is not None and os.path.isfile(
501 block.model_path
502 ):
503 _load_estimator.__name__ = f"load_{block.estimator_name}"
504 block.estimator_ = dask.delayed(_load_estimator)(block)
505 elif block.input_dask_array:
506 ds = ds.persist()
507 args = _get_dask_args_from_ds(ds, block.fit_input)
508 args = [d for d, dims in args]
509 block.estimator_ = _fit(*args, block=block)
510 else:
511 _fit.__name__ = f"{block.estimator_name}.fit"
512 block.estimator_ = dask.delayed(_fit)(
513 *args,
514 block=block,
515 )
517 mn = "transform"
518 if i == len(self.graph) - 1:
519 if do_fit:
520 break
521 mn = method_name
523 if block.features_dir is None:
524 args = _get_dask_args_from_ds(ds, block.transform_input)
525 dims, data = _blockwise_with_block(
526 args, block, mn, input_has_keys=False
527 )
528 else:
529 dims, data = _transform_or_load(
530 block, ds, block.transform_input, mn
531 )
533 # replace data inside dataset
534 ds = ds.copy(deep=False)
535 del ds["data"]
536 persisted = False
537 if not np.all(np.isfinite(data.shape)):
538 block.estimator_, data = dask.persist(block.estimator_, data)
539 data = data.compute_chunk_sizes()
540 persisted = True
541 ds["data"] = (dims, data)
542 if persisted:
543 ds = ds.persist()
545 return ds
547 def fit(self, ds, y=None):
548 if y is not None:
549 raise ValueError()
550 self._transform(ds, do_fit=True)
551 return self
553 def transform(self, ds):
554 return self._transform(ds, method_name="transform")
556 def decision_function(self, ds):
557 return self._transform(ds, method_name="decision_function")
559 def predict(self, ds):
560 return self._transform(ds, method_name="predict")
562 def predict_proba(self, ds):
563 return self._transform(ds, method_name="predict_proba")
565 def score(self, ds):
566 return self._transform(ds, method_name="score")