Coverage for src/bob/pipelines/wrappers.py: 85%

472 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 21:32 +0200

1"""Scikit-learn Estimator Wrappers.""" 

2import logging 

3import os 

4import tempfile 

5import time 

6import traceback 

7 

8from functools import partial 

9from pathlib import Path 

10 

11import cloudpickle 

12import dask 

13import dask.array as da 

14import dask.bag 

15import numpy as np 

16 

17from dask import delayed 

18from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin 

19from sklearn.pipeline import Pipeline 

20from sklearn.preprocessing import FunctionTransformer 

21 

22import bob.io.base 

23 

24from .sample import DelayedSample, Sample, SampleBatch, SampleSet 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def _frmt(estimator, limit=30, attr="estimator"): 

30 # default value of limit is chosen so the log can be seen in dask graphs 

31 def _n(e): 

32 return e.__class__.__name__.replace("Wrapper", "") 

33 

34 name = "" 

35 while hasattr(estimator, attr): 

36 name += f"{_n(estimator)}|" 

37 estimator = getattr(estimator, attr) 

38 

39 if ( 

40 isinstance(estimator, FunctionTransformer) 

41 and type(estimator) is FunctionTransformer 

42 ): 

43 name += str(estimator.func.__name__) 

44 else: 

45 name += str(estimator) 

46 

47 name = f"{name:.{limit}}" 

48 return name 

49 

50 

51def copy_learned_attributes(from_estimator, to_estimator): 

52 attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith("_")} 

53 

54 for k, v in attrs.items(): 

55 setattr(to_estimator, k, v) 

56 

57 

58def get_bob_tags(estimator=None, force_tags=None): 

59 """Returns the default tags of a Transformer unless forced or specified. 

60 

61 Relies on the tags API of sklearn to set and retrieve the tags. 

62 

63 Specify an estimator tag values with ``estimator._more_tags``:: 

64 

65 class My_annotator_transformer(sklearn.base.BaseEstimator): 

66 def _more_tags(self): 

67 return {"bob_output": "annotations"} 

68 

69 The returned tags will take their value with the following priority: 

70 

71 1. key:value in `force_tags`, if it is present; 

72 2. key:value in `estimator` tags (set with `estimator._more_tags()`) if it exists; 

73 3. the default value for that tag if none of the previous exist. 

74 

75 Examples 

76 -------- 

77 bob_input: str 

78 The Sample attribute passed to the first argument of the fit or transform method. 

79 Default value is ``data``. 

80 Example:: 

81 

82 {"bob_input": ("annotations")} 

83 

84 will result in:: 

85 

86 estimator.transform(sample.annotations) 

87 

88 bob_transform_extra_input: tuple of str 

89 Each element of the tuple is a str representing an attribute of a Sample 

90 object. Each attribute of the sample will be passed as argument to the transform 

91 method in that order. Default value is an empty tuple ``(,)``. 

92 Example:: 

93 

94 {"bob_transform_extra_input": (("kwarg_1","annotations"), ("kwarg_2","gender"))} 

95 

96 will result in:: 

97 

98 estimator.transform(sample.data, kwarg_1=sample.annotations, kwarg_2=sample.gender) 

99 

100 bob_fit_extra_input: tuple of str 

101 Each element of the tuple is a str representing an attribute of a Sample 

102 object. Each attribute of the sample will be passed as argument to the fit 

103 method in that order. Default value is an empty tuple ``(,)``. 

104 Example:: 

105 

106 {"bob_fit_extra_input": (("y", "annotations"), ("extra "metadata"))} 

107 

108 will result in:: 

109 

110 estimator.fit(sample.data, y=sample.annotations, extra=sample.metadata) 

111 

112 bob_output: str 

113 The Sample attribute in which the output of the transform is stored. 

114 Default value is ``data``. 

115 

116 bob_checkpoint_extension: str 

117 The extension of each checkpoint file. 

118 Default value is ``.h5``. 

119 

120 bob_features_save_fn: func 

121 The function used to save each checkpoint file. 

122 Default value is :any:`bob.io.base.save`. 

123 

124 bob_features_load_fn: func 

125 The function used to load each checkpoint file. 

126 Default value is :any:`bob.io.base.load`. 

127 

128 bob_fit_supports_dask_array: bool 

129 Indicates that the fit method of that estimator accepts dask arrays as 

130 input. You may only use this tag if you accept X (N, M) and optionally y 

131 (N) as input. The fit function may not accept any other input. 

132 Default value is ``False``. 

133 

134 bob_fit_supports_dask_bag: bool 

135 Indicates that the fit method of that estimator accepts dask bags as 

136 input. If true, each input parameter of the fit will be a dask bag. You 

137 still can (and normally you should) wrap your estimator with the 

138 SampleWrapper so the same code runs with and without dask. 

139 Default value is ``False``. 

140 

141 bob_checkpoint_features: bool 

142 If False, the features of the estimator will never be saved. 

143 Default value is ``True``. 

144 

145 Parameters 

146 ---------- 

147 estimator: sklearn.BaseEstimator or None 

148 An estimator class with tags that will overwrite the default values. Setting to 

149 None will return the default values of every tags. 

150 force_tags: dict[str, Any] or None 

151 Tags with a non-default value that will overwrite the default and the estimator 

152 tags. 

153 

154 Returns 

155 ------- 

156 dict[str, Any] 

157 The resulting tags with a value (either specified in the estimator, forced by 

158 the arguments, or default) 

159 """ 

