Coverage for src/bob/pad/base/pipelines/abstract_classes.py: 47%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 23:40 +0200

1from __future__ import annotations 

2 

3from abc import ABCMeta, abstractmethod 

4 

5from bob.pipelines import Sample 

6 

7 

8class Database(metaclass=ABCMeta): 

9 """Base database class for PAD experiments.""" 

10 

11 @abstractmethod 

12 def fit_samples(self) -> list[Sample]: 

13 """Returns :any:`bob.pipelines.Sample`'s to train a PAD model. 

14 

15 Returns 

16 ------- 

17 samples : list 

18 List of samples for model training. 

19 """ 

20 pass 

21 

22 @abstractmethod 

23 def predict_samples(self, group: str = "dev") -> list[Sample]: 

24 """Returns :any:`bob.pipelines.Sample`'s to be scored. 

25 

26 Parameters 

27 ---------- 

28 group : :py:class:`str`, optional 

29 Limits samples to this group 

30 

31 Returns 

32 ------- 

33 samples : list 

34 List of samples to be scored. 

35 """ 

36 pass 

37 

38 def all_samples( 

39 self, groups: str | list[str] | None = None 

40 ) -> list[Sample]: 

41 """Returns all the samples of the database in one list. 

42 

43 Giving ``groups`` will restrict the ``predict_samples`` to those groups. 

44 """ 

45 samples = self.fit_samples() 

46 if groups is not None: 

47 if type(groups) is str: 

48 groups = [groups] 

49 for group in groups: 

50 samples.extend(self.predict_samples(group=group)) 

51 else: 

52 samples.extend(self.predict_samples(group=group)) 

53 return samples