Coverage for src/bob/pad/base/script/pad_figure.py: 58%
213 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
1"""Runs error analysis on score sets, outputs metrics and plots"""
3import click
4import numpy as np
6from tabulate import tabulate
8import bob.bio.base.script.figure as bio_figure
9import bob.measure.script.figure as measure_figure
11from bob.measure import f_score, farfrr, precision_recall, roc_auc_score
12from bob.measure.utils import get_fta_list
14from ..error_utils import apcer_bpcer, calc_threshold
17def _normalize_input_scores(input_score, input_name):
18 pos, negs = input_score
19 # convert scores to sorted numpy arrays and keep a copy of all negatives
20 pos = np.ascontiguousarray(pos)
21 pos.sort()
22 all_negs = np.ascontiguousarray([s for neg in negs.values() for s in neg])
23 all_negs.sort()
24 # FTA is calculated on pos and all_negs so we remove nans from negs
25 for k, v in negs.items():
26 v = np.ascontiguousarray(v)
27 v.sort()
28 negs[k] = v[~np.isnan(v)]
29 neg_list, pos_list, fta_list = get_fta_list([(all_negs, pos)])
30 all_negs, pos, fta = neg_list[0], pos_list[0], fta_list[0]
31 return input_name, pos, negs, all_negs, fta
34class Metrics(bio_figure.Metrics):
35 """Compute metrics from score files"""
37 def __init__(self, ctx, scores, evaluation, func_load, names):
38 if isinstance(names, str):
39 names = names.split(",")
40 super(Metrics, self).__init__(ctx, scores, evaluation, func_load, names)
42 def get_thres(self, criterion, pos, negs, all_negs, far_value):
43 return calc_threshold(
44 criterion,
45 pos=pos,
46 negs=negs.values(),
47 all_negs=all_negs,
48 far_value=far_value,
49 is_sorted=True,
50 )
52 def _numbers(self, threshold, pos, negs, all_negs, fta):
53 pais = list(negs.keys())
54 apcer_pais, apcer_ap, bpcer = apcer_bpcer(
55 threshold, pos, *[negs[k] for k in pais]
56 )
57 apcer_pais = {k: apcer_pais[i] for i, k in enumerate(pais)}
58 acer = (apcer_ap + bpcer) / 2.0
59 fpr, fnr = farfrr(all_negs, pos, threshold)
60 hter = (fpr + fnr) / 2.0
61 far = fpr * (1 - fta)
62 frr = fta + fnr * (1 - fta)
64 nn = all_negs.shape[0] # number of attack
65 fp = int(round(fpr * nn)) # number of false positives
66 np = pos.shape[0] # number of bonafide
67 fn = int(round(fnr * np)) # number of false negatives
69 # precision and recall
70 precision, recall = precision_recall(all_negs, pos, threshold)
72 # f_score
73 f1_score = f_score(all_negs, pos, threshold, 1)
75 # auc
76 auc = roc_auc_score(all_negs, pos)
77 auc_log = roc_auc_score(all_negs, pos, log_scale=True)
79 metrics = dict(
80 apcer_pais=apcer_pais,
81 apcer_ap=apcer_ap,
82 bpcer=bpcer,
83 acer=acer,
84 fta=fta,
85 fpr=fpr,
86 fnr=fnr,
87 hter=hter,
88 far=far,
89 frr=frr,
90 fp=fp,
91 nn=nn,
92 fn=fn,
93 np=np,
94 precision=precision,
95 recall=recall,
96 f1_score=f1_score,
97 auc=auc,
98 )
99 metrics["auc-log-scale"] = auc_log
100 return metrics
102 def _strings(self, metrics):
103 n_dec = ".%df" % self._decimal
104 for k, v in metrics.items():
105 if k in ("precision", "recall", "f1_score", "auc", "auc-log-scale"):
106 metrics[k] = "%s" % format(v, n_dec)
107 elif k in ("np", "nn", "fp", "fn"):
108 continue
109 elif k in ("fpr", "fnr"):
110 if "fp" in metrics:
111 metrics[k] = "%s%% (%d/%d)" % (
112 format(100 * v, n_dec),
113 metrics["fp" if k == "fpr" else "fn"],
114 metrics["nn" if k == "fpr" else "np"],
115 )
116 else:
117 metrics[k] = "%s%%" % format(100 * v, n_dec)
118 elif k == "apcer_pais":
119 metrics[k] = {
120 k1: "%s%%" % format(100 * v1, n_dec) for k1, v1 in v.items()
121 }
122 else:
123 metrics[k] = "%s%%" % format(100 * v, n_dec)
125 return metrics
127 def _get_all_metrics(self, idx, input_scores, input_names):
128 """Compute all metrics for dev and eval scores"""
129 for i, (score, name) in enumerate(zip(input_scores, input_names)):
130 input_scores[i] = _normalize_input_scores(score, name)
132 dev_file, dev_pos, dev_negs, dev_all_negs, dev_fta = input_scores[0]
133 if self._eval:
134 (
135 eval_file,
136 eval_pos,
137 eval_negs,
138 eval_all_negs,
139 eval_fta,
140 ) = input_scores[1]
142 threshold = (
143 self.get_thres(
144 self._criterion, dev_pos, dev_negs, dev_all_negs, self._far
145 )
146 if self._thres is None
147 else self._thres[idx]
148 )
150 title = self._legends[idx] if self._legends is not None else None
151 if self._thres is None:
152 far_str = ""
153 if self._criterion == "far" and self._far is not None:
154 far_str = str(self._far)
155 click.echo(
156 "[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
157 % (
158 self._criterion.upper(),
159 far_str,
160 title or dev_file,
161 threshold,
162 ),
163 file=self.log_file,
164 )
165 else:
166 click.echo(
167 "[Min. criterion: user provided] Threshold on "
168 "Development set `%s`: %e" % (dev_file or title, threshold),
169 file=self.log_file,
170 )
172 res = []
173 res.append(
174 self._strings(
175 self._numbers(
176 threshold, dev_pos, dev_negs, dev_all_negs, dev_fta
177 )
178 )
179 )
181 if self._eval:
182 # computes statistics for the eval set based on the threshold a priori
183 res.append(
184 self._strings(
185 self._numbers(
186 threshold, eval_pos, eval_negs, eval_all_negs, eval_fta
187 )
188 )
189 )
190 else:
191 res.append(None)
193 return res
195 def compute(self, idx, input_scores, input_names):
196 """Compute metrics for the given criteria"""
197 title = self._legends[idx] if self._legends is not None else None
198 all_metrics = self._get_all_metrics(idx, input_scores, input_names)
199 headers = [" " or title, "Development"]
200 if self._eval:
201 headers.append("Evaluation")
202 rows = []
204 for name in self.names:
205 if name == "apcer_pais":
206 for k, v in all_metrics[0][name].items():
207 print_name = f"APCER ({k})"
208 rows += [[print_name, v]]
209 if self._eval:
210 rows[-1].append(all_metrics[1][name][k])
211 continue
212 print_name = name.upper()
213 rows += [[print_name, all_metrics[0][name]]]
214 if self._eval:
215 rows[-1].append(all_metrics[1][name])
217 click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
220class MultiMetrics(Metrics):
221 """Compute metrics from score files"""
223 def __init__(self, ctx, scores, evaluation, func_load, names):
224 super(MultiMetrics, self).__init__(
225 ctx, scores, evaluation, func_load, names=names
226 )
227 self.rows = []
228 self.headers = None
229 self.pais = None
231 def _compute_headers(self, pais):
232 names = list(self.names)
233 if "apcer_pais" in names:
234 idx = names.index("apcer_pais")
235 names = (
236 [n.upper() for n in names[:idx]]
237 + self.pais
238 + [n.upper() for n in names[idx + 1 :]]
239 )
240 self.headers = ["Methods"] + names
241 if self._eval and "hter" in self.names:
242 self.headers.insert(1, "HTER (dev)")
244 def _strings(self, metrics):
245 formatted_metrics = dict()
246 for name in self.names:
247 if name == "apcer_pais":
248 for pai in self.pais:
249 mean = metrics[pai].mean()
250 std = metrics[pai].std()
251 mean = super()._strings({pai: mean})[pai]
252 std = super()._strings({pai: std})[pai]
253 formatted_metrics[pai] = f"{mean} ({std})"
254 else:
255 mean = metrics[name].mean()
256 std = metrics[name].std()
257 mean = super()._strings({name: mean})[name]
258 std = super()._strings({name: std})[name]
259 formatted_metrics[name] = f"{mean} ({std})"
261 return formatted_metrics
263 def _structured_array(self, metrics):
264 names = list(metrics[0].keys())
265 if "apcer_pais" in names:
266 idx = names.index("apcer_pais")
267 pais = list(
268 f"APCER ({pai})" for pai in metrics[0]["apcer_pais"].keys()
269 )
270 names = names[:idx] + pais + names[idx + 1 :]
271 self.pais = self.pais or pais
272 formats = [float] * len(names)
273 dtype = dict(names=names, formats=formats)
274 array = []
275 for each in metrics:
276 array.append([])
277 for k, v in each.items():
278 if k == "apcer_pais":
279 array[-1].extend(list(v.values()))
280 else:
281 array[-1].append(v)
282 array = [tuple(a) for a in array]
283 return np.array(array, dtype=dtype)
285 def compute(self, idx, input_scores, input_names):
286 """Computes the average of metrics over several protocols."""
287 for i, (score, name) in enumerate(zip(input_scores, input_names)):
288 input_scores[i] = _normalize_input_scores(score, name)
290 step = 2 if self._eval else 1
291 self._dev_metrics = []
292 self._thresholds = []
293 for scores in input_scores[::step]:
294 name, pos, negs, all_negs, fta = scores
295 threshold = (
296 self.get_thres(self._criterion, pos, negs, all_negs, self._far)
297 if self._thres is None
298 else self._thres[idx]
299 )
300 self._thresholds.append(threshold)
301 self._dev_metrics.append(
302 self._numbers(threshold, pos, negs, all_negs, fta)
303 )
304 self._dev_metrics = self._structured_array(self._dev_metrics)
306 if self._eval:
307 self._eval_metrics = []
308 for i, scores in enumerate(input_scores[1::step]):
309 name, pos, negs, all_negs, fta = scores
310 threshold = self._thresholds[i]
311 self._eval_metrics.append(
312 self._numbers(threshold, pos, negs, all_negs, fta)
313 )
314 self._eval_metrics = self._structured_array(self._eval_metrics)
316 title = self._legends[idx] if self._legends is not None else name
318 dev_metrics = self._strings(self._dev_metrics)
320 if self._eval and "hter" in dev_metrics:
321 self.rows.append([title, dev_metrics["hter"]])
322 elif not self._eval:
323 row = [title]
324 for name in self.names:
325 if name == "apcer_pais":
326 for pai in self.pais:
327 row += [dev_metrics[pai]]
328 else:
329 row += [dev_metrics[name]]
330 self.rows.append(row)
331 else:
332 self.rows.append([title])
334 if self._eval:
335 eval_metrics = self._strings(self._eval_metrics)
336 row = []
337 for name in self.names:
338 if name == "apcer_pais":
339 for pai in self.pais:
340 row += [eval_metrics[pai]]
341 else:
342 row += [eval_metrics[name]]
344 self.rows[-1].extend(row)
346 # compute header based on found PAI names
347 if self.headers is None:
348 self._compute_headers(self.pais)
350 def end_process(self):
351 click.echo(
352 tabulate(self.rows, self.headers, self._tablefmt),
353 file=self.log_file,
354 )
355 super(MultiMetrics, self).end_process()
358class Roc(bio_figure.Roc):
359 """ROC for PAD"""
361 def __init__(self, ctx, scores, evaluation, func_load):
362 super(Roc, self).__init__(ctx, scores, evaluation, func_load)
363 self._x_label = ctx.meta.get("x_label") or "APCER"
364 default_y_label = "1-BPCER" if self._tpr else "BPCER"
365 self._y_label = ctx.meta.get("y_label") or default_y_label
368class Det(bio_figure.Det):
369 def __init__(self, ctx, scores, evaluation, func_load):
370 super(Det, self).__init__(ctx, scores, evaluation, func_load)
371 self._x_label = ctx.meta.get("x_label") or "APCER (%)"
372 self._y_label = ctx.meta.get("y_label") or "BPCER (%)"
375class Hist(measure_figure.Hist):
376 """Histograms for PAD"""
378 def _setup_hist(self, neg, pos):
379 self._title_base = "PAD"
380 self._density_hist(pos[0], n=0, label="Bona-fide", color="C1")
381 self._density_hist(
382 neg[0],
383 n=1,
384 label="Presentation attack",
385 alpha=0.4,
386 color="C7",
387 hatch="\\\\",
388 )