160 force_tags = force_tags or {} 

161 default_tags = { 

162 "bob_input": "data", 

163 "bob_transform_extra_input": tuple(), 

164 "bob_fit_extra_input": tuple(), 

165 "bob_output": "data", 

166 "bob_checkpoint_extension": ".h5", 

167 "bob_features_save_fn": bob.io.base.save, 

168 "bob_features_load_fn": bob.io.base.load, 

169 "bob_fit_supports_dask_array": False, 

170 "bob_fit_supports_dask_bag": False, 

171 "bob_checkpoint_features": True, 

172 } 

173 estimator_tags = estimator._get_tags() if estimator is not None else {} 

174 return {**default_tags, **estimator_tags, **force_tags} 

175 

176 

177class BaseWrapper(MetaEstimatorMixin, BaseEstimator): 

178 """The base class for all wrappers.""" 

179 

180 def _more_tags(self): 

181 return self.estimator._more_tags() 

182 

183 

184def _make_kwargs_from_samples(samples, arg_attr_list): 

185 kwargs = { 

186 arg: [getattr(s, attr) for s in samples] for arg, attr in arg_attr_list 

187 } 

188 return kwargs 

189 

190 

191def _check_n_input_output(samples, output, func_name): 

192 ls, lo = len(samples), len(output) 

193 if ls != lo: 

194 raise RuntimeError( 

195 f"{func_name} got {ls} samples but returned {lo} samples!" 

196 ) 

197 

198 

199class DelayedSamplesCall: 

200 def __init__( 

201 self, func, func_name, samples, sample_attribute="data", **kwargs 

202 ): 

203 super().__init__(**kwargs) 

204 self.func = func 

205 self.func_name = func_name 

206 self.samples = samples 

207 self.output = None 

208 self.sample_attribute = sample_attribute 

209 

210 def __call__(self, index): 

211 if self.output is None: 

212 # Isolate invalid samples (when previous transformers returned None) 

213 invalid_ids = [ 

214 i for i, s in enumerate(self.samples) if s.data is None 

215 ] 

216 valid_samples = [s for s in self.samples if s.data is not None] 

217 # Process only the valid samples 

218 if len(valid_samples) > 0: 

219 X = SampleBatch( 

220 valid_samples, sample_attribute=self.sample_attribute 

221 ) 

222 self.output = self.func(X) 

223 _check_n_input_output( 

224 valid_samples, self.output, self.func_name 

225 ) 

226 if self.output is None: 

227 self.output = [None] * len(valid_samples) 

228 # Rebuild the full batch of samples (include the previously failed) 

229 if len(invalid_ids) > 0: 

230 self.output = list(self.output) 

231 for i in invalid_ids: 

232 self.output.insert(i, None) 

233 return self.output[index] 

234 

235 

236class SampleWrapper(BaseWrapper, TransformerMixin): 

237 """Wraps scikit-learn estimators to work with :any:`Sample`-based 

238 pipelines. 

239 

240 Do not use this class except for scikit-learn estimators. 

241 

242 Attributes 

243 ---------- 

244 estimator 

245 The scikit-learn estimator that is wrapped. 

246 fit_extra_arguments : [tuple] 

247 Use this option if you want to pass extra arguments to the fit method of the 

248 mixed instance. The format is a list of two value tuples. The first value in 

249 tuples is the name of the argument that fit accepts, like ``y``, and the second 

250 value is the name of the attribute that samples carry. For example, if you are 

251 passing samples to the fit method and want to pass ``subject`` attributes of 

252 samples as the ``y`` argument to the fit method, you can provide ``[("y", 

253 "subject")]`` as the value for this attribute. 

254 output_attribute : str 

255 The name of a Sample attribute where the output of the estimator will be 

256 saved to [Default is ``data``]. For example, if ``output_attribute`` is 

257 ``"annotations"``, then ``sample.annotations`` will contain the output of 

258 the estimator. 

259 transform_extra_arguments : [tuple] 

260 Similar to ``fit_extra_arguments`` but for the transform and other similar 

261 methods. 

262 delayed_output : bool 

263 If ``True``, the output will be an instance of ``DelayedSample`` otherwise it 

264 will be an instance of ``Sample``. 

265 """ 

266 

267 def __init__( 

268 self, 

269 estimator, 

270 transform_extra_arguments=None, 

271 fit_extra_arguments=None, 

272 output_attribute=None, 

273 input_attribute=None, 

274 delayed_output=True, 

275 **kwargs, 

276 ): 

277 super().__init__(**kwargs) 

278 self.estimator = estimator 

279 

280 bob_tags = get_bob_tags(self.estimator) 

281 self.input_attribute = input_attribute or bob_tags["bob_input"] 

282 self.transform_extra_arguments = ( 

283 transform_extra_arguments or bob_tags["bob_transform_extra_input"] 

284 ) 

285 self.fit_extra_arguments = ( 

286 fit_extra_arguments or bob_tags["bob_fit_extra_input"] 

287 ) 

288 self.output_attribute = output_attribute or bob_tags["bob_output"] 

289 self.delayed_output = delayed_output 

290 

291 def _samples_transform(self, samples, method_name): 

292 # Transform either samples or samplesets 

293 method = getattr(self.estimator, method_name) 

