Coverage for src/bob/bio/base/wrappers.py: 74%
38 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
4import os
6import bob.pipelines
8from bob.bio.base.extractor import Extractor
9from bob.bio.base.preprocessor import Preprocessor
10from bob.bio.base.transformers import (
11 ExtractorTransformer,
12 PreprocessorTransformer,
13)
14from bob.bio.base.utils import is_argument_available
17def wrap_bob_legacy(
18 bob_object,
19 dir_name,
20 fit_extra_arguments=None,
21 transform_extra_arguments=None,
22 dask_it=False,
23 **kwargs,
24):
25 """
26 Wraps either :any:`bob.bio.base.preprocessor.Preprocessor` or
27 :any:`bob.bio.base.extractor.Extractor` with
28 :any:`sklearn.base.TransformerMixin` and
29 :any:`bob.pipelines.wrappers.CheckpointWrapper` and
30 :any:`bob.pipelines.wrappers.SampleWrapper`
33 Parameters
34 ----------
36 bob_object: object
37 Instance of :any:`bob.bio.base.preprocessor.Preprocessor` or
38 :any:`bob.bio.base.extractor.Extractor`
40 dir_name: str
41 Directory name for the checkpoints
43 fit_extra_arguments: [tuple]
44 Same behavior as in Check
45 :any:`bob.pipelines.wrappers.fit_extra_arguments`
47 transform_extra_arguments: [tuple]
48 Same behavior as in Check
49 :any:`bob.pipelines.wrappers.transform_extra_arguments`
51 dask_it: bool
52 If True, the transformer will be a dask graph
53 """
55 if isinstance(bob_object, Preprocessor):
56 transformer = wrap_checkpoint_preprocessor(
57 bob_object,
58 features_dir=os.path.join(dir_name, "preprocessor"),
59 **kwargs,
60 )
61 elif isinstance(bob_object, Extractor):
62 transformer = wrap_checkpoint_extractor(
63 bob_object,
64 features_dir=os.path.join(dir_name, "extractor"),
65 model_path=dir_name,
66 fit_extra_arguments=fit_extra_arguments,
67 transform_extra_arguments=transform_extra_arguments,
68 **kwargs,
69 )
70 else:
71 raise ValueError(
72 "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`"
73 )
75 if dask_it:
76 transformer = bob.pipelines.wrap(["dask"], transformer)
78 return transformer
81def wrap_sample_preprocessor(
82 preprocessor,
83 transform_extra_arguments=(("annotations", "annotations"),),
84 **kwargs,
85):
86 """
87 Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with
88 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
90 .. warning::
91 This wrapper doesn't checkpoint data
93 Parameters
94 ----------
96 preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor`
97 Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped
99 transform_extra_arguments: [tuple]
100 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
102 """
104 transformer = PreprocessorTransformer(preprocessor)
105 return bob.pipelines.wrap(
106 ["sample"],
107 transformer,
108 transform_extra_arguments=transform_extra_arguments,
109 )
112def wrap_checkpoint_preprocessor(
113 preprocessor,
114 features_dir=None,
115 transform_extra_arguments=(("annotations", "annotations"),),
116 load_func=None,
117 save_func=None,
118 extension=".hdf5",
119):
120 """
121 Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with
122 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
124 Parameters
125 ----------
127 preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor`
128 Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped
130 features_dir: str
131 Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`).
133 extension : str, optional
134 Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`).
136 load_func : None, optional
137 Function that loads data to be preprocessed.
138 The default is :any:`bob.bio.base.preprocessor.Preprocessor.read_data`
140 save_func : None, optional
141 Function that saves preprocessed data.
142 The default is :any:`bob.bio.base.preprocessor.Preprocessor.write_data`
144 transform_extra_arguments: [tuple]
145 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
147 """
149 transformer = PreprocessorTransformer(preprocessor)
150 return bob.pipelines.wrap(
151 ["sample", "checkpoint"],
152 transformer,
153 load_func=load_func or preprocessor.read_data,
154 save_func=save_func or preprocessor.write_data,
155 features_dir=features_dir,
156 transform_extra_arguments=transform_extra_arguments,
157 extension=extension,
158 )
161def _prepare_extractor_sample_args(
162 extractor, transform_extra_arguments, fit_extra_arguments
163):
164 if transform_extra_arguments is None and is_argument_available(
165 "metadata", extractor.__call__
166 ):
167 transform_extra_arguments = (("metadata", "metadata"),)
169 if (
170 fit_extra_arguments is None
171 and extractor.requires_training
172 and extractor.split_training_data_by_client
173 ):
174 fit_extra_arguments = (("y", "subject"),)
176 return transform_extra_arguments, fit_extra_arguments
179def wrap_sample_extractor(
180 extractor,
181 fit_extra_arguments=None,
182 transform_extra_arguments=None,
183 model_path=None,
184 **kwargs,
185):
186 """
187 Wraps :any:`bob.bio.base.extractor.Extractor` with
188 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
190 Parameters
191 ----------
193 extractor: :any:`bob.bio.base.extractor.Preprocessor`
194 Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped
196 transform_extra_arguments: [tuple], optional
197 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
199 model_path: str
200 Path to `extractor_file` in :any:`bob.bio.base.extractor.Extractor`
202 """
204 extractor_file = (
205 os.path.join(model_path, "Extractor.hdf5")
206 if model_path is not None
207 else None
208 )
210 transformer = ExtractorTransformer(extractor, model_path=extractor_file)
212 (
213 transform_extra_arguments,
214 fit_extra_arguments,
215 ) = _prepare_extractor_sample_args(
216 extractor, transform_extra_arguments, fit_extra_arguments
217 )
219 return bob.pipelines.wrap(
220 ["sample"],
221 transformer,
222 transform_extra_arguments=transform_extra_arguments,
223 fit_extra_arguments=fit_extra_arguments,
224 **kwargs,
225 )
228def wrap_checkpoint_extractor(
229 extractor,
230 features_dir=None,
231 fit_extra_arguments=None,
232 transform_extra_arguments=None,
233 load_func=None,
234 save_func=None,
235 extension=".hdf5",
236 model_path=None,
237 **kwargs,
238):
239 """
240 Wraps :any:`bob.bio.base.extractor.Extractor` with
241 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
243 Parameters
244 ----------
246 extractor: :any:`bob.bio.base.extractor.Preprocessor`
247 Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped
249 features_dir: str
250 Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`).
252 extension : str, optional
253 Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`).
255 load_func : None, optional
256 Function that loads data to be preprocessed.
257 The default is :any:`bob.bio.base.extractor.Extractor.read_feature`
259 save_func : None, optional
260 Function that saves preprocessed data.
261 The default is :any:`bob.bio.base.extractor.Extractor.write_feature`
263 fit_extra_arguments: [tuple]
264 Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments`
266 transform_extra_arguments: [tuple], optional
267 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
269 model_path: str
270 See :any:`TransformerExtractor`.
272 """
274 extractor_file = (
275 os.path.join(model_path, "Extractor.hdf5")
276 if model_path is not None
277 else None
278 )
280 model_file = (
281 os.path.join(model_path, "Extractor.pkl")
282 if model_path is not None
283 else None
284 )
285 transformer = ExtractorTransformer(extractor, model_path=extractor_file)
287 (
288 transform_extra_arguments,
289 fit_extra_arguments,
290 ) = _prepare_extractor_sample_args(
291 extractor, transform_extra_arguments, fit_extra_arguments
292 )
294 return bob.pipelines.wrap(
295 ["sample", "checkpoint"],
296 transformer,
297 load_func=load_func or extractor.read_feature,
298 save_func=save_func or extractor.write_feature,
299 model_path=model_file,
300 features_dir=features_dir,
301 transform_extra_arguments=transform_extra_arguments,
302 fit_extra_arguments=fit_extra_arguments,
303 **kwargs,
304 )