Coverage for src/bob/pipelines/xarray.py: 95%

306 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 21:32 +0200

1import logging 

2import os 

3import random 

4import string 

5 

6from functools import partial 

7 

8import cloudpickle 

9import dask 

10import h5py 

11import numpy as np 

12import xarray as xr 

13 

14from sklearn.base import BaseEstimator 

15from sklearn.pipeline import _name_estimators 

16from sklearn.utils.metaestimators import _BaseComposition 

17 

18from .sample import SAMPLE_DATA_ATTRS, _ReprMixin 

19from .wrappers import estimator_requires_fit 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24def save(data, path): 

25 array = np.require(data, requirements=("C_CONTIGUOUS", "ALIGNED")) 

26 with h5py.File(path, "w") as f: 

27 f.create_dataset("array", data=array) 

28 

29 

30def load(path): 

31 with h5py.File(path, "r") as f: 

32 data = f["array"][()] 

33 return data 

34 

35 

36def _load_fn_to_xarray(load_fn, meta=None): 

37 if meta is None: 

38 meta = np.array(load_fn()) 

39 

40 da = dask.array.from_delayed( 

41 dask.delayed(load_fn)(), meta.shape, dtype=meta.dtype, name=False 

42 ) 

43 try: 

44 dims = meta.dims 

45 except Exception: 

46 dims = None 

47 

48 xa = xr.DataArray(da, dims=dims) 

49 return xa, meta 

50 

51 

52def _one_sample_to_dataset(sample, meta=None): 

53 dataset = {} 

54 delayed_attributes = getattr(sample, "_delayed_attributes", None) or {} 

55 for k in sample.__dict__: 

56 if ( 

57 k in SAMPLE_DATA_ATTRS 

58 or k in delayed_attributes 

59 or k.startswith("_") 

60 ): 

61 continue 

62 dataset[k] = getattr(sample, k) 

63 

64 meta = meta or {} 

65 

66 for k in ["data"] + list(delayed_attributes.keys()): 

67 attr_meta = meta.get(k) 

68 attr_array, attr_meta = _load_fn_to_xarray( 

69 partial(getattr, sample, k), meta=attr_meta 

70 ) 

71 meta[k] = attr_meta 

72 dataset[k] = attr_array 

73 

74 return xr.Dataset(dataset).chunk(), meta 

75 

76 

77def samples_to_dataset(samples, meta=None, npartitions=48, shuffle=False): 

78 """Converts a list of samples to a dataset. 

79 

80 See :ref:`bob.pipelines.dataset_pipeline`. 

81 

82 Parameters 

83 ---------- 

84 samples : list 

85 A list of :any:`Sample` or :any:`DelayedSample` objects. 

86 meta : ``xarray.DataArray``, optional 

87 An xarray.DataArray to be used as a template for data inside samples. 

88 npartitions : :obj:`int`, optional 

89 The number of partitions to partition the samples. 

90 shuffle : :obj:`bool`, optional 

91 If True, shuffles the samples (in-place) before constructing the dataset. 

92 

93 Returns 

94 ------- 

95 ``xarray.Dataset`` 

96 The constructed dataset with at least a ``data`` variable. 

97 """ 

98 if meta is not None and not isinstance(meta, dict): 

99 meta = dict(data=meta) 

100 

101 delayed_attributes = getattr(samples[0], "delayed_attributes", None) or {} 

102 if meta is None or not all( 

103 k in meta for k in ["data"] + list(delayed_attributes.keys()) 

104 ): 

105 dataset, meta = _one_sample_to_dataset(samples[0]) 

106 

107 if shuffle: 

108 random.shuffle(samples) 

109 

110 dataset = xr.concat( 

111 [_one_sample_to_dataset(s, meta=meta)[0] for s in samples], dim="sample" 

112 ) 

113 if npartitions is not None: 

