Coverage for src/bob/pad/base/database/csv_dataset.py: 33%
33 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
5from bob.bio.base.database.legacy import check_parameters_for_validity
6from bob.pad.base.pipelines.abstract_classes import Database
7from bob.pipelines.dataset import FileListDatabase
10def validate_pad_sample(sample):
11 if not hasattr(sample, "subject"):
12 raise RuntimeError(
13 "PAD samples should contain a `subject` attribute which "
14 "reveals the identifies the person from whom the sample is created."
15 )
16 if not hasattr(sample, "attack_type"):
17 raise RuntimeError(
18 "PAD samples should contain a `attack_type` attribute which "
19 "should be '' for bona fide samples and something like "
20 "print, replay, mask, etc. for attacks. This attribute is "
21 "considered the PAI type of each attack is used to compute APCER."
22 )
23 if sample.attack_type == "":
24 sample.attack_type = None
25 sample.is_bonafide = sample.attack_type is None
26 if not hasattr(sample, "key"):
27 sample.key = sample.filename
28 return sample
31class FileListPadDatabase(Database, FileListDatabase):
32 """A PAD database interface from CSV files."""
34 def __init__(
35 self,
36 name,
37 dataset_protocols_path,
38 protocol,
39 transformer=None,
40 **kwargs,
41 ):
42 super().__init__(
43 name=name,
44 dataset_protocols_path=dataset_protocols_path,
45 protocol=protocol,
46 transformer=transformer,
47 **kwargs,
48 )
50 def __repr__(self) -> str:
51 return "FileListPadDatabase(dataset_protocols_path='{}', protocol='{}', transformer={})".format(
52 self.dataset_protocols_path, self.protocol, self.transformer
53 )
55 def purposes(self):
56 return ("real", "attack")
58 def samples(self, groups=None, purposes=None):
59 results = super().samples(groups=groups)
60 purposes = check_parameters_for_validity(
61 purposes, "purposes", self.purposes(), self.purposes()
62 )
64 def _filter(s):
65 return (s.is_bonafide and "real" in purposes) or (
66 (not s.is_bonafide) and "attack" in purposes
67 )
69 results = [validate_pad_sample(sample) for sample in results]
70 results = list(filter(_filter, results))
71 return results
73 def fit_samples(self):
74 return self.samples(groups="train")
76 def predict_samples(self, group="dev"):
77 return self.samples(groups=group)