294 func_name = f"{self}.{method_name}" 

295 

296 if isinstance(samples[0], SampleSet): 

297 return [ 

298 SampleSet( 

299 self._samples_transform(sset.samples, method_name), 

300 parent=sset, 

301 ) 

302 for sset in samples 

303 ] 

304 else: 

305 kwargs = _make_kwargs_from_samples( 

306 samples, self.transform_extra_arguments 

307 ) 

308 delayed = DelayedSamplesCall( 

309 partial(method, **kwargs), 

310 func_name, 

311 samples, 

312 sample_attribute=self.input_attribute, 

313 ) 

314 if self.output_attribute != "data": 

315 # Edit the sample.<output_attribute> instead of data 

316 for i, s in enumerate(samples): 

317 setattr(s, self.output_attribute, delayed(i)) 

318 new_samples = samples 

319 else: 

320 new_samples = [] 

321 for i, s in enumerate(samples): 

322 if self.delayed_output: 

323 sample = DelayedSample( 

324 partial(delayed, index=i), parent=s 

325 ) 

326 else: 

327 sample = Sample(delayed(i), parent=s) 

328 new_samples.append(sample) 

329 return new_samples 

330 

331 def transform(self, samples): 

332 logger.debug(f"{_frmt(self)}.transform") 

333 return self._samples_transform(samples, "transform") 

334 

335 def decision_function(self, samples): 

336 logger.debug(f"{_frmt(self)}.decision_function") 

337 return self._samples_transform(samples, "decision_function") 

338 

339 def predict(self, samples): 

340 logger.debug(f"{_frmt(self)}.predict") 

341 return self._samples_transform(samples, "predict") 

342 

343 def predict_proba(self, samples): 

344 logger.debug(f"{_frmt(self)}.predict_proba") 

345 return self._samples_transform(samples, "predict_proba") 

346 

347 def score(self, samples): 

348 logger.debug(f"{_frmt(self)}.score") 

349 return self._samples_transform(samples, "score") 

350 

351 def fit(self, samples, y=None, **kwargs): 

352 # If samples is a dask bag or array, pass the arguments unmodified 

353 # The data is already prepared in the DaskWrapper 

354 if isinstance(samples, (dask.bag.core.Bag, dask.array.Array)): 

355 logger.debug(f"{_frmt(self)}.fit") 

356 self.estimator.fit(samples, y=y, **kwargs) 

357 return self 

358 

359 if y is not None: 

360 raise TypeError( 

361 "We don't accept `y` in fit arguments because `y` should be part of " 

362 "the sample. To pass `y` to the wrapped estimator, use " 

363 "`fit_extra_arguments` tag." 

364 ) 

365 

366 if not estimator_requires_fit(self.estimator): 

367 return self 

368 

369 # if the estimator needs to be fitted. 

370 logger.debug(f"{_frmt(self)}.fit") 

371 # Samples is list of either Sample or DelayedSample created with 

372 # DelayedSamplesCall function, therefore some element in the list can be 

373 # None. 

374 # Filter out invalid samples (i.e. samples[k] == None), otherwise 

375 # SampleBatch will fail and throw exceptions 

376 samples = [ 

377 s for s in samples if getattr(s, self.input_attribute) is not None 

378 ] 

379 X = SampleBatch(samples, sample_attribute=self.input_attribute) 

380 

381 kwargs = _make_kwargs_from_samples(samples, self.fit_extra_arguments) 

382 self.estimator = self.estimator.fit(X, **kwargs) 

383 copy_learned_attributes(self.estimator, self) 

384 return self 

385 

386 

387class CheckpointWrapper(BaseWrapper, TransformerMixin): 

388 """Wraps :any:`Sample`-based estimators so the results are saved in 

389 disk. 

390 

391 Parameters 

392 ---------- 

393 

394 estimator 

395 The scikit-learn estimator to be wrapped. 

396 

397 model_path: str 

398 Saves the estimator state in this directory if the `estimator` is stateful 

399 

400 features_dir: str 

401 Saves the transformed data in this directory 

402 

403 extension: str 

404 Default extension of the transformed features. 

405 If None, will use the ``bob_checkpoint_extension`` tag in the estimator, or 

406 default to ``.h5``. 

407 

408 save_func 

409 Pointer to a customized function that saves transformed features to disk. 

410 If None, will use the ``bob_feature_save_fn`` tag in the estimator, or default 

411 to ``bob.io.base.save``. 

412 

413 load_func 

414 Pointer to a customized function that loads transformed features from disk. 

415 If None, will use the ``bob_feature_load_fn`` tag in the estimator, or default 

416 to ``bob.io.base.load``. 

417 

418 sample_attribute: str 

419 Defines the payload attribute of the sample. 

420 If None, will use the ``bob_output`` tag in the estimator, or default to 

421 ``data``. 

422 

423 hash_fn 

424 Pointer to a hash function. This hash function maps 

425 `sample.key` to a hash code and this hash code corresponds a relative directory 

426 where a single `sample` will be checkpointed. 

427 This is useful when is desirable file directories with less than 

428 a certain number of files. 

429 

430 attempts 

431 Number of checkpoint attempts. Sometimes, because of network/disk issues 

432 files can't be saved. This argument sets the maximum number of attempts 

433 to checkpoint a sample. 

434 

435 force: bool 

436 If True, will recompute the checkpoints even if they exists 

437 

438 """ 

