Coverage for src/bob/pipelines/dataset/database.py: 89%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 21:32 +0200

1""" 

2The principles of this module are: 

3 

4* one csv file -> one set 

5* one row -> one sample 

6* csv files could exist in a tarball or inside a folder 

7* scikit-learn transformers are used to further transform samples 

8* several csv files (sets) compose a protocol 

9* several protocols compose a database 

10""" 

11import csv 

12import itertools 

13import os 

14 

15from collections.abc import Iterable 

16from pathlib import Path 

17from typing import Any, Optional, TextIO, Union 

18 

19import sklearn.pipeline 

20 

21from bob.pipelines.dataset.protocols.retrieve import ( 

22 list_group_names, 

23 list_protocol_names, 

24 open_definition_file, 

25 retrieve_protocols, 

26) 

27 

28from ..sample import Sample 

29from ..utils import check_parameter_for_validity, check_parameters_for_validity 

30 

31 

32def _maybe_open_file(path, **kwargs): 

33 if isinstance(path, (str, bytes, Path)): 

34 path = open(path, **kwargs) 

35 return path 

36 

37 

38class FileListToSamples(Iterable): 

39 """Converts a list of paths and metadata to a list of samples. 

40 

41 This class reads a file containing paths and optionally metadata and returns a list 

42 of :py:class:`bob.pipelines.Sample`\\ s when called. 

43 

44 A separator character can be set (defaults is space) to split the rows. 

45 No escaping is done (no quotes). 

46 

47 A Transformer can be given to apply a transform on each sample. (Keep in mind this 

48 will not be distributed on Dask; Prefer applying Transformer in a 

49 ``bob.pipelines.Pipeline``.) 

50 """ 

51 

52 def __init__( 

53 self, 

54 list_file: str, 

55 separator: str = " ", 

56 transformer: Optional[sklearn.pipeline.Pipeline] = None, 

57 **kwargs, 

58 ): 

59 super().__init__(**kwargs) 

60 self.list_file = list_file 

61 self.transformer = transformer 

62 self.separator = separator 

63 

64 def __iter__(self): 

65 for row_dict in self.rows: 

66 sample = Sample(None, **row_dict) 

67 if self.transformer is not None: 

68 # The transformer might convert one sample to several samples 

69 for s in self.transformer.transform([sample]): 

70 yield s 

71 else: 

72 yield sample 

73 

74 @property 

75 def rows(self) -> dict[str, Any]: 

76 with open(self.list_file, "rt") as f: 

77 for line in f: 

78 yield dict(line.split(self.separator)) 

79 

80 

81class CSVToSamples(FileListToSamples): 

82 """Converts a csv file to a list of samples""" 

83 

84 def __init__( 

85 self, 

86 list_file: str, 

87 transformer: Optional[sklearn.pipeline.Pipeline] = None, 

88 dict_reader_kwargs: Optional[dict[str, Any]] = None, 

89 **kwargs, 

90 ): 

91 list_file = _maybe_open_file(list_file, newline="") 

92 super().__init__( 

93 list_file=list_file, 

94 transformer=transformer, 

95 **kwargs, 

96 ) 

97 self.dict_reader_kwargs = dict_reader_kwargs 

98 

99 @property 

100 def rows(self): 

101 self.list_file.seek(0) 

102 kw = self.dict_reader_kwargs or {} 

103 reader = csv.DictReader(self.list_file, **kw) 

104 return reader 

105 

106 

107class FileListDatabase: 

108 """A generic database interface. 

109 Use this class to convert csv files to a database that outputs samples. The 

110 format is simple, the files must be inside a folder (or a compressed 

111 tarball) with the following format:: 

112 

113 dataset_protocols_path/<protocol>/<group>.csv 

114 

115 The top folders are the name of the protocols (if you only have one, you may 

116 name it ``default``). Inside each protocol folder, there are `<group>.csv` 

117 files where the name of the file specifies the name of the group. We 

118 recommend using the names ``train``, ``dev``, ``eval`` for your typical 

119 training, development, and test sets. 

120 

121 """ 

122 

123 def __init__( 

124 self, 

125 *, 

126 name: str, 

127 protocol: str, 

128 dataset_protocols_path: Union[os.PathLike[str], str, None] = None, 

129 reader_cls: Iterable = CSVToSamples, 

130 transformer: Optional[sklearn.pipeline.Pipeline] = None, 

131 **kwargs, 

132 ): 

133 """ 

134 Parameters 

135 ---------- 

136 dataset_protocols_path 

137 Path to a folder or a tarball where the csv protocol files are located. 

138 protocol 

139 The name of the protocol to be used for samples. If None, the first 

140 protocol found will be used. 

141 reader_cls 

142 An iterable that returns created Sample objects from a list file. 

143 transformer 

144 A scikit-learn transformer that further changes the samples. 

145 

146 Raises 

147 ------ 

148 ValueError 

149 If the dataset_protocols_path does not exist. 

150 """ 

151 

152 # Tricksy trick to make protocols non-classmethod when instantiated 

153 self.protocols = self._instance_protocols 

154 

155 if getattr(self, "name", None) is None: 

156 self.name = name 

157 

158 if dataset_protocols_path is None: 

159 dataset_protocols_path = self.retrieve_dataset_protocols() 

160 

161 self.dataset_protocols_path = Path(dataset_protocols_path) 

162 

163 if len(self.protocols()) < 1: 

