Coverage for src/bob/bio/base/wrappers.py: 74%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4import os 

5 

6import bob.pipelines 

7 

8from bob.bio.base.extractor import Extractor 

9from bob.bio.base.preprocessor import Preprocessor 

10from bob.bio.base.transformers import ( 

11 ExtractorTransformer, 

12 PreprocessorTransformer, 

13) 

14from bob.bio.base.utils import is_argument_available 

15 

16 

17def wrap_bob_legacy( 

18 bob_object, 

19 dir_name, 

20 fit_extra_arguments=None, 

21 transform_extra_arguments=None, 

22 dask_it=False, 

23 **kwargs, 

24): 

25 """ 

26 Wraps either :any:`bob.bio.base.preprocessor.Preprocessor` or 

27 :any:`bob.bio.base.extractor.Extractor` with 

28 :any:`sklearn.base.TransformerMixin` and 

29 :any:`bob.pipelines.wrappers.CheckpointWrapper` and 

30 :any:`bob.pipelines.wrappers.SampleWrapper` 

31 

32 

33 Parameters 

34 ---------- 

35 

36 bob_object: object 

37 Instance of :any:`bob.bio.base.preprocessor.Preprocessor` or 

38 :any:`bob.bio.base.extractor.Extractor` 

39 

40 dir_name: str 

41 Directory name for the checkpoints 

42 

43 fit_extra_arguments: [tuple] 

44 Same behavior as in Check 

45 :any:`bob.pipelines.wrappers.fit_extra_arguments` 

46 

47 transform_extra_arguments: [tuple] 

48 Same behavior as in Check 

49 :any:`bob.pipelines.wrappers.transform_extra_arguments` 

50 

51 dask_it: bool 

52 If True, the transformer will be a dask graph 

53 """ 

54 

55 if isinstance(bob_object, Preprocessor): 

56 transformer = wrap_checkpoint_preprocessor( 

57 bob_object, 

58 features_dir=os.path.join(dir_name, "preprocessor"), 

59 **kwargs, 

60 ) 

61 elif isinstance(bob_object, Extractor): 

62 transformer = wrap_checkpoint_extractor( 

63 bob_object, 

64 features_dir=os.path.join(dir_name, "extractor"), 

65 model_path=dir_name, 

66 fit_extra_arguments=fit_extra_arguments, 

67 transform_extra_arguments=transform_extra_arguments, 

68 **kwargs, 

69 ) 

70 else: 

71 raise ValueError( 

72 "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`" 

73 ) 

74 

75 if dask_it: 

76 transformer = bob.pipelines.wrap(["dask"], transformer) 

77 

78 return transformer 

79 

80 

81def wrap_sample_preprocessor( 

82 preprocessor, 

83 transform_extra_arguments=(("annotations", "annotations"),), 

84 **kwargs, 

85): 

86 """ 

87 Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with 

88 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` 

89 

90 .. warning:: 

91 This wrapper doesn't checkpoint data 

92 

93 Parameters 

94 ---------- 

95 

96 preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor` 

97 Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped 

98 

99 transform_extra_arguments: [tuple] 

100 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` 

101 

102 """ 

103 

104 transformer = PreprocessorTransformer(preprocessor) 

105 return bob.pipelines.wrap( 

106 ["sample"], 

107 transformer, 

108 transform_extra_arguments=transform_extra_arguments, 

109 ) 

110 

111 

112def wrap_checkpoint_preprocessor( 

113 preprocessor, 

114 features_dir=None, 

115 transform_extra_arguments=(("annotations", "annotations"),), 

116 load_func=None, 

117 save_func=None, 

118 extension=".hdf5", 

119): 

120 """ 

121 Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with 

122 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` 

123 

124 Parameters 

125 ---------- 

126 

127 preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor` 

128 Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped 

129 

130 features_dir: str 

131 Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`). 

132 

133 extension : str, optional 

134 Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`). 

135 

136 load_func : None, optional 

137 Function that loads data to be preprocessed. 

138 The default is :any:`bob.bio.base.preprocessor.Preprocessor.read_data` 

139 

140 save_func : None, optional 

141 Function that saves preprocessed data. 

142 The default is :any:`bob.bio.base.preprocessor.Preprocessor.write_data` 

143 

144 transform_extra_arguments: [tuple] 

145 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` 

146 

147 """ 

148 

149 transformer = PreprocessorTransformer(preprocessor) 

