Coverage for src/bob/pipelines/wrappers.py: 85%
472 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
1"""Scikit-learn Estimator Wrappers."""
2import logging
3import os
4import tempfile
5import time
6import traceback
8from functools import partial
9from pathlib import Path
11import cloudpickle
12import dask
13import dask.array as da
14import dask.bag
15import numpy as np
17from dask import delayed
18from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin
19from sklearn.pipeline import Pipeline
20from sklearn.preprocessing import FunctionTransformer
22import bob.io.base
24from .sample import DelayedSample, Sample, SampleBatch, SampleSet
26logger = logging.getLogger(__name__)
29def _frmt(estimator, limit=30, attr="estimator"):
30 # default value of limit is chosen so the log can be seen in dask graphs
31 def _n(e):
32 return e.__class__.__name__.replace("Wrapper", "")
34 name = ""
35 while hasattr(estimator, attr):
36 name += f"{_n(estimator)}|"
37 estimator = getattr(estimator, attr)
39 if (
40 isinstance(estimator, FunctionTransformer)
41 and type(estimator) is FunctionTransformer
42 ):
43 name += str(estimator.func.__name__)
44 else:
45 name += str(estimator)
47 name = f"{name:.{limit}}"
48 return name
51def copy_learned_attributes(from_estimator, to_estimator):
52 attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith("_")}
54 for k, v in attrs.items():
55 setattr(to_estimator, k, v)
58def get_bob_tags(estimator=None, force_tags=None):
59 """Returns the default tags of a Transformer unless forced or specified.
61 Relies on the tags API of sklearn to set and retrieve the tags.
63 Specify an estimator tag values with ``estimator._more_tags``::
65 class My_annotator_transformer(sklearn.base.BaseEstimator):
66 def _more_tags(self):
67 return {"bob_output": "annotations"}
69 The returned tags will take their value with the following priority:
71 1. key:value in `force_tags`, if it is present;
72 2. key:value in `estimator` tags (set with `estimator._more_tags()`) if it exists;
73 3. the default value for that tag if none of the previous exist.
75 Examples
76 --------
77 bob_input: str
78 The Sample attribute passed to the first argument of the fit or transform method.
79 Default value is ``data``.
80 Example::
82 {"bob_input": ("annotations")}
84 will result in::
86 estimator.transform(sample.annotations)
88 bob_transform_extra_input: tuple of str
89 Each element of the tuple is a str representing an attribute of a Sample
90 object. Each attribute of the sample will be passed as argument to the transform
91 method in that order. Default value is an empty tuple ``(,)``.
92 Example::
94 {"bob_transform_extra_input": (("kwarg_1","annotations"), ("kwarg_2","gender"))}
96 will result in::
98 estimator.transform(sample.data, kwarg_1=sample.annotations, kwarg_2=sample.gender)
100 bob_fit_extra_input: tuple of str
101 Each element of the tuple is a str representing an attribute of a Sample
102 object. Each attribute of the sample will be passed as argument to the fit
103 method in that order. Default value is an empty tuple ``(,)``.
104 Example::
106 {"bob_fit_extra_input": (("y", "annotations"), ("extra "metadata"))}
108 will result in::
110 estimator.fit(sample.data, y=sample.annotations, extra=sample.metadata)
112 bob_output: str
113 The Sample attribute in which the output of the transform is stored.
114 Default value is ``data``.
116 bob_checkpoint_extension: str
117 The extension of each checkpoint file.
118 Default value is ``.h5``.
120 bob_features_save_fn: func
121 The function used to save each checkpoint file.
122 Default value is :any:`bob.io.base.save`.
124 bob_features_load_fn: func
125 The function used to load each checkpoint file.
126 Default value is :any:`bob.io.base.load`.
128 bob_fit_supports_dask_array: bool
129 Indicates that the fit method of that estimator accepts dask arrays as
130 input. You may only use this tag if you accept X (N, M) and optionally y
131 (N) as input. The fit function may not accept any other input.
132 Default value is ``False``.
134 bob_fit_supports_dask_bag: bool
135 Indicates that the fit method of that estimator accepts dask bags as
136 input. If true, each input parameter of the fit will be a dask bag. You
137 still can (and normally you should) wrap your estimator with the
138 SampleWrapper so the same code runs with and without dask.
139 Default value is ``False``.
141 bob_checkpoint_features: bool
142 If False, the features of the estimator will never be saved.
143 Default value is ``True``.
145 Parameters
146 ----------
147 estimator: sklearn.BaseEstimator or None
148 An estimator class with tags that will overwrite the default values. Setting to
149 None will return the default values of every tags.
150 force_tags: dict[str, Any] or None
151 Tags with a non-default value that will overwrite the default and the estimator
152 tags.
154 Returns
155 -------
156 dict[str, Any]
157 The resulting tags with a value (either specified in the estimator, forced by
158 the arguments, or default)
159 """
160 force_tags = force_tags or {}
161 default_tags = {
162 "bob_input": "data",
163 "bob_transform_extra_input": tuple(),
164 "bob_fit_extra_input": tuple(),
165 "bob_output": "data",
166 "bob_checkpoint_extension": ".h5",
167 "bob_features_save_fn": bob.io.base.save,
168 "bob_features_load_fn": bob.io.base.load,
169 "bob_fit_supports_dask_array": False,
170 "bob_fit_supports_dask_bag": False,
171 "bob_checkpoint_features": True,
172 }
173 estimator_tags = estimator._get_tags() if estimator is not None else {}
174 return {**default_tags, **estimator_tags, **force_tags}
177class BaseWrapper(MetaEstimatorMixin, BaseEstimator):
178 """The base class for all wrappers."""
180 def _more_tags(self):
181 return self.estimator._more_tags()
184def _make_kwargs_from_samples(samples, arg_attr_list):
185 kwargs = {
186 arg: [getattr(s, attr) for s in samples] for arg, attr in arg_attr_list
187 }
188 return kwargs
191def _check_n_input_output(samples, output, func_name):
192 ls, lo = len(samples), len(output)
193 if ls != lo:
194 raise RuntimeError(
195 f"{func_name} got {ls} samples but returned {lo} samples!"
196 )
199class DelayedSamplesCall:
200 def __init__(
201 self, func, func_name, samples, sample_attribute="data", **kwargs
202 ):
203 super().__init__(**kwargs)
204 self.func = func
205 self.func_name = func_name
206 self.samples = samples
207 self.output = None
208 self.sample_attribute = sample_attribute
210 def __call__(self, index):
211 if self.output is None:
212 # Isolate invalid samples (when previous transformers returned None)
213 invalid_ids = [
214 i for i, s in enumerate(self.samples) if s.data is None
215 ]
216 valid_samples = [s for s in self.samples if s.data is not None]
217 # Process only the valid samples
218 if len(valid_samples) > 0:
219 X = SampleBatch(
220 valid_samples, sample_attribute=self.sample_attribute
221 )
222 self.output = self.func(X)
223 _check_n_input_output(
224 valid_samples, self.output, self.func_name
225 )
226 if self.output is None:
227 self.output = [None] * len(valid_samples)
228 # Rebuild the full batch of samples (include the previously failed)
229 if len(invalid_ids) > 0:
230 self.output = list(self.output)
231 for i in invalid_ids:
232 self.output.insert(i, None)
233 return self.output[index]
236class SampleWrapper(BaseWrapper, TransformerMixin):
237 """Wraps scikit-learn estimators to work with :any:`Sample`-based
238 pipelines.
240 Do not use this class except for scikit-learn estimators.
242 Attributes
243 ----------
244 estimator
245 The scikit-learn estimator that is wrapped.
246 fit_extra_arguments : [tuple]
247 Use this option if you want to pass extra arguments to the fit method of the
248 mixed instance. The format is a list of two value tuples. The first value in
249 tuples is the name of the argument that fit accepts, like ``y``, and the second
250 value is the name of the attribute that samples carry. For example, if you are
251 passing samples to the fit method and want to pass ``subject`` attributes of
252 samples as the ``y`` argument to the fit method, you can provide ``[("y",
253 "subject")]`` as the value for this attribute.
254 output_attribute : str
255 The name of a Sample attribute where the output of the estimator will be
256 saved to [Default is ``data``]. For example, if ``output_attribute`` is
257 ``"annotations"``, then ``sample.annotations`` will contain the output of
258 the estimator.
259 transform_extra_arguments : [tuple]
260 Similar to ``fit_extra_arguments`` but for the transform and other similar
261 methods.
262 delayed_output : bool
263 If ``True``, the output will be an instance of ``DelayedSample`` otherwise it
264 will be an instance of ``Sample``.
265 """
267 def __init__(
268 self,
269 estimator,
270 transform_extra_arguments=None,
271 fit_extra_arguments=None,
272 output_attribute=None,
273 input_attribute=None,
274 delayed_output=True,
275 **kwargs,
276 ):
277 super().__init__(**kwargs)
278 self.estimator = estimator
280 bob_tags = get_bob_tags(self.estimator)
281 self.input_attribute = input_attribute or bob_tags["bob_input"]
282 self.transform_extra_arguments = (
283 transform_extra_arguments or bob_tags["bob_transform_extra_input"]
284 )
285 self.fit_extra_arguments = (
286 fit_extra_arguments or bob_tags["bob_fit_extra_input"]
287 )
288 self.output_attribute = output_attribute or bob_tags["bob_output"]
289 self.delayed_output = delayed_output
291 def _samples_transform(self, samples, method_name):
292 # Transform either samples or samplesets
293 method = getattr(self.estimator, method_name)
294 func_name = f"{self}.{method_name}"
296 if isinstance(samples[0], SampleSet):
297 return [
298 SampleSet(
299 self._samples_transform(sset.samples, method_name),
300 parent=sset,
301 )
302 for sset in samples
303 ]
304 else:
305 kwargs = _make_kwargs_from_samples(
306 samples, self.transform_extra_arguments
307 )
308 delayed = DelayedSamplesCall(
309 partial(method, **kwargs),
310 func_name,
311 samples,
312 sample_attribute=self.input_attribute,
313 )
314 if self.output_attribute != "data":
315 # Edit the sample.<output_attribute> instead of data
316 for i, s in enumerate(samples):
317 setattr(s, self.output_attribute, delayed(i))
318 new_samples = samples
319 else:
320 new_samples = []
321 for i, s in enumerate(samples):
322 if self.delayed_output:
323 sample = DelayedSample(
324 partial(delayed, index=i), parent=s
325 )
326 else:
327 sample = Sample(delayed(i), parent=s)
328 new_samples.append(sample)
329 return new_samples
331 def transform(self, samples):
332 logger.debug(f"{_frmt(self)}.transform")
333 return self._samples_transform(samples, "transform")
335 def decision_function(self, samples):
336 logger.debug(f"{_frmt(self)}.decision_function")
337 return self._samples_transform(samples, "decision_function")
339 def predict(self, samples):
340 logger.debug(f"{_frmt(self)}.predict")
341 return self._samples_transform(samples, "predict")
343 def predict_proba(self, samples):
344 logger.debug(f"{_frmt(self)}.predict_proba")
345 return self._samples_transform(samples, "predict_proba")
347 def score(self, samples):
348 logger.debug(f"{_frmt(self)}.score")
349 return self._samples_transform(samples, "score")
351 def fit(self, samples, y=None, **kwargs):
352 # If samples is a dask bag or array, pass the arguments unmodified
353 # The data is already prepared in the DaskWrapper
354 if isinstance(samples, (dask.bag.core.Bag, dask.array.Array)):
355 logger.debug(f"{_frmt(self)}.fit")
356 self.estimator.fit(samples, y=y, **kwargs)
357 return self
359 if y is not None:
360 raise TypeError(
361 "We don't accept `y` in fit arguments because `y` should be part of "
362 "the sample. To pass `y` to the wrapped estimator, use "
363 "`fit_extra_arguments` tag."
364 )
366 if not estimator_requires_fit(self.estimator):
367 return self
369 # if the estimator needs to be fitted.
370 logger.debug(f"{_frmt(self)}.fit")
371 # Samples is list of either Sample or DelayedSample created with
372 # DelayedSamplesCall function, therefore some element in the list can be
373 # None.
374 # Filter out invalid samples (i.e. samples[k] == None), otherwise
375 # SampleBatch will fail and throw exceptions
376 samples = [
377 s for s in samples if getattr(s, self.input_attribute) is not None
378 ]
379 X = SampleBatch(samples, sample_attribute=self.input_attribute)
381 kwargs = _make_kwargs_from_samples(samples, self.fit_extra_arguments)
382 self.estimator = self.estimator.fit(X, **kwargs)
383 copy_learned_attributes(self.estimator, self)
384 return self
387class CheckpointWrapper(BaseWrapper, TransformerMixin):
388 """Wraps :any:`Sample`-based estimators so the results are saved in
389 disk.
391 Parameters
392 ----------
394 estimator
395 The scikit-learn estimator to be wrapped.
397 model_path: str
398 Saves the estimator state in this directory if the `estimator` is stateful
400 features_dir: str
401 Saves the transformed data in this directory
403 extension: str
404 Default extension of the transformed features.
405 If None, will use the ``bob_checkpoint_extension`` tag in the estimator, or
406 default to ``.h5``.
408 save_func
409 Pointer to a customized function that saves transformed features to disk.
410 If None, will use the ``bob_feature_save_fn`` tag in the estimator, or default
411 to ``bob.io.base.save``.
413 load_func
414 Pointer to a customized function that loads transformed features from disk.
415 If None, will use the ``bob_feature_load_fn`` tag in the estimator, or default
416 to ``bob.io.base.load``.
418 sample_attribute: str
419 Defines the payload attribute of the sample.
420 If None, will use the ``bob_output`` tag in the estimator, or default to
421 ``data``.
423 hash_fn
424 Pointer to a hash function. This hash function maps
425 `sample.key` to a hash code and this hash code corresponds a relative directory
426 where a single `sample` will be checkpointed.
427 This is useful when is desirable file directories with less than
428 a certain number of files.
430 attempts
431 Number of checkpoint attempts. Sometimes, because of network/disk issues
432 files can't be saved. This argument sets the maximum number of attempts
433 to checkpoint a sample.
435 force: bool
436 If True, will recompute the checkpoints even if they exists
438 """
440 def __init__(
441 self,
442 estimator,
443 model_path=None,
444 features_dir=None,
445 extension=None,
446 save_func=None,
447 load_func=None,
448 sample_attribute=None,
449 hash_fn=None,
450 attempts=10,
451 force=False,
452 **kwargs,
453 ):
454 super().__init__(**kwargs)
455 bob_tags = get_bob_tags(estimator)
456 self.extension = extension or bob_tags["bob_checkpoint_extension"]
457 self.save_func = save_func or bob_tags["bob_features_save_fn"]
458 self.load_func = load_func or bob_tags["bob_features_load_fn"]
459 self.sample_attribute = sample_attribute or bob_tags["bob_output"]
461 if not bob_tags["bob_checkpoint_features"]:
462 logger.info(
463 "Checkpointing is disabled for %s beacuse the bob_checkpoint_features tag is False.",
464 estimator,
465 )
466 features_dir = None
468 self.force = force
469 self.estimator = estimator
470 self.model_path = model_path
471 self.features_dir = features_dir
472 self.hash_fn = hash_fn
473 self.attempts = attempts
475 # Paths check
476 if model_path is None and features_dir is None:
477 logger.warning(
478 "Both model_path and features_dir are None. "
479 f"Nothing will be checkpointed. From: {self}"
480 )
482 def _checkpoint_transform(self, samples, method_name):
483 # Transform either samples or samplesets
484 method = getattr(self.estimator, method_name)
486 # if features_dir is None, just transform all samples at once
487 if self.features_dir is None:
488 return method(samples)
490 def _transform_samples(samples):
491 paths = [self.make_path(s) for s in samples]
492 should_compute_list = [
493 self.force or p is None or not os.path.isfile(p) for p in paths
494 ]
495 # call method on non-checkpointed samples
496 non_existing_samples = [
497 s
498 for s, should_compute in zip(samples, should_compute_list)
499 if should_compute
500 ]
501 # non_existing_samples could be empty
502 computed_features = []
503 if non_existing_samples:
504 computed_features = method(non_existing_samples)
505 _check_n_input_output(
506 non_existing_samples, computed_features, method
507 )
508 # return computed features and checkpointed features
509 features, com_feat_index = [], 0
511 for s, p, should_compute in zip(
512 samples, paths, should_compute_list
513 ):
514 if should_compute:
515 feat = computed_features[com_feat_index]
516 com_feat_index += 1
517 # save the computed feature when valid (not None)
518 if (
519 p is not None
520 and getattr(feat, self.sample_attribute) is not None
521 ):
522 self.save(feat)
523 # sometimes loading the file fails randomly
524 for _ in range(self.attempts):
525 try:
526 feat = self.load(s, p)
527 break
528 except Exception:
529 error = traceback.format_exc()
530 time.sleep(0.1)
531 else:
532 raise RuntimeError(
533 f"Could not load using: {self.load}({s}, {p}) with the following error: {error}"
534 )
535 features.append(feat)
536 else:
537 features.append(self.load(s, p))
538 return features
540 if isinstance(samples[0], SampleSet):
541 return [
542 SampleSet(_transform_samples(s.samples), parent=s)
543 for s in samples
544 ]
545 else:
546 return _transform_samples(samples)
548 def transform(self, samples):
549 logger.debug(f"{_frmt(self)}.transform")
550 return self._checkpoint_transform(samples, "transform")
552 def decision_function(self, samples):
553 logger.debug(f"{_frmt(self)}.decision_function")
554 return self.estimator.decision_function(samples)
556 def predict(self, samples):
557 logger.debug(f"{_frmt(self)}.predict")
558 return self.estimator.predict(samples)
560 def predict_proba(self, samples):
561 logger.debug(f"{_frmt(self)}.predict_proba")
562 return self.estimator.predict_proba(samples)
564 def score(self, samples):
565 logger.debug(f"{_frmt(self)}.score")
566 return self.estimator.score(samples)
568 def fit(self, samples, y=None, **kwargs):
569 if not estimator_requires_fit(self.estimator):
570 return self
572 # if the estimator needs to be fitted.
573 logger.debug(f"{_frmt(self)}.fit")
575 if self.model_path is not None and os.path.isfile(self.model_path):
576 logger.info("Found a checkpoint for model. Loading ...")
577 return self.load_model()
579 self.estimator = self.estimator.fit(samples, y=y, **kwargs)
580 copy_learned_attributes(self.estimator, self)
581 return self.save_model()
583 def make_path(self, sample):
584 if self.features_dir is None:
585 return None
587 key = str(sample.key)
588 if key.startswith(os.sep) or ".." in key:
589 raise ValueError(
590 "Sample.key values should be relative paths with no "
591 f"reference to upper folders. Got: {key}"
592 )
594 hash_dir_name = self.hash_fn(key) if self.hash_fn is not None else ""
596 return os.path.join(
597 self.features_dir, hash_dir_name, key + self.extension
598 )
600 def save(self, sample):
601 path = self.make_path(sample)
602 # Gets sample.data or sample.<sample_attribute> if specified
603 to_save = getattr(sample, self.sample_attribute)
604 for _ in range(self.attempts):
605 try:
606 dirname = os.path.dirname(path)
607 os.makedirs(dirname, exist_ok=True)
609 # Atomic writing
610 extension = "".join(Path(path).suffixes)
611 with tempfile.NamedTemporaryFile(
612 dir=dirname, delete=False, suffix=extension
613 ) as f:
614 self.save_func(to_save, f.name)
615 os.replace(f.name, path)
617 # test loading
618 self.load_func(path)
619 break
620 except Exception:
621 error = traceback.format_exc()
622 time.sleep(0.1)
623 else:
624 raise RuntimeError(
625 f"Could not save {to_save} of type {type(to_save)} using {self.save_func} with the following error: {error}"
626 )
628 def load(self, sample, path):
629 # because we are checkpointing, we return a DelayedSample
630 # instead of a normal (preloaded) sample. This allows the next
631 # phase to avoid loading it would it be unnecessary (e.g. next
632 # phase is already check-pointed)
633 if self.sample_attribute == "data":
634 return DelayedSample(partial(self.load_func, path), parent=sample)
635 else:
636 loaded = self.load_func(path)
637 setattr(sample, self.sample_attribute, loaded)
638 return sample
640 def load_model(self):
641 if not estimator_requires_fit(self.estimator):
642 return self
643 with open(self.model_path, "rb") as f:
644 loaded_estimator = cloudpickle.load(f)
645 # We update self.estimator instead of replacing it because
646 # self.estimator might be referenced elsewhere.
647 _update_estimator(self.estimator, loaded_estimator)
648 return self
650 def save_model(self):
651 if (
652 not estimator_requires_fit(self.estimator)
653 or self.model_path is None
654 ):
655 return self
656 os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
657 with open(self.model_path, "wb") as f:
658 cloudpickle.dump(self.estimator, f)
659 return self
662def _update_estimator(estimator, loaded_estimator):
663 # recursively update estimator with loaded_estimator without replacing
664 # estimator.estimator
665 if hasattr(estimator, "estimator"):
666 _update_estimator(estimator.estimator, loaded_estimator.estimator)
667 for k, v in loaded_estimator.__dict__.items():
668 if k != "estimator":
669 estimator.__dict__[k] = v
672def is_checkpointed(estimator):
673 return is_instance_nested(estimator, "estimator", CheckpointWrapper)
676def getattr_nested(estimator, attr):
677 if hasattr(estimator, attr):
678 return getattr(estimator, attr)
679 elif hasattr(estimator, "estimator"):
680 return getattr_nested(estimator.estimator, attr)
681 return None
684def _sample_attribute(samples, attribute):
685 return [getattr(s, attribute) for s in samples]
688def _len_samples(samples):
689 return [len(samples)]
692def _shape_samples(samples):
693 return [[s.shape for s in samples]]
696def _array_from_sample_bags(X: dask.bag.Bag, attribute: str, ndim: int = 2):
697 if ndim not in (1, 2):
698 raise NotImplementedError(f"ndim must be 1 or 2. Got: {ndim}")
700 if ndim == 1:
701 stack_function = np.concatenate
702 else:
703 stack_function = np.vstack
705 # because samples could be delayed samples, we convert sample bags to
706 # sample.attribute bags first and then persist
707 X = X.map_partitions(_sample_attribute, attribute=attribute).persist()
709 # convert sample.attribute bags to arrays
710 delayeds = X.to_delayed()
711 lengths = X.map_partitions(_len_samples)
712 shapes = X.map_partitions(_shape_samples)
713 lengths, shapes = dask.compute(lengths, shapes)
714 dtype, X = None, []
715 for length_, shape_, delayed_samples_list in zip(lengths, shapes, delayeds):
716 delayed_samples_list._length = length_
718 if dtype is None:
719 dtype = np.array(delayed_samples_list[0].compute()).dtype
721 # stack the data in each bag
722 stacked_samples = dask.delayed(stack_function)(delayed_samples_list)
723 # make sure shapes are at least 2d
724 for i, s in enumerate(shape_):
725 if len(s) == 1 and ndim == 2:
726 shape_[i] = (1,) + s
727 elif len(s) == 0:
728 # if shape is empty, it means that the samples are scalars
729 if ndim == 1:
730 shape_[i] = (1,)
731 else:
732 shape_[i] = (1, 1)
733 stacked_shape = sum(s[0] for s in shape_)
734 stacked_shape = [stacked_shape] + list(shape_[0][1:])
736 darray = da.from_delayed(
737 stacked_samples,
738 stacked_shape,
739 dtype=dtype,
740 name=False,
741 )
742 X.append(darray)
744 # stack data from all bags
745 X = stack_function(X)
746 return X
749class DaskWrapper(BaseWrapper, TransformerMixin):
750 """Wraps Scikit estimators to handle Dask Bags as input.
752 Parameters
753 ----------
755 fit_resource: str
756 Mark the delayed(self.fit) with this value. This can be used in
757 a future delayed(self.fit).compute(resources=resource_tape) so
758 dask scheduler can place this task in a particular resource
759 (e.g GPU)
761 transform_resource: str
762 Mark the delayed(self.transform) with this value. This can be used in
763 a future delayed(self.transform).compute(resources=resource_tape) so
764 dask scheduler can place this task in a particular resource
765 (e.g GPU)
766 """
768 def __init__(
769 self,
770 estimator,
771 fit_tag=None,
772 transform_tag=None,
773 fit_supports_dask_array=None,
774 fit_supports_dask_bag=None,
775 **kwargs,
776 ):
777 super().__init__(**kwargs)
778 self.estimator = estimator
779 self._dask_state = estimator
780 self.resource_tags = dict()
781 self.fit_tag = fit_tag
782 self.transform_tag = transform_tag
783 self.fit_supports_dask_array = (
784 fit_supports_dask_array
785 or get_bob_tags(self.estimator)["bob_fit_supports_dask_array"]
786 )
787 self.fit_supports_dask_bag = (
788 fit_supports_dask_bag
789 or get_bob_tags(self.estimator)["bob_fit_supports_dask_bag"]
790 )
792 def _make_dask_resource_tag(self, tag):
793 return {tag: 1}
795 def _dask_transform(self, X, method_name):
796 graph_name = f"{_frmt(self)}.{method_name}"
797 logger.debug(graph_name)
799 def _transf(X_line, dask_state):
800 return getattr(dask_state, method_name)(X_line)
802 # change the name to have a better name in dask graphs
803 _transf.__name__ = graph_name
804 # scatter the dask_state to all workers for efficiency
805 dask_state = dask.delayed(self._dask_state)
806 map_partitions = X.map_partitions(_transf, dask_state)
807 if self.transform_tag:
808 self.resource_tags[
809 tuple(map_partitions.dask.keys())
810 ] = self._make_dask_resource_tag(self.transform_tag)
812 return map_partitions
814 def transform(self, samples):
815 return self._dask_transform(samples, "transform")
817 def decision_function(self, samples):
818 return self._dask_transform(samples, "decision_function")
820 def predict(self, samples):
821 return self._dask_transform(samples, "predict")
823 def predict_proba(self, samples):
824 return self._dask_transform(samples, "predict_proba")
826 def score(self, samples):
827 return self._dask_transform(samples, "score")
829 def _get_fit_params_from_sample_bags(self, bags):
830 logger.debug("Preparing data as dask arrays for fit")
832 input_attribute = getattr_nested(self, "input_attribute")
833 fit_extra_arguments = getattr_nested(self, "fit_extra_arguments")
835 # convert X which is a dask bag to a dask array
836 X = _array_from_sample_bags(bags, input_attribute, ndim=2)
837 kwargs = dict()
838 for arg, attr in fit_extra_arguments:
839 # we only create a dask array if the arg is named ``y``
840 if arg == "y":
841 kwargs[arg] = _array_from_sample_bags(bags, attr, ndim=1)
842 else:
843 raise NotImplementedError(
844 f"fit_extra_arguments: {arg} is not supported, only ``y`` is supported."
845 )
847 return X, kwargs
849 def _fit_on_dask_array(self, bags, y=None, **fit_params):
850 if y is not None or fit_params:
851 raise ValueError(
852 "y or fit_params should be passed through fit_extra_arguments of the SampleWrapper"
853 )
855 X, fit_params = self._get_fit_params_from_sample_bags(bags)
856 self.estimator.fit(X, **fit_params)
857 return self
859 def _fit_on_dask_bag(self, bags, y=None, **fit_params):
860 # X is a dask bag of Samples convert to required fit parameters
861 logger.debug("Converting dask bag of samples to bags of fit parameters")
863 def getattr_list(samples, attribute):
864 return SampleBatch(samples, sample_attribute=attribute)
866 # we prepare the input parameters here instead of doing this in the
867 # SampleWrapper. The SampleWrapper class then will pass these dask bags
868 # directly to the underlying estimator.
869 bob_tags = get_bob_tags(self.estimator)
870 input_attribute = bob_tags["bob_input"]
871 fit_extra_arguments = bob_tags["bob_fit_extra_input"]
873 X = bags.map_partitions(getattr_list, input_attribute)
874 kwargs = {
875 arg: bags.map_partitions(getattr_list, attr)
876 for arg, attr in fit_extra_arguments
877 }
879 self.estimator.fit(X, **kwargs)
880 return self
882 def fit(self, X, y=None, **fit_params):
883 if not estimator_requires_fit(self.estimator):
884 return self
885 logger.debug(f"{_frmt(self)}.fit")
887 model_path = None
888 if is_checkpointed(self):
889 model_path = getattr_nested(self, "model_path")
890 model_path = model_path or ""
891 if os.path.isfile(model_path):
892 logger.info(
893 f"Checkpointed estimator detected at {model_path}. The estimator ({_frmt(self)}) will be loaded and training will not run."
894 )
895 # we should load the estimator outside dask graph to make sure that
896 # the estimator loads in the scheduler
897 self.estimator.load_model()
898 return self
900 if self.fit_supports_dask_array:
901 return self._fit_on_dask_array(X, y, **fit_params)
902 elif self.fit_supports_dask_bag:
903 return self._fit_on_dask_bag(X, y, **fit_params)
905 def _fit(X, y, **fit_params):
906 try:
907 self.estimator = self.estimator.fit(X, y, **fit_params)
908 except Exception as e:
909 raise RuntimeError(
910 f"Something went wrong when fitting {self.estimator} "
911 f"from {self}"
912 ) from e
913 copy_learned_attributes(self.estimator, self)
914 return self.estimator
916 # change the name to have a better name in dask graphs
917 _fit.__name__ = f"{_frmt(self)}.fit"
919 _fit_call = delayed(_fit)(X, y, **fit_params)
920 self._dask_state = _fit_call.persist()
922 if self.fit_tag is not None:
923 # If you do `delayed(_fit)(X, y)`, two tasks are generated;
924 # the `finalize-TASK` and `TASK`. With this, we make sure
925 # that the two are annotated
926 self.resource_tags[
927 tuple(
928 [
929 f"{k}{str(self._dask_state.key)}"
930 for k in ["", "finalize-"]
931 ]
932 )
933 ] = self._make_dask_resource_tag(self.fit_tag)
935 return self
938class ToDaskBag(TransformerMixin, BaseEstimator):
939 """Transform an arbitrary iterator into a :any:`dask.bag.Bag`
941 Example
942 -------
943 >>> import bob.pipelines
944 >>> transformer = bob.pipelines.ToDaskBag()
945 >>> dask_bag = transformer.transform([1,2,3])
946 >>> # dask_bag.map_partitions(...)
948 Attributes
949 ----------
950 npartitions : int
951 Number of partitions used in :any:`dask.bag.from_sequence`
952 """
954 def __init__(self, npartitions=None, partition_size=None, **kwargs):
955 super().__init__(**kwargs)
956 self.npartitions = npartitions
957 self.partition_size = partition_size
959 def fit(self, X, y=None):
960 return self
962 def transform(self, X):
963 logger.debug(f"{_frmt(self)}.transform")
964 if self.partition_size is None:
965 return dask.bag.from_sequence(X, npartitions=self.npartitions)
966 else:
967 return dask.bag.from_sequence(X, partition_size=self.partition_size)
969 def _more_tags(self):
970 return {"requires_fit": False}
973def wrap(bases, estimator=None, **kwargs):
974 """Wraps several estimators inside each other.
976 If ``estimator`` is a pipeline, the estimators in that pipeline are wrapped.
978 The default behavior of wrappers can be customized through the tags; see
979 :any:`bob.pipelines.get_bob_tags` for more information.
981 Parameters
982 ----------
983 bases : list
984 A list of classes to be used to wrap ``estimator``.
985 estimator : :any:`object`, optional
986 An initial estimator to be wrapped inside other wrappers.
987 If None, the first class will be used to initialize the estimator.
988 **kwargs
989 Extra parameters passed to the init of classes.
991 Returns
992 -------
993 object
994 The wrapped estimator
996 Raises
997 ------
998 ValueError
999 If not all kwargs are consumed.
1000 """
1001 # if wrappers are passed as strings convert them to classes
1002 for i, w in enumerate(bases):
1003 if isinstance(w, str):
1004 bases[i] = {
1005 "sample": SampleWrapper,
1006 "checkpoint": CheckpointWrapper,
1007 "dask": DaskWrapper,
1008 }[w.lower()]
1010 def _wrap(estimator, **kwargs):
1011 # wrap the object and pass the kwargs
1012 for w_class in bases:
1013 valid_params = w_class._get_param_names()
1014 params = {k: kwargs.pop(k) for k in valid_params if k in kwargs}
1015 if estimator is None:
1016 estimator = w_class(**params)
1017 else:
1018 estimator = w_class(estimator, **params)
1019 return estimator, kwargs
1021 # if the estimator is a pipeline, wrap its steps instead.
1022 # We don't look for pipelines recursively because most of the time we
1023 # don't want the inner pipeline's steps to be wrapped.
1024 if isinstance(estimator, Pipeline):
1025 # wrap inner steps
1026 for idx, name, trans in estimator._iter():
1027 # when checkpointing a pipeline, checkpoint each transformer in its own folder
1028 new_kwargs = dict(kwargs)
1029 features_dir, model_path = (
1030 kwargs.get("features_dir"),
1031 kwargs.get("model_path"),
1032 )
1033 if features_dir is not None:
1034 new_kwargs["features_dir"] = os.path.join(features_dir, name)
1035 if model_path is not None:
1036 new_kwargs["model_path"] = os.path.join(
1037 model_path, f"{name}.pkl"
1038 )
1040 trans, leftover = _wrap(trans, **new_kwargs)
1041 estimator.steps[idx] = (name, trans)
1043 # if being wrapped with DaskWrapper, add ToDaskBag to the steps
1044 if DaskWrapper in bases:
1045 valid_params = ToDaskBag._get_param_names()
1046 params = {k: leftover.pop(k) for k in valid_params if k in leftover}
1047 dask_bag = ToDaskBag(**params)
1048 estimator.steps.insert(0, ("ToDaskBag", dask_bag))
1049 else:
1050 estimator, leftover = _wrap(estimator, **kwargs)
1052 if leftover:
1053 raise ValueError(f"Got extra kwargs that were not consumed: {leftover}")
1055 return estimator
1058def dask_tags(estimator):
1059 """Recursively collects resource_tags in dasked estimators."""
1060 tags = {}
1062 if hasattr(estimator, "estimator"):
1063 tags.update(dask_tags(estimator.estimator))
1065 if isinstance(estimator, Pipeline):
1066 for idx, name, trans in estimator._iter():
1067 tags.update(dask_tags(trans))
1069 if hasattr(estimator, "resource_tags"):
1070 tags.update(estimator.resource_tags)
1072 return tags
1075def estimator_requires_fit(estimator):
1076 if not hasattr(estimator, "_get_tags"):
1077 raise ValueError(
1078 f"Passed estimator: {estimator} does not have the _get_tags method."
1079 )
1081 # If the estimator is wrapped, check the wrapped estimator
1082 if is_instance_nested(
1083 estimator, "estimator", (SampleWrapper, CheckpointWrapper, DaskWrapper)
1084 ):
1085 return estimator_requires_fit(estimator.estimator)
1087 # If estimator is a Pipeline, check if any of the steps requires fit
1088 if isinstance(estimator, Pipeline):
1089 return any([estimator_requires_fit(e) for _, e in estimator.steps])
1091 # We check for the FunctionTransformer since theoretically it
1092 # does require fit but it does not really need it.
1093 if is_instance_nested(estimator, "estimator", FunctionTransformer):
1094 return False
1096 # if the estimator does not require fit, don't call fit
1097 # See: https://scikit-learn.org/stable/developers/develop.html
1098 tags = estimator._get_tags()
1099 return tags["requires_fit"]
1102def is_instance_nested(instance, attribute, isinstance_of):
1103 """
1104 Check if an object and its nested objects is an instance of a class.
1106 This is useful while using aggregation and it's necessary to check if some
1107 functionally was aggregated
1109 Parameters
1110 ----------
1111 instance:
1112 Object to be searched
1114 attribute:
1115 Attribute name to be recursively searched
1117 isinstance_of:
1118 Instance class to be searched
1120 """
1121 if isinstance(instance, isinstance_of):
1122 return True
1124 if not hasattr(instance, attribute):
1125 return False
1127 # Checking the current object and its immediate nested
1128 if isinstance(instance, isinstance_of) or isinstance(
1129 getattr(instance, attribute), isinstance_of
1130 ):
1131 return True
1132 else:
1133 # Recursive search
1134 return is_instance_nested(
1135 getattr(instance, attribute), attribute, isinstance_of
1136 )
1139def is_pipeline_wrapped(estimator, wrapper):
1140 """
1141 Iterates over the transformers of :py:class:`sklearn.pipeline.Pipeline` checking and
1142 checks if they were wrapped with `wrapper` class
1144 Parameters
1145 ----------
1147 estimator: sklearn.pipeline.Pipeline
1148 Pipeline to be checked
1150 wrapper: type
1151 The Wrapper class or a tuple of classes to be checked
1153 Returns
1154 -------
1155 list
1156 Returns a list of boolean values, where each value indicates if the corresponding estimator is wrapped or not
1157 """
1159 if not isinstance(estimator, Pipeline):
1160 raise ValueError(f"{estimator} is not an instance of Pipeline")
1162 return [
1163 is_instance_nested(trans, "estimator", wrapper)
1164 for _, _, trans in estimator._iter()
1165 ]