439 

440 def __init__( 

441 self, 

442 estimator, 

443 model_path=None, 

444 features_dir=None, 

445 extension=None, 

446 save_func=None, 

447 load_func=None, 

448 sample_attribute=None, 

449 hash_fn=None, 

450 attempts=10, 

451 force=False, 

452 **kwargs, 

453 ): 

454 super().__init__(**kwargs) 

455 bob_tags = get_bob_tags(estimator) 

456 self.extension = extension or bob_tags["bob_checkpoint_extension"] 

457 self.save_func = save_func or bob_tags["bob_features_save_fn"] 

458 self.load_func = load_func or bob_tags["bob_features_load_fn"] 

459 self.sample_attribute = sample_attribute or bob_tags["bob_output"] 

460 

461 if not bob_tags["bob_checkpoint_features"]: 

462 logger.info( 

463 "Checkpointing is disabled for %s beacuse the bob_checkpoint_features tag is False.", 

464 estimator, 

465 ) 

466 features_dir = None 

467 

468 self.force = force 

469 self.estimator = estimator 

470 self.model_path = model_path 

471 self.features_dir = features_dir 

472 self.hash_fn = hash_fn 

473 self.attempts = attempts 

474 

475 # Paths check 

476 if model_path is None and features_dir is None: 

477 logger.warning( 

478 "Both model_path and features_dir are None. " 

479 f"Nothing will be checkpointed. From: {self}" 

480 ) 

481 

482 def _checkpoint_transform(self, samples, method_name): 

483 # Transform either samples or samplesets 

484 method = getattr(self.estimator, method_name) 

485 

486 # if features_dir is None, just transform all samples at once 

487 if self.features_dir is None: 

488 return method(samples) 

489 

490 def _transform_samples(samples): 

491 paths = [self.make_path(s) for s in samples] 

492 should_compute_list = [ 

493 self.force or p is None or not os.path.isfile(p) for p in paths 

494 ] 

495 # call method on non-checkpointed samples 

496 non_existing_samples = [ 

497 s 

498 for s, should_compute in zip(samples, should_compute_list) 

499 if should_compute 

500 ] 

501 # non_existing_samples could be empty 

502 computed_features = [] 

503 if non_existing_samples: 

504 computed_features = method(non_existing_samples) 

505 _check_n_input_output( 

506 non_existing_samples, computed_features, method 

507 ) 

508 # return computed features and checkpointed features 

509 features, com_feat_index = [], 0 

510 

511 for s, p, should_compute in zip( 

512 samples, paths, should_compute_list 

513 ): 

514 if should_compute: 

515 feat = computed_features[com_feat_index] 

516 com_feat_index += 1 

517 # save the computed feature when valid (not None) 

518 if ( 

519 p is not None 

520 and getattr(feat, self.sample_attribute) is not None 

521 ): 

522 self.save(feat) 

523 # sometimes loading the file fails randomly 

524 for _ in range(self.attempts): 

525 try: 

526 feat = self.load(s, p) 

527 break 

528 except Exception: 

529 error = traceback.format_exc() 

530 time.sleep(0.1) 

531 else: 

532 raise RuntimeError( 

533 f"Could not load using: {self.load}({s}, {p}) with the following error: {error}" 

534 ) 

535 features.append(feat) 

536 else: 

537 features.append(self.load(s, p)) 

538 return features 

539 

540 if isinstance(samples[0], SampleSet): 

541 return [ 

542 SampleSet(_transform_samples(s.samples), parent=s) 

543 for s in samples 

544 ] 

545 else: 

546 return _transform_samples(samples) 

547 

548 def transform(self, samples): 

549 logger.debug(f"{_frmt(self)}.transform") 

550 return self._checkpoint_transform(samples, "transform") 

551 

552 def decision_function(self, samples): 

553 logger.debug(f"{_frmt(self)}.decision_function") 

554 return self.estimator.decision_function(samples) 

555 

556 def predict(self, samples): 

557 logger.debug(f"{_frmt(self)}.predict") 

558 return self.estimator.predict(samples) 

559 

560 def predict_proba(self, samples): 

561 logger.debug(f"{_frmt(self)}.predict_proba") 

562 return self.estimator.predict_proba(samples) 

563 

564 def score(self, samples): 

565 logger.debug(f"{_frmt(self)}.score") 

566 return self.estimator.score(samples) 

567 

568 def fit(self, samples, y=None, **kwargs): 

569 if not estimator_requires_fit(self.estimator): 

570 return self 

571 

572 # if the estimator needs to be fitted. 

573 logger.debug(f"{_frmt(self)}.fit") 

574 

575 if self.model_path is not None and os.path.isfile(self.model_path): 

576 logger.info("Found a checkpoint for model. Loading ...") 

577 return self.load_model() 

578 

579 self.estimator = self.estimator.fit(samples, y=y, **kwargs) 

580 copy_learned_attributes(self.estimator, self) 

581 return self.save_model() 

582 

583 def make_path(self, sample): 

584 if self.features_dir is None: 

585 return None 

586 

587 key = str(sample.key) 

588 if key.startswith(os.sep) or ".." in key: 

589 raise ValueError( 

590 "Sample.key values should be relative paths with no " 

591 f"reference to upper folders. Got: {key}" 

592 ) 

593 

594 hash_dir_name = self.hash_fn(key) if self.hash_fn is not None else "" 