114 dataset = dataset.chunk({"sample": max(1, len(samples) // npartitions)}) 

115 return dataset 

116 

117 

118class Block(_ReprMixin): 

119 """A block representation in a graph. 

120 This class is meant to be used with :any:`DatasetPipeline`. 

121 

122 Attributes 

123 ---------- 

124 dataset_map : ``callable`` 

125 A callable that transforms the input dataset into another dataset. 

126 estimator : object 

127 A scikit-learn estimator 

128 estimator_name : str 

129 Name of the estimator 

130 extension : str 

131 The extension of checkpointed features. 

132 features_dir : str 

133 The directory to save the features. 

134 fit_input : str or list 

135 A str or list of str of column names of the dataset to be given to the ``.fit`` 

136 method. 

137 fit_kwargs : None or dict 

138 A dict of ``fit_kwargs`` to be passed to the ``.fit`` method of the estimator. 

139 input_dask_array : bool 

140 Whether the estimator takes dask arrays in its fit method or not. 

141 load_func : ``callable`` 

142 A function to save the features. Defaults to ``np.load``. 

143 model_path : str or None 

144 If given, the estimator will be pickled here. 

145 output_dims : list 

146 A list of ``(dim_name, dim_size)`` tuples. If ``dim_name`` is ``None``, a new 

147 name is automatically generated, otherwise it should be a string. ``dim_size`` 

148 should be a positive integer or nan for new dimensions or ``None`` for existing 

149 dimensions. 

150 output_dtype : object 

151 The dtype of the output of the transformer. Defaults to ``float``. 

152 save_func : ``callable`` 

153 A function to save the features. Defaults to ``np.save`` with ``allow_pickle`` 

154 set to ``False``. 

155 transform_input : str or list 

156 A str or list of str of column names of the dataset to be given to the 

157 ``.transform`` method. 

158 """ 

159 

160 def __init__( 

161 self, 

162 estimator=None, 

163 output_dtype=float, 

164 output_dims=((None, np.nan),), 

165 fit_input="data", 

166 transform_input="data", 

167 estimator_name=None, 

168 model_path=None, 

169 features_dir=None, 

170 extension=".hdf5", 

171 save_func=None, 

172 load_func=None, 

173 dataset_map=None, 

174 input_dask_array=False, 

175 fit_kwargs=None, 

176 **kwargs, 

177 ): 

178 super().__init__(**kwargs) 

179 self.estimator = estimator 

180 self.output_dtype = output_dtype 

181 if not all(len(d) == 2 for d in output_dims): 

182 raise ValueError( 

183 "output_dims must be an iterable of size 2 tuples " 

184 f"(dim_name, dim_size), not {output_dims}" 

185 ) 

186 self.output_dims = output_dims 

187 self.fit_input = fit_input 

188 self.transform_input = transform_input 

189 if estimator_name is None: 

190 estimator_name = _name_estimators([estimator])[0][0] 

191 self.estimator_name = estimator_name 

192 self.model_path = model_path 

193 self.features_dir = features_dir 

194 self.extension = extension 

195 estimator_save_fn = ( 

196 None 

197 if estimator is None 

198 else estimator._get_tags().get("bob_features_save_fn") 

199 ) 

200 estimator_load_fn = ( 

201 None 

202 if estimator is None 

203 else estimator._get_tags().get("bob_features_load_fn") 

204 ) 

205 self.save_func = save_func or estimator_save_fn or save 

206 self.load_func = load_func or estimator_load_fn or load 

207 self.dataset_map = dataset_map 

208 self.input_dask_array = input_dask_array 

209 self.fit_kwargs = fit_kwargs or {} 

210 

211 def __getitem__(self, key): 

212 return getattr(self, key) 

213 

214 def __setitem__(self, key, value): 

215 setattr(self, key, value) 

216 

217 @property 

218 def output_ndim(self): 

219 return len(self.output_dims) + 1 

220 

221 def make_path(self, key): 

222 key = str(key) 

223 if key.startswith(os.sep) or ".." in key: 

224 raise ValueError( 

225 "Sample.key values should be relative paths with no " 

226 f"reference to upper folders. Got: {key}" 

227 ) 

228 return os.path.join(self.features_dir, key + self.extension) 

229 

230 def save(self, key, data): 

231 path = self.make_path(key) 

232 os.makedirs(os.path.dirname(path), exist_ok=True) 

233 # this should be save_func(data, path) so it's compatible with bob.io.base.save 

234 return self.save_func(data, path) 

235 

236 def load(self, key): 

237 path = self.make_path(key) 

238 return self.load_func(path) 

239 

240 

241def _fit(*args, block): 

242 logger.info(f"Calling {block.estimator_name}.fit") 

243 block.estimator.fit(*args, **block.fit_kwargs) 

244 if block.model_path is not None: 

245 logger.info(f"Saving {block.estimator_name} in {block.model_path}") 

246 os.makedirs(os.path.dirname(block.model_path), exist_ok=True) 

247 with open(block.model_path, "wb") as f: 

248 cloudpickle.dump(block.estimator, f) 

249 return block.estimator 

250 

251 

252class _TokenStableTransform: 

253 def __init__(self, block, method_name=None, input_has_keys=False, **kwargs): 

254 super().__init__(**kwargs) 

255 self.block = block 

256 self.method_name = method_name or "transform" 

257 self.input_has_keys = input_has_keys 

258 

259 def __dask_tokenize__(self): 

260 return (self.method_name, self.block.features_dir) 

261 

262 def __call__(self, *args, estimator): 

263 block, method_name = self.block, self.method_name 

264 logger.info(f"Calling {block.estimator_name}.{method_name}") 

265 

266 input_args = args[:-1] if self.input_has_keys else args 

267 try: 

268 features = getattr(estimator, self.method_name)(*input_args) 

269 except Exception as e: 

270 raise RuntimeError( 

271 f"Failed to transform data: {estimator}.{self.method_name}(*{input_args})" 

272 ) from e 

273 

274 # if keys are provided, checkpoint features 

275 if self.input_has_keys: 

276 data = args[0] 

277 key = args[-1] 

278 

279 l1, l2 = len(data), len(features) 

280 if l1 != l2: 

281 raise ValueError( 

282 f"Got {l2} features from processing {l1} samples!" 

283 ) 

284 

285 # save computed_features 

286 logger.info(f"Saving {l2} features in {block.features_dir}") 

287 for feat, k in zip(features, key): 

288 block.save(k, feat) 

289 

290 return features 

291 

292 

293def _populate_graph(graph): 

294 new_graph = [] 

295 for block in graph: 

296 if isinstance(block, BaseEstimator): 

297 block = {"estimator": block} 

298 if isinstance(block, dict): 

299 block = Block(**block) 

300 new_graph.append(block) 

301 return new_graph 

302 

303 

304def _get_dask_args_from_ds(ds, columns): 

305 if isinstance(columns, str): 

306 args = [(ds[columns].data, ds[columns].dims)] 

307 else: 

308 args = [] 

309 for c in columns: 

310 args.extend(_get_dask_args_from_ds(ds, c)) 

311 args = tuple(args) 

312 return args 

313 

314 

315def _blockwise_with_block_args(args, block, method_name=None): 

316 meta = [] 

317 for _ in range(1, block.output_ndim): 

318 meta = [meta] 

319 meta = np.array(meta, dtype=block.output_dtype) 

320 

321 ascii_letters = list(string.ascii_lowercase) 

322 dim_map = {} 

323 

324 input_arg_pairs = [] 

325 for array, dims in args: 

326 dim_name = [] 

327 for dim, dim_size in zip(dims, array.shape): 

328 if dim not in dim_map: 

329 dim_map[dim] = (ascii_letters.pop(0), dim_size) 

330 dim_name.append(dim_map[dim][0]) 

331 input_arg_pairs.extend((array, "".join(dim_name))) 

332 

333 # the sample dimension is always kept the same 

334 output_dim_name = f"{input_arg_pairs[1][0]}" 

335 new_axes = dict() 

336 for dim_name, dim_size in block.output_dims: 

337 if dim_name in dim_map: 

338 output_dim_name += dim_map[dim_name][0] 

339 else: 

340 try: 

341 dim_size = float(dim_size) 

342 except Exception: 

343 raise ValueError( 

344 "Expected a float dim_size (positive integers or nan) for new " 

345 f"dimension: {dim_name} but got: {dim_size}" 

346 ) 

347 

348 new_letter = ascii_letters.pop(0) 

349 if dim_name is None: 

350 dim_name = new_letter 

351 dim_map[dim_name] = (new_letter, dim_size) 

352 output_dim_name += new_letter 

353 new_axes[new_letter] = dim_size 

354 

355 dims = [] 

356 inv_map = {v[0]: k for k, v in dim_map.items()} 

357 for dim_name in output_dim_name: 

358 dims.append(inv_map[dim_name]) 

359 

360 output_shape = [dim_map[d][1] for d in dims] 

361 

362 return output_dim_name, new_axes, input_arg_pairs, dims, meta, output_shape 

363 

364 

365def _blockwise_with_block(args, block, method_name=None, input_has_keys=False): 

366 ( 

367 output_dim_name, 

368 new_axes, 

369 input_arg_pairs, 

370 dims, 

371 meta, 

372 _, 

373 ) = _blockwise_with_block_args(args, block, method_name=None) 

374 transform_func = _TokenStableTransform( 

375 block, method_name, input_has_keys=input_has_keys 

376 ) 

377 transform_func.__name__ = f"{block.estimator_name}.{method_name}" 

378 

379 data = dask.array.blockwise( 

380 transform_func, 

381 output_dim_name, 

382 *input_arg_pairs, 

383 meta=meta, 

384 new_axes=new_axes, 

385 concatenate=True, 

386 estimator=block.estimator_, 

387 ) 

388 

389 return dims, data 

390 

391 

392def _load_estimator(block): 

393 logger.info(f"Loading {block.estimator_name} from {block.model_path}") 

394 with open(block.model_path, "rb") as f: 

395 block.estimator = cloudpickle.load(f) 

396 return block.estimator 

397 

398 

399def _transform_or_load(block, ds, input_columns, mn): 

400 if isinstance(input_columns, str): 

401 input_columns = [input_columns] 

402 input_columns = list(input_columns) + ["key"] 

403 

404 # filter dataset based on existing checkpoints 

405 key = np.asarray(ds["key"]) 

406 paths = [block.make_path(k) for k in key] 

407 saved_samples = np.asarray([os.path.isfile(p) for p in paths]) 

408 # compute/load features per chunk 

409 chunksize = ds.data.data.chunksize[0] 

410 for i in range(0, len(saved_samples), chunksize): 

411 if not np.all(saved_samples[i : i + chunksize]): 

412 saved_samples[i : i + chunksize] = False 

413 

414 nonsaved_samples = np.logical_not(saved_samples) 

415 total_samples_n, saved_samples_n = len(key), saved_samples.sum() 

416 saved_ds = ds.sel({"sample": saved_samples}) 

417 nonsaved_ds = ds.sel({"sample": nonsaved_samples}) 

418 

419 computed_data = loaded_data = None 

420 # compute non-saved data 

421 if total_samples_n - saved_samples_n > 0: 

422 args = _get_dask_args_from_ds(nonsaved_ds, input_columns) 

423 dims, computed_data = _blockwise_with_block( 

424 args, block, mn, input_has_keys=True 

425 ) 

426 

427 # load saved data 

428 if saved_samples_n > 0: 

429 logger.info( 

430 f"Might load {saved_samples_n} features of {block.estimator_name}.{mn} from disk." 

431 ) 

432 args = _get_dask_args_from_ds(saved_ds, input_columns) 

433 dims, meta, shape = _blockwise_with_block_args(args, block, mn)[-3:] 

434 loaded_data = [ 

435 dask.array.from_delayed( 

436 dask.delayed(block.load)(k), 

437 shape=shape[1:], 

438 meta=meta, 

439 name=False, 

440 )[None, ...] 

441 for k in key[saved_samples] 

442 ] 

443 loaded_data = dask.array.concatenate(loaded_data, axis=0) 

444 

445 # merge loaded and computed data 

446 if computed_data is None: 

447 data = loaded_data 

448 elif loaded_data is None: 

449 data = computed_data 

450 else: 

451 # merge data chunk-based 

452 data = [] 

453 i, j = 0, 0 

454 for k in range(0, len(saved_samples), chunksize): 

455 saved = saved_samples[k] 

456 if saved: 

457 pick = loaded_data[j : j + chunksize] 

458 j += chunksize 

459 else: 

460 pick = computed_data[i : i + chunksize] 

461 i += chunksize 

462 data.append(pick) 

463 data = dask.array.concatenate(data, axis=0) 

464 

465 data = dask.array.rechunk(data, {0: chunksize}) 

466 return dims, data 

467 

468 

469class DatasetPipeline(_BaseComposition): 

470 """A dataset-based scikit-learn pipeline. 

471 See :ref:`bob.pipelines.dataset_pipeline`. 

472 

473 Attributes 

474 ---------- 

475 graph : list 

476 A list of :any:`Block`'s to be applied on input dataset. 

477 """ 

478 

479 def __init__(self, graph, **kwargs): 

480 super().__init__(**kwargs) 

481 self.graph = _populate_graph(graph) 

482 

483 def _transform(self, ds, do_fit=False, method_name=None): 

484 for i, block in enumerate(self.graph): 

485 if block.dataset_map is not None: 

486 try: 

487 ds = block.dataset_map(ds) 

488 except Exception as e: 

489 raise RuntimeError( 

490 f"Could not map ds {ds}\n with {block.dataset_map}" 

491 ) from e 

492 continue 

493 

494 if do_fit: 

495 args = _get_dask_args_from_ds(ds, block.fit_input) 

496 args = [d for d, dims in args] 

497 estimator = block.estimator 

498 if not estimator_requires_fit(estimator): 

499 block.estimator_ = estimator 

500 elif block.model_path is not None and os.path.isfile( 

501 block.model_path 

502 ): 

503 _load_estimator.__name__ = f"load_{block.estimator_name}" 

504 block.estimator_ = dask.delayed(_load_estimator)(block) 

505 elif block.input_dask_array: 

506 ds = ds.persist() 

507 args = _get_dask_args_from_ds(ds, block.fit_input) 

508 args = [d for d, dims in args] 

509 block.estimator_ = _fit(*args, block=block) 

510 else: 

511 _fit.__name__ = f"{block.estimator_name}.fit" 

512 block.estimator_ = dask.delayed(_fit)( 

513 *args, 

514 block=block, 

515 ) 

516 

517 mn = "transform" 

518 if i == len(self.graph) - 1: 

519 if do_fit: 

520 break 

521 mn = method_name 

522 

523 if block.features_dir is None: 

524 args = _get_dask_args_from_ds(ds, block.transform_input) 

525 dims, data = _blockwise_with_block( 

526 args, block, mn, input_has_keys=False 

527 ) 

528 else: 

529 dims, data = _transform_or_load( 

530 block, ds, block.transform_input, mn 

531 ) 

532 

533 # replace data inside dataset 

534 ds = ds.copy(deep=False) 

535 del ds["data"] 

536 persisted = False 

537 if not np.all(np.isfinite(data.shape)): 

538 block.estimator_, data = dask.persist(block.estimator_, data) 

539 data = data.compute_chunk_sizes() 

540 persisted = True 

541 ds["data"] = (dims, data) 

542 if persisted: 

543 ds = ds.persist() 

544 

545 return ds 

546 

547 def fit(self, ds, y=None): 

548 if y is not None: 

549 raise ValueError() 

550 self._transform(ds, do_fit=True) 

551 return self 

552 

553 def transform(self, ds): 

554 return self._transform(ds, method_name="transform") 

555 

556 def decision_function(self, ds): 

557 return self._transform(ds, method_name="decision_function") 

558 

559 def predict(self, ds): 

560 return self._transform(ds, method_name="predict") 

561 

562 def predict_proba(self, ds): 

563 return self._transform(ds, method_name="predict_proba") 

564 

565 def score(self, ds): 

566 return self._transform(ds, method_name="score")