150 return bob.pipelines.wrap( 

151 ["sample", "checkpoint"], 

152 transformer, 

153 load_func=load_func or preprocessor.read_data, 

154 save_func=save_func or preprocessor.write_data, 

155 features_dir=features_dir, 

156 transform_extra_arguments=transform_extra_arguments, 

157 extension=extension, 

158 ) 

159 

160 

161def _prepare_extractor_sample_args( 

162 extractor, transform_extra_arguments, fit_extra_arguments 

163): 

164 if transform_extra_arguments is None and is_argument_available( 

165 "metadata", extractor.__call__ 

166 ): 

167 transform_extra_arguments = (("metadata", "metadata"),) 

168 

169 if ( 

170 fit_extra_arguments is None 

171 and extractor.requires_training 

172 and extractor.split_training_data_by_client 

173 ): 

174 fit_extra_arguments = (("y", "subject"),) 

175 

176 return transform_extra_arguments, fit_extra_arguments 

177 

178 

179def wrap_sample_extractor( 

180 extractor, 

181 fit_extra_arguments=None, 

182 transform_extra_arguments=None, 

183 model_path=None, 

184 **kwargs, 

185): 

186 """ 

187 Wraps :any:`bob.bio.base.extractor.Extractor` with 

188 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` 

189 

190 Parameters 

191 ---------- 

192 

193 extractor: :any:`bob.bio.base.extractor.Preprocessor` 

194 Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped 

195 

196 transform_extra_arguments: [tuple], optional 

197 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` 

198 

199 model_path: str 

200 Path to `extractor_file` in :any:`bob.bio.base.extractor.Extractor` 

201 

202 """ 

203 

204 extractor_file = ( 

205 os.path.join(model_path, "Extractor.hdf5") 

206 if model_path is not None 

207 else None 

208 ) 

209 

210 transformer = ExtractorTransformer(extractor, model_path=extractor_file) 

211 

212 ( 

213 transform_extra_arguments, 

214 fit_extra_arguments, 

215 ) = _prepare_extractor_sample_args( 

216 extractor, transform_extra_arguments, fit_extra_arguments 

217 ) 

218 

219 return bob.pipelines.wrap( 

220 ["sample"], 

221 transformer, 

222 transform_extra_arguments=transform_extra_arguments, 

223 fit_extra_arguments=fit_extra_arguments, 

224 **kwargs, 

225 ) 

226 

227 

228def wrap_checkpoint_extractor( 

229 extractor, 

230 features_dir=None, 

231 fit_extra_arguments=None, 

232 transform_extra_arguments=None, 

233 load_func=None, 

234 save_func=None, 

235 extension=".hdf5", 

236 model_path=None, 

237 **kwargs, 

238): 

239 """ 

240 Wraps :any:`bob.bio.base.extractor.Extractor` with 

241 :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` 

242 

243 Parameters 

244 ---------- 

245 

246 extractor: :any:`bob.bio.base.extractor.Preprocessor` 

247 Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped 

248 

249 features_dir: str 

250 Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`). 

251 

252 extension : str, optional 

253 Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`). 

254 

255 load_func : None, optional 

256 Function that loads data to be preprocessed. 

257 The default is :any:`bob.bio.base.extractor.Extractor.read_feature` 

258 

259 save_func : None, optional 

260 Function that saves preprocessed data. 

261 The default is :any:`bob.bio.base.extractor.Extractor.write_feature` 

262 

263 fit_extra_arguments: [tuple] 

264 Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` 

265 

266 transform_extra_arguments: [tuple], optional 

267 Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` 

268 

269 model_path: str 

270 See :any:`TransformerExtractor`. 

271 

272 """ 

273 

274 extractor_file = ( 

275 os.path.join(model_path, "Extractor.hdf5") 

276 if model_path is not None 

277 else None 

278 ) 

279 

280 model_file = ( 

281 os.path.join(model_path, "Extractor.pkl") 

282 if model_path is not None 

283 else None 

284 ) 

285 transformer = ExtractorTransformer(extractor, model_path=extractor_file) 

286 

287 ( 

288 transform_extra_arguments, 

289 fit_extra_arguments, 

290 ) = _prepare_extractor_sample_args( 

291 extractor, transform_extra_arguments, fit_extra_arguments 

292 ) 

293 

294 return bob.pipelines.wrap( 

295 ["sample", "checkpoint"], 

296 transformer, 

297 load_func=load_func or extractor.read_feature, 

298 save_func=save_func or extractor.write_feature, 

299 model_path=model_file, 

300 features_dir=features_dir, 

301 transform_extra_arguments=transform_extra_arguments, 

302 fit_extra_arguments=fit_extra_arguments, 

303 **kwargs, 

304 )