Coverage for src/bob/pad/base/pipelines/abstract_classes.py: 47%
19 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
1from __future__ import annotations
3from abc import ABCMeta, abstractmethod
5from bob.pipelines import Sample
8class Database(metaclass=ABCMeta):
9 """Base database class for PAD experiments."""
11 @abstractmethod
12 def fit_samples(self) -> list[Sample]:
13 """Returns :any:`bob.pipelines.Sample`'s to train a PAD model.
15 Returns
16 -------
17 samples : list
18 List of samples for model training.
19 """
20 pass
22 @abstractmethod
23 def predict_samples(self, group: str = "dev") -> list[Sample]:
24 """Returns :any:`bob.pipelines.Sample`'s to be scored.
26 Parameters
27 ----------
28 group : :py:class:`str`, optional
29 Limits samples to this group
31 Returns
32 -------
33 samples : list
34 List of samples to be scored.
35 """
36 pass
38 def all_samples(
39 self, groups: str | list[str] | None = None
40 ) -> list[Sample]:
41 """Returns all the samples of the database in one list.
43 Giving ``groups`` will restrict the ``predict_samples`` to those groups.
44 """
45 samples = self.fit_samples()
46 if groups is not None:
47 if type(groups) is str:
48 groups = [groups]
49 for group in groups:
50 samples.extend(self.predict_samples(group=group))
51 else:
52 samples.extend(self.predict_samples(group=group))
53 return samples