595 

596 return os.path.join( 

597 self.features_dir, hash_dir_name, key + self.extension 

598 ) 

599 

600 def save(self, sample): 

601 path = self.make_path(sample) 

602 # Gets sample.data or sample.<sample_attribute> if specified 

603 to_save = getattr(sample, self.sample_attribute) 

604 for _ in range(self.attempts): 

605 try: 

606 dirname = os.path.dirname(path) 

607 os.makedirs(dirname, exist_ok=True) 

608 

609 # Atomic writing 

610 extension = "".join(Path(path).suffixes) 

611 with tempfile.NamedTemporaryFile( 

612 dir=dirname, delete=False, suffix=extension 

613 ) as f: 

614 self.save_func(to_save, f.name) 

615 os.replace(f.name, path) 

616 

617 # test loading 

618 self.load_func(path) 

619 break 

620 except Exception: 

621 error = traceback.format_exc() 

622 time.sleep(0.1) 

623 else: 

624 raise RuntimeError( 

625 f"Could not save {to_save} of type {type(to_save)} using {self.save_func} with the following error: {error}" 

626 ) 

627 

628 def load(self, sample, path): 

629 # because we are checkpointing, we return a DelayedSample 

630 # instead of a normal (preloaded) sample. This allows the next 

631 # phase to avoid loading it would it be unnecessary (e.g. next 

632 # phase is already check-pointed) 

633 if self.sample_attribute == "data": 

634 return DelayedSample(partial(self.load_func, path), parent=sample) 

635 else: 

636 loaded = self.load_func(path) 

637 setattr(sample, self.sample_attribute, loaded) 

638 return sample 

639 

640 def load_model(self): 

641 if not estimator_requires_fit(self.estimator): 

642 return self 

643 with open(self.model_path, "rb") as f: 

644 loaded_estimator = cloudpickle.load(f) 

645 # We update self.estimator instead of replacing it because 

646 # self.estimator might be referenced elsewhere. 

647 _update_estimator(self.estimator, loaded_estimator) 

648 return self 

649 

650 def save_model(self): 

651 if ( 

652 not estimator_requires_fit(self.estimator) 

653 or self.model_path is None 

654 ): 

655 return self 

656 os.makedirs(os.path.dirname(self.model_path), exist_ok=True) 

657 with open(self.model_path, "wb") as f: 

658 cloudpickle.dump(self.estimator, f) 

659 return self 

660 

661 

662def _update_estimator(estimator, loaded_estimator): 

663 # recursively update estimator with loaded_estimator without replacing 

664 # estimator.estimator 

665 if hasattr(estimator, "estimator"): 

666 _update_estimator(estimator.estimator, loaded_estimator.estimator) 

667 for k, v in loaded_estimator.__dict__.items(): 

668 if k != "estimator": 

669 estimator.__dict__[k] = v 

670 

671 

672def is_checkpointed(estimator): 

673 return is_instance_nested(estimator, "estimator", CheckpointWrapper) 

674 

675 

676def getattr_nested(estimator, attr): 

677 if hasattr(estimator, attr): 

678 return getattr(estimator, attr) 

679 elif hasattr(estimator, "estimator"): 

680 return getattr_nested(estimator.estimator, attr) 

681 return None 

682 

683 

684def _sample_attribute(samples, attribute): 

685 return [getattr(s, attribute) for s in samples] 

686 

687 

688def _len_samples(samples): 

689 return [len(samples)] 

690 

691 

692def _shape_samples(samples): 

693 return [[s.shape for s in samples]] 

694 

695 

696def _array_from_sample_bags(X: dask.bag.Bag, attribute: str, ndim: int = 2): 

697 if ndim not in (1, 2): 

698 raise NotImplementedError(f"ndim must be 1 or 2. Got: {ndim}") 

699 

700 if ndim == 1: 

701 stack_function = np.concatenate 

702 else: 

703 stack_function = np.vstack 

704 

705 # because samples could be delayed samples, we convert sample bags to 

706 # sample.attribute bags first and then persist 

707 X = X.map_partitions(_sample_attribute, attribute=attribute).persist() 

708 

709 # convert sample.attribute bags to arrays 

710 delayeds = X.to_delayed() 

711 lengths = X.map_partitions(_len_samples) 

712 shapes = X.map_partitions(_shape_samples) 

713 lengths, shapes = dask.compute(lengths, shapes) 

714 dtype, X = None, [] 

715 for length_, shape_, delayed_samples_list in zip(lengths, shapes, delayeds): 

716 delayed_samples_list._length = length_ 

717 

718 if dtype is None: 

719 dtype = np.array(delayed_samples_list[0].compute()).dtype 

720 

721 # stack the data in each bag 

722 stacked_samples = dask.delayed(stack_function)(delayed_samples_list) 

723 # make sure shapes are at least 2d 

724 for i, s in enumerate(shape_): 

725 if len(s) == 1 and ndim == 2: 

726 shape_[i] = (1,) + s 

727 elif len(s) == 0: 

728 # if shape is empty, it means that the samples are scalars 

729 if ndim == 1: 

730 shape_[i] = (1,) 

731 else: 

732 shape_[i] = (1, 1) 

733 stacked_shape = sum(s[0] for s in shape_) 

734 stacked_shape = [stacked_shape] + list(shape_[0][1:]) 

735 

