Coverage for src/bob/pipelines/dataset/database.py: 89%
116 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 21:32 +0200
1"""
2The principles of this module are:
4* one csv file -> one set
5* one row -> one sample
6* csv files could exist in a tarball or inside a folder
7* scikit-learn transformers are used to further transform samples
8* several csv files (sets) compose a protocol
9* several protocols compose a database
10"""
11import csv
12import itertools
13import os
15from collections.abc import Iterable
16from pathlib import Path
17from typing import Any, Optional, TextIO, Union
19import sklearn.pipeline
21from bob.pipelines.dataset.protocols.retrieve import (
22 list_group_names,
23 list_protocol_names,
24 open_definition_file,
25 retrieve_protocols,
26)
28from ..sample import Sample
29from ..utils import check_parameter_for_validity, check_parameters_for_validity
32def _maybe_open_file(path, **kwargs):
33 if isinstance(path, (str, bytes, Path)):
34 path = open(path, **kwargs)
35 return path
38class FileListToSamples(Iterable):
39 """Converts a list of paths and metadata to a list of samples.
41 This class reads a file containing paths and optionally metadata and returns a list
42 of :py:class:`bob.pipelines.Sample`\\ s when called.
44 A separator character can be set (defaults is space) to split the rows.
45 No escaping is done (no quotes).
47 A Transformer can be given to apply a transform on each sample. (Keep in mind this
48 will not be distributed on Dask; Prefer applying Transformer in a
49 ``bob.pipelines.Pipeline``.)
50 """
52 def __init__(
53 self,
54 list_file: str,
55 separator: str = " ",
56 transformer: Optional[sklearn.pipeline.Pipeline] = None,
57 **kwargs,
58 ):
59 super().__init__(**kwargs)
60 self.list_file = list_file
61 self.transformer = transformer
62 self.separator = separator
64 def __iter__(self):
65 for row_dict in self.rows:
66 sample = Sample(None, **row_dict)
67 if self.transformer is not None:
68 # The transformer might convert one sample to several samples
69 for s in self.transformer.transform([sample]):
70 yield s
71 else:
72 yield sample
74 @property
75 def rows(self) -> dict[str, Any]:
76 with open(self.list_file, "rt") as f:
77 for line in f:
78 yield dict(line.split(self.separator))
81class CSVToSamples(FileListToSamples):
82 """Converts a csv file to a list of samples"""
84 def __init__(
85 self,
86 list_file: str,
87 transformer: Optional[sklearn.pipeline.Pipeline] = None,
88 dict_reader_kwargs: Optional[dict[str, Any]] = None,
89 **kwargs,
90 ):
91 list_file = _maybe_open_file(list_file, newline="")
92 super().__init__(
93 list_file=list_file,
94 transformer=transformer,
95 **kwargs,
96 )
97 self.dict_reader_kwargs = dict_reader_kwargs
99 @property
100 def rows(self):
101 self.list_file.seek(0)
102 kw = self.dict_reader_kwargs or {}
103 reader = csv.DictReader(self.list_file, **kw)
104 return reader
107class FileListDatabase:
108 """A generic database interface.
109 Use this class to convert csv files to a database that outputs samples. The
110 format is simple, the files must be inside a folder (or a compressed
111 tarball) with the following format::
113 dataset_protocols_path/<protocol>/<group>.csv
115 The top folders are the name of the protocols (if you only have one, you may
116 name it ``default``). Inside each protocol folder, there are `<group>.csv`
117 files where the name of the file specifies the name of the group. We
118 recommend using the names ``train``, ``dev``, ``eval`` for your typical
119 training, development, and test sets.
121 """
123 def __init__(
124 self,
125 *,
126 name: str,
127 protocol: str,
128 dataset_protocols_path: Union[os.PathLike[str], str, None] = None,
129 reader_cls: Iterable = CSVToSamples,
130 transformer: Optional[sklearn.pipeline.Pipeline] = None,
131 **kwargs,
132 ):
133 """
134 Parameters
135 ----------
136 dataset_protocols_path
137 Path to a folder or a tarball where the csv protocol files are located.
138 protocol
139 The name of the protocol to be used for samples. If None, the first
140 protocol found will be used.
141 reader_cls
142 An iterable that returns created Sample objects from a list file.
143 transformer
144 A scikit-learn transformer that further changes the samples.
146 Raises
147 ------
148 ValueError
149 If the dataset_protocols_path does not exist.
150 """
152 # Tricksy trick to make protocols non-classmethod when instantiated
153 self.protocols = self._instance_protocols
155 if getattr(self, "name", None) is None:
156 self.name = name
158 if dataset_protocols_path is None:
159 dataset_protocols_path = self.retrieve_dataset_protocols()
161 self.dataset_protocols_path = Path(dataset_protocols_path)
163 if len(self.protocols()) < 1:
164 raise ValueError(
165 f"No protocols found at `{dataset_protocols_path}`!"
166 )
167 self.reader_cls = reader_cls
168 self._transformer = transformer
169 self.readers: dict[str, Iterable] = {}
170 self._protocol = None
171 self.protocol = protocol
172 super().__init__(**kwargs)
174 @property
175 def protocol(self) -> str:
176 return self._protocol
178 @protocol.setter
179 def protocol(self, value: str):
180 value = check_parameter_for_validity(
181 value, "protocol", self.protocols(), self.protocols()[0]
182 )
183 self._protocol = value
185 @property
186 def transformer(self) -> sklearn.pipeline.Pipeline:
187 return self._transformer
189 @transformer.setter
190 def transformer(self, value: sklearn.pipeline.Pipeline):
191 self._transformer = value
192 for reader in self.readers.values():
193 reader.transformer = value
195 def groups(self) -> list[str]:
196 """Returns all the available groups."""
197 return list_group_names(
198 database_name=self.name,
199 protocol=self.protocol,
200 database_filename=self.dataset_protocols_path.name,
201 base_dir=self.dataset_protocols_path.parent,
202 subdir=".",
203 )
205 def _instance_protocols(self) -> list[str]:
206 """Returns all the available protocols."""
207 return list_protocol_names(
208 database_name=self.name,
209 database_filename=self.dataset_protocols_path.name,
210 base_dir=self.dataset_protocols_path.parent,
211 subdir=".",
212 )
214 @classmethod
215 def protocols(cls) -> list[str]: # pylint: disable=method-hidden
216 """Returns all the available protocols."""
217 # Ensure the definition file exists locally
218 loc = cls.retrieve_dataset_protocols()
219 if not hasattr(cls, "name"):
220 raise ValueError(f"{cls} has no attribute 'name'.")
221 return list_protocol_names(
222 database_name=getattr(cls, "name"),
223 database_filename=loc.name,
224 base_dir=loc.parent,
225 subdir=".",
226 )
228 @classmethod
229 def retrieve_dataset_protocols(cls) -> Path:
230 """Return a path to the protocols definition files.
232 If the files are not present locally in ``bob_data/<subdir>/<category>``, they
233 will be downloaded.
235 The class inheriting from CSVDatabase must have a ``name`` and an
236 ``dataset_protocols_urls`` attributes.
238 A ``checksum`` attribute can be used to verify the file and ensure the correct
239 version is used.
240 """
242 # When the path is specified, just return it.
243 if getattr(cls, "dataset_protocols_path", None) is not None:
244 return getattr(cls, "dataset_protocols_path")
246 # Save to bob_data/protocols, or if present, in a category sub directory.
247 subdir = Path("protocols")
248 if hasattr(cls, "category"):
249 subdir = subdir / getattr(cls, "category")
251 # Retrieve the file from the server (or use the local version).
252 return retrieve_protocols(
253 urls=getattr(cls, "dataset_protocols_urls"),
254 destination_filename=getattr(cls, "dataset_protocols_name", None),
255 base_dir=None,
256 subdir=subdir,
257 checksum=getattr(cls, "dataset_protocols_checksum", None),
258 )
260 def list_file(self, group: str) -> TextIO:
261 """Returns the corresponding definition file of a group."""
262 list_file = open_definition_file(
263 search_pattern=group + ".csv",
264 database_name=self.name,
265 protocol=self.protocol,
266 database_filename=self.dataset_protocols_path.name,
267 base_dir=self.dataset_protocols_path.parent,
268 subdir=".",
269 )
270 return list_file
272 def get_reader(self, group: str) -> Iterable:
273 """Returns an :any:`Iterable` of :any:`Sample` objects."""
274 key = (self.protocol, group)
275 if key not in self.readers:
276 self.readers[key] = self.reader_cls(
277 list_file=self.list_file(group), transformer=self.transformer
278 )
280 reader = self.readers[key]
281 return reader
283 def samples(self, groups=None):
284 """Get samples of a certain group
286 Parameters
287 ----------
288 groups : :obj:`str`, optional
289 A str or list of str to be used for filtering samples, by default None
291 Returns
292 -------
293 list
294 A list containing the samples loaded from csv files.
295 """
297 groups = check_parameters_for_validity(
298 groups, "groups", self.groups(), self.groups()
299 )
300 all_samples = []
301 for grp in groups:
302 for sample in self.get_reader(grp):
303 all_samples.append(sample)
305 return all_samples
307 @staticmethod
308 def sort(samples: list[Sample], unique: bool = True):
309 """Sorts samples and removes duplicates by default."""
311 def key_func(x):
312 return x.key
314 samples = sorted(samples, key=key_func)
316 if unique:
317 samples = [
318 next(iter(v))
319 for _, v in itertools.groupby(samples, key=key_func)
320 ]
322 return samples