164 raise ValueError( 

165 f"No protocols found at `{dataset_protocols_path}`!" 

166 ) 

167 self.reader_cls = reader_cls 

168 self._transformer = transformer 

169 self.readers: dict[str, Iterable] = {} 

170 self._protocol = None 

171 self.protocol = protocol 

172 super().__init__(**kwargs) 

173 

174 @property 

175 def protocol(self) -> str: 

176 return self._protocol 

177 

178 @protocol.setter 

179 def protocol(self, value: str): 

180 value = check_parameter_for_validity( 

181 value, "protocol", self.protocols(), self.protocols()[0] 

182 ) 

183 self._protocol = value 

184 

185 @property 

186 def transformer(self) -> sklearn.pipeline.Pipeline: 

187 return self._transformer 

188 

189 @transformer.setter 

190 def transformer(self, value: sklearn.pipeline.Pipeline): 

191 self._transformer = value 

192 for reader in self.readers.values(): 

193 reader.transformer = value 

194 

195 def groups(self) -> list[str]: 

196 """Returns all the available groups.""" 

197 return list_group_names( 

198 database_name=self.name, 

199 protocol=self.protocol, 

200 database_filename=self.dataset_protocols_path.name, 

201 base_dir=self.dataset_protocols_path.parent, 

202 subdir=".", 

203 ) 

204 

205 def _instance_protocols(self) -> list[str]: 

206 """Returns all the available protocols.""" 

207 return list_protocol_names( 

208 database_name=self.name, 

209 database_filename=self.dataset_protocols_path.name, 

210 base_dir=self.dataset_protocols_path.parent, 

211 subdir=".", 

212 ) 

213 

214 @classmethod 

215 def protocols(cls) -> list[str]: # pylint: disable=method-hidden 

216 """Returns all the available protocols.""" 

217 # Ensure the definition file exists locally 

218 loc = cls.retrieve_dataset_protocols() 

219 if not hasattr(cls, "name"): 

220 raise ValueError(f"{cls} has no attribute 'name'.") 

221 return list_protocol_names( 

222 database_name=getattr(cls, "name"), 

223 database_filename=loc.name, 

224 base_dir=loc.parent, 

225 subdir=".", 

226 ) 

227 

228 @classmethod 

229 def retrieve_dataset_protocols(cls) -> Path: 

230 """Return a path to the protocols definition files. 

231 

232 If the files are not present locally in ``bob_data/<subdir>/<category>``, they 

233 will be downloaded. 

234 

235 The class inheriting from CSVDatabase must have a ``name`` and an 

236 ``dataset_protocols_urls`` attributes. 

237 

238 A ``checksum`` attribute can be used to verify the file and ensure the correct 

239 version is used. 

240 """ 

241 

242 # When the path is specified, just return it. 

243 if getattr(cls, "dataset_protocols_path", None) is not None: 

244 return getattr(cls, "dataset_protocols_path") 

245 

246 # Save to bob_data/protocols, or if present, in a category sub directory. 

247 subdir = Path("protocols") 

248 if hasattr(cls, "category"): 

249 subdir = subdir / getattr(cls, "category") 

250 

251 # Retrieve the file from the server (or use the local version). 

252 return retrieve_protocols( 

253 urls=getattr(cls, "dataset_protocols_urls"), 

254 destination_filename=getattr(cls, "dataset_protocols_name", None), 

255 base_dir=None, 

256 subdir=subdir, 

257 checksum=getattr(cls, "dataset_protocols_checksum", None), 

258 ) 

259 

260 def list_file(self, group: str) -> TextIO: 

261 """Returns the corresponding definition file of a group.""" 

262 list_file = open_definition_file( 

263 search_pattern=group + ".csv", 

264 database_name=self.name, 

265 protocol=self.protocol, 

266 database_filename=self.dataset_protocols_path.name, 

267 base_dir=self.dataset_protocols_path.parent, 

268 subdir=".", 

269 ) 

270 return list_file 

271 

272 def get_reader(self, group: str) -> Iterable: 

273 """Returns an :any:`Iterable` of :any:`Sample` objects.""" 

274 key = (self.protocol, group) 

275 if key not in self.readers: 

276 self.readers[key] = self.reader_cls( 

277 list_file=self.list_file(group), transformer=self.transformer 

278 ) 

279 

280 reader = self.readers[key] 

281 return reader 

282 

283 def samples(self, groups=None): 

284 """Get samples of a certain group 

285 

286 Parameters 

287 ---------- 

288 groups : :obj:`str`, optional 

289 A str or list of str to be used for filtering samples, by default None 

290 

291 Returns 

292 ------- 

293 list 

294 A list containing the samples loaded from csv files. 

295 """ 

296 

297 groups = check_parameters_for_validity( 

298 groups, "groups", self.groups(), self.groups() 

299 ) 

300 all_samples = [] 

301 for grp in groups: 

302 for sample in self.get_reader(grp): 

303 all_samples.append(sample) 

304 

305 return all_samples 

306 

307 @staticmethod 

308 def sort(samples: list[Sample], unique: bool = True): 

309 """Sorts samples and removes duplicates by default.""" 

310 

311 def key_func(x): 

312 return x.key 

313 

314 samples = sorted(samples, key=key_func) 

315 

316 if unique: 

317 samples = [ 

318 next(iter(v)) 

319 for _, v in itertools.groupby(samples, key=key_func) 

320 ] 

321 

322 return samples