736 darray = da.from_delayed( 

737 stacked_samples, 

738 stacked_shape, 

739 dtype=dtype, 

740 name=False, 

741 ) 

742 X.append(darray) 

743 

744 # stack data from all bags 

745 X = stack_function(X) 

746 return X 

747 

748 

749class DaskWrapper(BaseWrapper, TransformerMixin): 

750 """Wraps Scikit estimators to handle Dask Bags as input. 

751 

752 Parameters 

753 ---------- 

754 

755 fit_resource: str 

756 Mark the delayed(self.fit) with this value. This can be used in 

757 a future delayed(self.fit).compute(resources=resource_tape) so 

758 dask scheduler can place this task in a particular resource 

759 (e.g GPU) 

760 

761 transform_resource: str 

762 Mark the delayed(self.transform) with this value. This can be used in 

763 a future delayed(self.transform).compute(resources=resource_tape) so 

764 dask scheduler can place this task in a particular resource 

765 (e.g GPU) 

766 """ 

767 

768 def __init__( 

769 self, 

770 estimator, 

771 fit_tag=None, 

772 transform_tag=None, 

773 fit_supports_dask_array=None, 

774 fit_supports_dask_bag=None, 

775 **kwargs, 

776 ): 

777 super().__init__(**kwargs) 

778 self.estimator = estimator 

779 self._dask_state = estimator 

780 self.resource_tags = dict() 

781 self.fit_tag = fit_tag 

782 self.transform_tag = transform_tag 

783 self.fit_supports_dask_array = ( 

784 fit_supports_dask_array 

785 or get_bob_tags(self.estimator)["bob_fit_supports_dask_array"] 

786 ) 

787 self.fit_supports_dask_bag = ( 

788 fit_supports_dask_bag 

789 or get_bob_tags(self.estimator)["bob_fit_supports_dask_bag"] 

790 ) 

791 

792 def _make_dask_resource_tag(self, tag): 

793 return {tag: 1} 

794 

795 def _dask_transform(self, X, method_name): 

796 graph_name = f"{_frmt(self)}.{method_name}" 

797 logger.debug(graph_name) 

798 

799 def _transf(X_line, dask_state): 

800 return getattr(dask_state, method_name)(X_line) 

801 

802 # change the name to have a better name in dask graphs 

803 _transf.__name__ = graph_name 

804 # scatter the dask_state to all workers for efficiency 

805 dask_state = dask.delayed(self._dask_state) 

806 map_partitions = X.map_partitions(_transf, dask_state) 

807 if self.transform_tag: 

808 self.resource_tags[ 

809 tuple(map_partitions.dask.keys()) 

810 ] = self._make_dask_resource_tag(self.transform_tag) 

811 

812 return map_partitions 

813 

814 def transform(self, samples): 

815 return self._dask_transform(samples, "transform") 

816 

817 def decision_function(self, samples): 

818 return self._dask_transform(samples, "decision_function") 

819 

820 def predict(self, samples): 

821 return self._dask_transform(samples, "predict") 

822 

823 def predict_proba(self, samples): 

824 return self._dask_transform(samples, "predict_proba") 

825 

826 def score(self, samples): 

827 return self._dask_transform(samples, "score") 

828 

829 def _get_fit_params_from_sample_bags(self, bags): 

830 logger.debug("Preparing data as dask arrays for fit") 

831 

832 input_attribute = getattr_nested(self, "input_attribute") 

833 fit_extra_arguments = getattr_nested(self, "fit_extra_arguments") 

834 

835 # convert X which is a dask bag to a dask array 

836 X = _array_from_sample_bags(bags, input_attribute, ndim=2) 

837 kwargs = dict() 

838 for arg, attr in fit_extra_arguments: 

839 # we only create a dask array if the arg is named ``y`` 

840 if arg == "y": 

841 kwargs[arg] = _array_from_sample_bags(bags, attr, ndim=1) 

842 else: 

843 raise NotImplementedError( 

844 f"fit_extra_arguments: {arg} is not supported, only ``y`` is supported." 

845 ) 

846 

847 return X, kwargs 

848 

849 def _fit_on_dask_array(self, bags, y=None, **fit_params): 

850 if y is not None or fit_params: 

851 raise ValueError( 

852 "y or fit_params should be passed through fit_extra_arguments of the SampleWrapper" 

853 ) 

854 

855 X, fit_params = self._get_fit_params_from_sample_bags(bags) 

856 self.estimator.fit(X, **fit_params) 

857 return self 

858 

859 def _fit_on_dask_bag(self, bags, y=None, **fit_params): 

860 # X is a dask bag of Samples convert to required fit parameters 

861 logger.debug("Converting dask bag of samples to bags of fit parameters") 

862 

863 def getattr_list(samples, attribute): 

864 return SampleBatch(samples, sample_attribute=attribute) 

865 

866 # we prepare the input parameters here instead of doing this in the 

867 # SampleWrapper. The SampleWrapper class then will pass these dask bags 

868 # directly to the underlying estimator. 

869 bob_tags = get_bob_tags(self.estimator) 

870 input_attribute = bob_tags["bob_input"] 

871 fit_extra_arguments = bob_tags["bob_fit_extra_input"] 

872 

873 X = bags.map_partitions(getattr_list, input_attribute) 

874 kwargs = { 

875 arg: bags.map_partitions(getattr_list, attr) 

876 for arg, attr in fit_extra_arguments 

877 } 

878 

879 self.estimator.fit(X, **kwargs) 

880 return self 

881 

882 def fit(self, X, y=None, **fit_params): 

883 if not estimator_requires_fit(self.estimator): 

884 return self 

885 logger.debug(f"{_frmt(self)}.fit") 

886 

887 model_path = None 

888 if is_checkpointed(self): 

889 model_path = getattr_nested(self, "model_path") 

890 model_path = model_path or "" 

891 if os.path.isfile(model_path): 

892 logger.info( 

893 f"Checkpointed estimator detected at {model_path}. The estimator ({_frmt(self)}) will be loaded and training will not run." 

894 ) 

895 # we should load the estimator outside dask graph to make sure that 

896 # the estimator loads in the scheduler 

897 self.estimator.load_model() 

898 return self 

899 

900 if self.fit_supports_dask_array: 

901 return self._fit_on_dask_array(X, y, **fit_params) 

902 elif self.fit_supports_dask_bag: 

903 return self._fit_on_dask_bag(X, y, **fit_params) 

904 

905 def _fit(X, y, **fit_params): 

906 try: 

907 self.estimator = self.estimator.fit(X, y, **fit_params) 

908 except Exception as e: 

909 raise RuntimeError( 

910 f"Something went wrong when fitting {self.estimator} " 

911 f"from {self}" 

912 ) from e 

913 copy_learned_attributes(self.estimator, self) 

914 return self.estimator 

915 

916 # change the name to have a better name in dask graphs 

917 _fit.__name__ = f"{_frmt(self)}.fit" 

918 

919 _fit_call = delayed(_fit)(X, y, **fit_params) 

920 self._dask_state = _fit_call.persist() 

921 

922 if self.fit_tag is not None: 

923 # If you do `delayed(_fit)(X, y)`, two tasks are generated; 

924 # the `finalize-TASK` and `TASK`. With this, we make sure 

925 # that the two are annotated 

926 self.resource_tags[ 

927 tuple( 

928 [ 

929 f"{k}{str(self._dask_state.key)}" 

930 for k in ["", "finalize-"] 

931 ] 

932 ) 

933 ] = self._make_dask_resource_tag(self.fit_tag) 

934 

935 return self 

936 

937 

938class ToDaskBag(TransformerMixin, BaseEstimator): 

939 """Transform an arbitrary iterator into a :any:`dask.bag.Bag` 

940 

941 Example 

942 ------- 

943 >>> import bob.pipelines 

944 >>> transformer = bob.pipelines.ToDaskBag() 

945 >>> dask_bag = transformer.transform([1,2,3]) 

946 >>> # dask_bag.map_partitions(...) 

947 

948 Attributes 

949 ---------- 

950 npartitions : int 

951 Number of partitions used in :any:`dask.bag.from_sequence` 

952 """ 

953 

954 def __init__(self, npartitions=None, partition_size=None, **kwargs): 

955 super().__init__(**kwargs) 

956 self.npartitions = npartitions 

957 self.partition_size = partition_size 

958 

959 def fit(self, X, y=None): 

960 return self 

961 

962 def transform(self, X): 

963 logger.debug(f"{_frmt(self)}.transform") 

964 if self.partition_size is None: 

965 return dask.bag.from_sequence(X, npartitions=self.npartitions) 

966 else: 

967 return dask.bag.from_sequence(X, partition_size=self.partition_size) 

968 

969 def _more_tags(self): 

970 return {"requires_fit": False} 

971 

972 

973def wrap(bases, estimator=None, **kwargs): 

974 """Wraps several estimators inside each other. 

975 

976 If ``estimator`` is a pipeline, the estimators in that pipeline are wrapped. 

977 

978 The default behavior of wrappers can be customized through the tags; see 

979 :any:`bob.pipelines.get_bob_tags` for more information. 

980 

981 Parameters 

982 ---------- 

983 bases : list 

984 A list of classes to be used to wrap ``estimator``. 

985 estimator : :any:`object`, optional 

986 An initial estimator to be wrapped inside other wrappers. 

987 If None, the first class will be used to initialize the estimator. 

988 **kwargs 

989 Extra parameters passed to the init of classes. 

990 

991 Returns 

992 ------- 

993 object 

994 The wrapped estimator 

995 

996 Raises 

997 ------ 

998 ValueError 

999 If not all kwargs are consumed. 

1000 """ 

1001 # if wrappers are passed as strings convert them to classes 

1002 for i, w in enumerate(bases): 

1003 if isinstance(w, str): 

1004 bases[i] = { 

1005 "sample": SampleWrapper, 

1006 "checkpoint": CheckpointWrapper, 

1007 "dask": DaskWrapper, 

1008 }[w.lower()] 

1009 

1010 def _wrap(estimator, **kwargs): 

1011 # wrap the object and pass the kwargs 

1012 for w_class in bases: 

1013 valid_params = w_class._get_param_names() 

1014 params = {k: kwargs.pop(k) for k in valid_params if k in kwargs} 

1015 if estimator is None: 

1016 estimator = w_class(**params) 

1017 else: 

1018 estimator = w_class(estimator, **params) 

1019 return estimator, kwargs 

1020 

1021 # if the estimator is a pipeline, wrap its steps instead. 

1022 # We don't look for pipelines recursively because most of the time we 

1023 # don't want the inner pipeline's steps to be wrapped. 

1024 if isinstance(estimator, Pipeline): 

1025 # wrap inner steps 

1026 for idx, name, trans in estimator._iter(): 

1027 # when checkpointing a pipeline, checkpoint each transformer in its own folder 

1028 new_kwargs = dict(kwargs) 

1029 features_dir, model_path = ( 

1030 kwargs.get("features_dir"), 

1031 kwargs.get("model_path"), 

1032 ) 

1033 if features_dir is not None: 

1034 new_kwargs["features_dir"] = os.path.join(features_dir, name) 

1035 if model_path is not None: 

1036 new_kwargs["model_path"] = os.path.join( 

1037 model_path, f"{name}.pkl" 

1038 ) 

1039 

1040 trans, leftover = _wrap(trans, **new_kwargs) 

1041 estimator.steps[idx] = (name, trans) 

1042 

1043 # if being wrapped with DaskWrapper, add ToDaskBag to the steps 

1044 if DaskWrapper in bases: 

1045 valid_params = ToDaskBag._get_param_names() 

1046 params = {k: leftover.pop(k) for k in valid_params if k in leftover} 

1047 dask_bag = ToDaskBag(**params) 

1048 estimator.steps.insert(0, ("ToDaskBag", dask_bag)) 

1049 else: 

1050 estimator, leftover = _wrap(estimator, **kwargs) 

1051 

1052 if leftover: 

1053 raise ValueError(f"Got extra kwargs that were not consumed: {leftover}") 

1054 

1055 return estimator 

1056 

1057 

1058def dask_tags(estimator): 

1059 """Recursively collects resource_tags in dasked estimators.""" 

1060 tags = {} 

1061 

1062 if hasattr(estimator, "estimator"): 

1063 tags.update(dask_tags(estimator.estimator)) 

1064 

1065 if isinstance(estimator, Pipeline): 

1066 for idx, name, trans in estimator._iter(): 

1067 tags.update(dask_tags(trans)) 

1068 

1069 if hasattr(estimator, "resource_tags"): 

1070 tags.update(estimator.resource_tags) 

1071 

1072 return tags 

1073 

1074 

1075def estimator_requires_fit(estimator): 

1076 if not hasattr(estimator, "_get_tags"): 

1077 raise ValueError( 

1078 f"Passed estimator: {estimator} does not have the _get_tags method." 

1079 ) 

1080 

1081 # If the estimator is wrapped, check the wrapped estimator 

1082 if is_instance_nested( 

1083 estimator, "estimator", (SampleWrapper, CheckpointWrapper, DaskWrapper) 

1084 ): 

1085 return estimator_requires_fit(estimator.estimator) 

1086 

1087 # If estimator is a Pipeline, check if any of the steps requires fit 

1088 if isinstance(estimator, Pipeline): 

1089 return any([estimator_requires_fit(e) for _, e in estimator.steps]) 

1090 

1091 # We check for the FunctionTransformer since theoretically it 

1092 # does require fit but it does not really need it. 

1093 if is_instance_nested(estimator, "estimator", FunctionTransformer): 

1094 return False 

1095 

1096 # if the estimator does not require fit, don't call fit 

1097 # See: https://scikit-learn.org/stable/developers/develop.html 

1098 tags = estimator._get_tags() 

1099 return tags["requires_fit"] 

1100 

1101 

1102def is_instance_nested(instance, attribute, isinstance_of): 

1103 """ 

1104 Check if an object and its nested objects is an instance of a class. 

1105 

1106 This is useful while using aggregation and it's necessary to check if some 

1107 functionally was aggregated 

1108 

1109 Parameters 

1110 ---------- 

1111 instance: 

1112 Object to be searched 

1113 

1114 attribute: 

1115 Attribute name to be recursively searched 

1116 

1117 isinstance_of: 

1118 Instance class to be searched 

1119 

1120 """ 

1121 if isinstance(instance, isinstance_of): 

1122 return True 

1123 

1124 if not hasattr(instance, attribute): 

1125 return False 

1126 

1127 # Checking the current object and its immediate nested 

1128 if isinstance(instance, isinstance_of) or isinstance( 

1129 getattr(instance, attribute), isinstance_of 

1130 ): 

1131 return True 

1132 else: 

1133 # Recursive search 

1134 return is_instance_nested( 

1135 getattr(instance, attribute), attribute, isinstance_of 

1136 ) 

1137 

1138 

1139def is_pipeline_wrapped(estimator, wrapper): 

1140 """ 

1141 Iterates over the transformers of :py:class:`sklearn.pipeline.Pipeline` checking and 

1142 checks if they were wrapped with `wrapper` class 

1143 

1144 Parameters 

1145 ---------- 

1146 

1147 estimator: sklearn.pipeline.Pipeline 

1148 Pipeline to be checked 

1149 

1150 wrapper: type 

1151 The Wrapper class or a tuple of classes to be checked 

1152 

1153 Returns 

1154 ------- 

1155 list 

1156 Returns a list of boolean values, where each value indicates if the corresponding estimator is wrapped or not 

1157 """ 

1158 

1159 if not isinstance(estimator, Pipeline): 

1160 raise ValueError(f"{estimator} is not an instance of Pipeline") 

1161 

1162 return [ 

1163 is_instance_nested(trans, "estimator", wrapper) 

1164 for _, _, trans in estimator._iter() 

1165 ]