Coverage for src/bob/measure/script/figure.py: 88%
550 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:10 +0200
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:10 +0200
1"""Runs error analysis on score sets, outputs metrics and plots"""
3from __future__ import division, print_function
5import logging
6import math
7import sys
9from abc import ABCMeta, abstractmethod
11import click
12import matplotlib
13import matplotlib.pyplot as mpl
14import numpy
16from matplotlib import gridspec
17from matplotlib.backends.backend_pdf import PdfPages
18from tabulate import tabulate
20from .. import far_threshold, plot, ppndf, utils
22LOGGER = logging.getLogger("bob.measure")
25def check_list_value(values, desired_number, name, name2="systems"):
26 if values is not None and len(values) != desired_number:
27 if len(values) == 1:
28 values = values * desired_number
29 else:
30 raise click.BadParameter(
31 "#{} ({}) must be either 1 value or the same as "
32 "#{} ({} values)".format(name, values, name2, desired_number)
33 )
35 return values
38class MeasureBase(object):
39 """Base class for metrics and plots.
40 This abstract class define the framework to plot or compute metrics from a
41 list of (positive, negative) scores tuples.
43 Attributes
44 ----------
45 func_load:
46 Function that is used to load the input files
47 """
49 __metaclass__ = ABCMeta # for python 2.7 compatibility
51 def __init__(self, ctx, scores, evaluation, func_load):
52 """
53 Parameters
54 ----------
55 ctx : :py:class:`dict`
56 Click context dictionary.
58 scores : :any:`list`:
59 List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
60 {dev,eval}-scores2)
61 eval : :py:class:`bool`
62 True if eval data are used
63 func_load : Function that is used to load the input files
64 """
65 self._scores = scores
66 self._ctx = ctx
67 self.func_load = func_load
68 self._legends = ctx.meta.get("legends")
69 self._eval = evaluation
70 self._min_arg = ctx.meta.get("min_arg", 1)
71 if len(scores) < 1 or len(scores) % self._min_arg != 0:
72 raise click.BadParameter(
73 "Number of argument must be a non-zero multiple of %d"
74 % self._min_arg
75 )
76 self.n_systems = int(len(scores) / self._min_arg)
77 if self._legends is not None and len(self._legends) < self.n_systems:
78 raise click.BadParameter(
79 "Number of legends must be >= to the " "number of systems"
80 )
82 def run(self):
83 """Generate outputs (e.g. metrics, files, pdf plots).
84 This function calls abstract methods
85 :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
86 loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
87 (in the loop iterating through the different
88 systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
89 (after the loop).
90 """
91 # init matplotlib, log files, ...
92 self.init_process()
93 # iterates through the different systems and feed `compute`
94 # with the dev (and eval) scores of each system
95 # Note that more than one dev or eval scores score can be passed to
96 # each system
97 for idx in range(self.n_systems):
98 # load scores for each system: get the corresponding arrays and
99 # base-name of files
100 input_scores, input_names = self._load_files(
101 # Scores are given as followed:
102 # SysA-dev SysA-eval ... SysA-XX SysB-dev SysB-eval ... SysB-XX
103 # ------------------------------ ------------------------------
104 # First set of `self._min_arg` Second set of input files
105 # input files starting at for SysB
106 # index idx * self._min_arg
107 self._scores[idx * self._min_arg : (idx + 1) * self._min_arg]
108 )
109 LOGGER.info("-----Input files for system %d-----", idx + 1)
110 for i, name in enumerate(input_names):
111 if not self._eval:
112 LOGGER.info("Dev. score %d: %s", i + 1, name)
113 else:
114 if i % 2 == 0:
115 LOGGER.info("Dev. score %d: %s", i / 2 + 1, name)
116 else:
117 LOGGER.info("Eval. score %d: %s", i / 2 + 1, name)
118 LOGGER.info("----------------------------------")
120 self.compute(idx, input_scores, input_names)
121 # setup final configuration, plotting properties, ...
122 self.end_process()
124 # protected functions that need to be overwritten
125 def init_process(self):
126 """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
127 before iterating through the different systems.
128 Should reimplemented in derived classes"""
129 pass
131 # Main computations are done here in the subclasses
132 @abstractmethod
133 def compute(self, idx, input_scores, input_names):
134 """Compute metrics or plots from the given scores provided by
135 :py:func:`~bob.measure.script.figure.MeasureBase.run`.
136 Should reimplemented in derived classes
138 Parameters
139 ----------
140 idx : :obj:`int`
141 index of the system
142 input_scores: :any:`list`
143 list of scores returned by the loading function
144 input_names: :any:`list`
145 list of base names for the input file of the system
146 """
147 pass
148 # structure of input is (vuln example):
149 # if evaluation is provided
150 # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
151 # (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
152 # and if only dev:
153 # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]
155 # Things to do after the main iterative computations are done
156 @abstractmethod
157 def end_process(self):
158 """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
159 after iterating through the different systems.
160 Should reimplemented in derived classes"""
161 pass
163 # common protected functions
165 def _load_files(self, filepaths):
166 """Load the input files and return the base names of the files
168 Returns
169 -------
170 scores: :any:`list`:
171 A list that contains the output of
172 ``func_load`` for the given files
173 basenames: :any:`list`:
174 A list of the given files
175 """
176 scores = []
177 basenames = []
178 for filename in filepaths:
179 basenames.append(filename)
180 scores.append(self.func_load(filename))
181 return scores, basenames
184class Metrics(MeasureBase):
185 """Compute metrics from score files
187 Attributes
188 ----------
189 log_file: str
190 output stream
191 """
193 def __init__(
194 self,
195 ctx,
196 scores,
197 evaluation,
198 func_load,
199 names=(
200 "False Positive Rate",
201 "False Negative Rate",
202 "Precision",
203 "Recall",
204 "F1-score",
205 "Area Under ROC Curve",
206 "Area Under ROC Curve (log scale)",
207 ),
208 ):
209 super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
210 self.names = names
211 self._tablefmt = ctx.meta.get("tablefmt")
212 self._criterion = ctx.meta.get("criterion")
213 self._open_mode = ctx.meta.get("open_mode")
214 self._thres = ctx.meta.get("thres")
215 self._decimal = ctx.meta.get("decimal", 2)
216 if self._thres is not None:
217 if len(self._thres) == 1:
218 self._thres = self._thres * self.n_systems
219 elif len(self._thres) != self.n_systems:
220 raise click.BadParameter(
221 "#thresholds must be the same as #systems (%d)"
222 % len(self.n_systems)
223 )
224 self._far = ctx.meta.get("far_value")
225 self._log = ctx.meta.get("log")
226 self.log_file = sys.stdout
227 if self._log is not None:
228 self.log_file = open(self._log, self._open_mode)
230 def get_thres(self, criterion, dev_neg, dev_pos, far):
231 return utils.get_thres(criterion, dev_neg, dev_pos, far)
233 def _numbers(self, neg, pos, threshold, fta):
234 from .. import f_score, farfrr, precision_recall, roc_auc_score
236 # fpr and fnr
237 fmr, fnmr = farfrr(neg, pos, threshold)
238 hter = (fmr + fnmr) / 2.0
239 far = fmr * (1 - fta)
240 frr = fta + fnmr * (1 - fta)
242 ni = neg.shape[0] # number of impostors
243 fm = int(round(fmr * ni)) # number of false accepts
244 nc = pos.shape[0] # number of clients
245 fnm = int(round(fnmr * nc)) # number of false rejects
247 # precision and recall
248 precision, recall = precision_recall(neg, pos, threshold)
250 # f_score
251 f1_score = f_score(neg, pos, threshold, 1)
253 # AUC ROC
254 auc = roc_auc_score(neg, pos)
255 auc_log = roc_auc_score(neg, pos, log_scale=True)
256 return (
257 fta,
258 fmr,
259 fnmr,
260 hter,
261 far,
262 frr,
263 fm,
264 ni,
265 fnm,
266 nc,
267 precision,
268 recall,
269 f1_score,
270 auc,
271 auc_log,
272 )
274 def _strings(self, metrics):
275 n_dec = ".%df" % self._decimal
276 fta_str = "%s%%" % format(100 * metrics[0], n_dec)
277 fmr_str = "%s%% (%d/%d)" % (
278 format(100 * metrics[1], n_dec),
279 metrics[6],
280 metrics[7],
281 )
282 fnmr_str = "%s%% (%d/%d)" % (
283 format(100 * metrics[2], n_dec),
284 metrics[8],
285 metrics[9],
286 )
287 far_str = "%s%%" % format(100 * metrics[4], n_dec)
288 frr_str = "%s%%" % format(100 * metrics[5], n_dec)
289 hter_str = "%s%%" % format(100 * metrics[3], n_dec)
290 prec_str = "%s" % format(metrics[10], n_dec)
291 recall_str = "%s" % format(metrics[11], n_dec)
292 f1_str = "%s" % format(metrics[12], n_dec)
293 auc_str = "%s" % format(metrics[13], n_dec)
294 auc_log_str = "%s" % format(metrics[14], n_dec)
296 return (
297 fta_str,
298 fmr_str,
299 fnmr_str,
300 far_str,
301 frr_str,
302 hter_str,
303 prec_str,
304 recall_str,
305 f1_str,
306 auc_str,
307 auc_log_str,
308 )
310 def _get_all_metrics(self, idx, input_scores, input_names):
311 """Compute all metrics for dev and eval scores"""
312 neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
313 dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
314 dev_file = input_names[0]
315 if self._eval:
316 eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
318 threshold = (
319 self.get_thres(self._criterion, dev_neg, dev_pos, self._far)
320 if self._thres is None
321 else self._thres[idx]
322 )
324 title = self._legends[idx] if self._legends is not None else None
325 if self._thres is None:
326 far_str = ""
327 if self._criterion == "far" and self._far is not None:
328 far_str = str(self._far)
329 click.echo(
330 "[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
331 % (
332 self._criterion.upper(),
333 far_str,
334 title or dev_file,
335 threshold,
336 ),
337 file=self.log_file,
338 )
339 else:
340 click.echo(
341 "[Min. criterion: user provided] Threshold on "
342 "Development set `%s`: %e" % (dev_file or title, threshold),
343 file=self.log_file,
344 )
346 res = []
347 res.append(
348 self._strings(self._numbers(dev_neg, dev_pos, threshold, dev_fta))
349 )
351 if self._eval:
352 # computes statistics for the eval set based on the threshold a
353 # priori
354 res.append(
355 self._strings(
356 self._numbers(eval_neg, eval_pos, threshold, eval_fta)
357 )
358 )
359 else:
360 res.append(None)
362 return res
364 def compute(self, idx, input_scores, input_names):
365 """Compute metrics thresholds and tables (FPR, FNR, precision, recall,
366 f1_score) for given system inputs"""
367 dev_file = input_names[0]
368 title = self._legends[idx] if self._legends is not None else None
369 all_metrics = self._get_all_metrics(idx, input_scores, input_names)
370 fta_dev = float(all_metrics[0][0].replace("%", ""))
371 if fta_dev > 0.0:
372 LOGGER.warn(
373 "NaNs scores (%s) were found in %s amd removed",
374 all_metrics[0][0],
375 dev_file,
376 )
377 headers = [" " or title, "Development"]
378 rows = [
379 [self.names[0], all_metrics[0][1]],
380 [self.names[1], all_metrics[0][2]],
381 [self.names[2], all_metrics[0][6]],
382 [self.names[3], all_metrics[0][7]],
383 [self.names[4], all_metrics[0][8]],
384 [self.names[5], all_metrics[0][9]],
385 [self.names[6], all_metrics[0][10]],
386 ]
388 if self._eval:
389 eval_file = input_names[1]
390 fta_eval = float(all_metrics[1][0].replace("%", ""))
391 if fta_eval > 0.0:
392 LOGGER.warn(
393 "NaNs scores (%s) were found in %s and removed.",
394 all_metrics[1][0],
395 eval_file,
396 )
397 # computes statistics for the eval set based on the threshold a
398 # priori
399 headers.append("Evaluation")
400 rows[0].append(all_metrics[1][1])
401 rows[1].append(all_metrics[1][2])
402 rows[2].append(all_metrics[1][6])
403 rows[3].append(all_metrics[1][7])
404 rows[4].append(all_metrics[1][8])
405 rows[5].append(all_metrics[1][9])
406 rows[6].append(all_metrics[1][10])
408 click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
410 def end_process(self):
411 """Close log file if needed"""
412 if self._log is not None:
413 self.log_file.close()
416class MultiMetrics(Metrics):
417 """Computes average of metrics based on several protocols (cross
418 validation)
420 Attributes
421 ----------
422 log_file : str
423 output stream
424 names : tuple
425 List of names for the metrics.
426 """
428 def __init__(
429 self,
430 ctx,
431 scores,
432 evaluation,
433 func_load,
434 names=(
435 "NaNs Rate",
436 "False Positive Rate",
437 "False Negative Rate",
438 "False Accept Rate",
439 "False Reject Rate",
440 "Half Total Error Rate",
441 ),
442 ):
443 super(MultiMetrics, self).__init__(
444 ctx, scores, evaluation, func_load, names=names
445 )
447 self.headers = ["Methods"] + list(self.names)
448 if self._eval:
449 self.headers.insert(1, self.names[5] + " (dev)")
450 self.rows = []
452 def _strings(self, metrics):
453 (
454 ftam,
455 fmrm,
456 fnmrm,
457 hterm,
458 farm,
459 frrm,
460 _,
461 _,
462 _,
463 _,
464 _,
465 _,
466 _,
467 ) = metrics.mean(axis=0)
468 ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _, _, _, _ = metrics.std(
469 axis=0
470 )
471 n_dec = ".%df" % self._decimal
472 fta_str = "%s%% (%s%%)" % (
473 format(100 * ftam, n_dec),
474 format(100 * ftas, n_dec),
475 )
476 fmr_str = "%s%% (%s%%)" % (
477 format(100 * fmrm, n_dec),
478 format(100 * fmrs, n_dec),
479 )
480 fnmr_str = "%s%% (%s%%)" % (
481 format(100 * fnmrm, n_dec),
482 format(100 * fnmrs, n_dec),
483 )
484 far_str = "%s%% (%s%%)" % (
485 format(100 * farm, n_dec),
486 format(100 * fars, n_dec),
487 )
488 frr_str = "%s%% (%s%%)" % (
489 format(100 * frrm, n_dec),
490 format(100 * frrs, n_dec),
491 )
492 hter_str = "%s%% (%s%%)" % (
493 format(100 * hterm, n_dec),
494 format(100 * hters, n_dec),
495 )
496 return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str
498 def compute(self, idx, input_scores, input_names):
499 """Computes the average of metrics over several protocols."""
500 neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
501 step = 2 if self._eval else 1
502 self._dev_metrics = []
503 self._thresholds = []
504 for i in range(0, len(input_scores), step):
505 neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
506 threshold = (
507 self.get_thres(self._criterion, neg, pos, self._far)
508 if self._thres is None
509 else self._thres[idx]
510 )
511 self._thresholds.append(threshold)
512 self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
513 self._dev_metrics = numpy.array(self._dev_metrics)
515 if self._eval:
516 self._eval_metrics = []
517 for i in range(1, len(input_scores), step):
518 neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
519 threshold = self._thresholds[i // 2]
520 self._eval_metrics.append(
521 self._numbers(neg, pos, threshold, fta)
522 )
523 self._eval_metrics = numpy.array(self._eval_metrics)
525 title = self._legends[idx] if self._legends is not None else None
527 fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = self._strings(
528 self._dev_metrics
529 )
531 if self._eval:
532 self.rows.append([title, hter_str])
533 else:
534 self.rows.append(
535 [title, fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str]
536 )
538 if self._eval:
539 # computes statistics for the eval set based on the threshold a
540 # priori
541 (
542 fta_str,
543 fmr_str,
544 fnmr_str,
545 far_str,
546 frr_str,
547 hter_str,
548 ) = self._strings(self._eval_metrics)
550 self.rows[-1].extend(
551 [fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str]
552 )
554 def end_process(self):
555 click.echo(
556 tabulate(self.rows, self.headers, self._tablefmt),
557 file=self.log_file,
558 )
559 super(MultiMetrics, self).end_process()
562class PlotBase(MeasureBase):
563 """Base class for plots. Regroup several options and code
564 shared by the different plots
565 """
567 def __init__(self, ctx, scores, evaluation, func_load):
568 super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
569 self._output = ctx.meta.get("output")
570 self._points = ctx.meta.get("points", 2000)
571 self._split = ctx.meta.get("split")
572 self._axlim = ctx.meta.get("axlim")
573 self._alpha = ctx.meta.get("alpha")
574 self._disp_legend = ctx.meta.get("disp_legend", True)
575 self._legend_loc = ctx.meta.get("legend_loc")
576 self._min_dig = None
577 if "min_far_value" in ctx.meta:
578 self._min_dig = int(math.log10(ctx.meta["min_far_value"]))
579 elif self._axlim is not None and self._axlim[0] is not None:
580 self._min_dig = int(
581 math.log10(self._axlim[0]) if self._axlim[0] != 0 else 0
582 )
583 self._clayout = ctx.meta.get("clayout")
584 self._far_at = ctx.meta.get("lines_at")
585 self._trans_far_val = self._far_at
586 if self._far_at is not None:
587 self._eval_points = {line: [] for line in self._far_at}
588 self._lines_val = []
589 self._print_fn = ctx.meta.get("show_fn", True)
590 self._x_rotation = ctx.meta.get("x_rotation")
591 if "style" in ctx.meta:
592 mpl.style.use(ctx.meta["style"])
593 self._nb_figs = 2 if self._eval and self._split else 1
594 self._colors = utils.get_colors(self.n_systems)
595 self._line_linestyles = ctx.meta.get("line_styles", False)
596 self._linestyles = utils.get_linestyles(
597 self.n_systems, self._line_linestyles
598 )
599 self._titles = ctx.meta.get("titles", []) * 2
600 # for compatibility
601 self._title = ctx.meta.get("title")
602 if not self._titles and self._title is not None:
603 self._titles = [self._title] * 2
605 self._x_label = ctx.meta.get("x_label")
606 self._y_label = ctx.meta.get("y_label")
607 self._grid_color = "silver"
608 self._pdf_page = None
609 self._end_setup_plot = True
611 def init_process(self):
612 """Open pdf and set axis font size if provided"""
613 if not hasattr(matplotlib, "backends"):
614 matplotlib.use("pdf")
616 self._pdf_page = (
617 self._ctx.meta["PdfPages"]
618 if "PdfPages" in self._ctx.meta
619 else PdfPages(self._output)
620 )
622 for i in range(self._nb_figs):
623 fs = self._ctx.meta.get("figsize")
624 fig = mpl.figure(i + 1, figsize=fs)
625 fig.set_constrained_layout(self._clayout)
626 fig.clear()
628 def end_process(self):
629 """Set title, legend, axis labels, grid colors, save figures, drow
630 lines and close pdf if needed"""
631 # draw vertical lines
632 if self._far_at is not None:
633 for line, line_trans in zip(self._far_at, self._trans_far_val):
634 mpl.figure(1)
635 mpl.plot(
636 [line_trans, line_trans],
637 [-100.0, 100.0],
638 "--",
639 color="black",
640 )
641 if self._eval and self._split:
642 mpl.figure(2)
643 x_values = [i for i, _ in self._eval_points[line]]
644 y_values = [j for _, j in self._eval_points[line]]
645 sort_indice = sorted(
646 range(len(x_values)), key=x_values.__getitem__
647 )
648 x_values = [x_values[i] for i in sort_indice]
649 y_values = [y_values[i] for i in sort_indice]
650 mpl.plot(x_values, y_values, "--", color="black")
651 # only for plots
652 if self._end_setup_plot:
653 for i in range(self._nb_figs):
654 fig = mpl.figure(i + 1)
655 title = "" if not self._titles else self._titles[i]
656 mpl.title(title if title.replace(" ", "") else "")
657 mpl.xlabel(self._x_label)
658 mpl.ylabel(self._y_label)
659 mpl.grid(True, color=self._grid_color)
660 if self._disp_legend:
661 self.plot_legends()
662 self._set_axis()
663 mpl.xticks(rotation=self._x_rotation)
664 self._pdf_page.savefig(fig)
666 # do not want to close PDF when running evaluate
667 if "PdfPages" in self._ctx.meta and (
668 "closef" not in self._ctx.meta or self._ctx.meta["closef"]
669 ):
670 self._pdf_page.close()
672 def plot_legends(self):
673 """Print legend on current plot"""
674 if not self._disp_legend:
675 return
677 lines = []
678 labels = []
679 for ax in mpl.gcf().get_axes():
680 ali, ala = ax.get_legend_handles_labels()
681 # avoid duplicates in legend
682 for li, la in zip(ali, ala):
683 if la not in labels:
684 lines.append(li)
685 labels.append(la)
687 # create legend on the top or bottom axis
688 leg = mpl.legend(
689 lines,
690 labels,
691 loc=self._legend_loc,
692 ncol=1,
693 )
695 return leg
697 # common protected functions
699 def _label(self, base, idx):
700 if self._legends is not None and len(self._legends) > idx:
701 return self._legends[idx]
702 if self.n_systems > 1:
703 return base + (" %d" % (idx + 1))
704 return base
706 def _set_axis(self):
707 if self._axlim is not None:
708 mpl.axis(self._axlim)
711class Roc(PlotBase):
712 """Handles the plotting of ROC"""
714 def __init__(self, ctx, scores, evaluation, func_load):
715 super(Roc, self).__init__(ctx, scores, evaluation, func_load)
716 self._titles = self._titles or ["ROC dev.", "ROC eval."]
717 self._x_label = self._x_label or "FPR"
718 self._semilogx = ctx.meta.get("semilogx", True)
719 self._tpr = ctx.meta.get("tpr", True)
720 dflt_y_label = "TPR" if self._tpr else "FNR"
721 self._y_label = self._y_label or dflt_y_label
722 best_legend = "lower right" if self._semilogx else "upper right"
723 self._legend_loc = self._legend_loc or best_legend
724 # custom defaults
725 if self._axlim is None:
726 self._axlim = [None, None, -0.05, 1.05]
727 self._min_dig = -4 if self._min_dig is None else self._min_dig
729 def compute(self, idx, input_scores, input_names):
730 """Plot ROC for dev and eval data using
731 :py:func:`bob.measure.plot.roc`"""
732 neg_list, pos_list, _ = utils.get_fta_list(input_scores)
733 dev_neg, dev_pos = neg_list[0], pos_list[0]
734 dev_file = input_names[0]
735 if self._eval:
736 eval_neg, eval_pos = neg_list[1], pos_list[1]
737 eval_file = input_names[1]
739 mpl.figure(1)
740 if self._eval:
741 LOGGER.info("ROC dev. curve using %s", dev_file)
742 plot.roc(
743 dev_neg,
744 dev_pos,
745 npoints=self._points,
746 semilogx=self._semilogx,
747 tpr=self._tpr,
748 min_far=self._min_dig,
749 color=self._colors[idx],
750 linestyle=self._linestyles[idx],
751 label=self._label("dev", idx),
752 alpha=self._alpha,
753 )
754 if self._split:
755 mpl.figure(2)
757 linestyle = "--" if not self._split else self._linestyles[idx]
758 LOGGER.info("ROC eval. curve using %s", eval_file)
759 plot.roc(
760 eval_neg,
761 eval_pos,
762 linestyle=linestyle,
763 npoints=self._points,
764 semilogx=self._semilogx,
765 tpr=self._tpr,
766 min_far=self._min_dig,
767 color=self._colors[idx],
768 label=self._label("eval.", idx),
769 alpha=self._alpha,
770 )
771 if self._far_at is not None:
772 from .. import fprfnr
774 for line in self._far_at:
775 thres_line = far_threshold(dev_neg, dev_pos, line)
776 eval_fmr, eval_fnmr = fprfnr(eval_neg, eval_pos, thres_line)
777 if self._tpr:
778 eval_fnmr = 1 - eval_fnmr
779 mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
780 self._eval_points[line].append((eval_fmr, eval_fnmr))
781 else:
782 LOGGER.info("ROC dev. curve using %s", dev_file)
783 plot.roc(
784 dev_neg,
785 dev_pos,
786 npoints=self._points,
787 semilogx=self._semilogx,
788 tpr=self._tpr,
789 min_far=self._min_dig,
790 color=self._colors[idx],
791 linestyle=self._linestyles[idx],
792 label=self._label("dev", idx),
793 alpha=self._alpha,
794 )
797class Det(PlotBase):
798 """Handles the plotting of DET"""
800 def __init__(self, ctx, scores, evaluation, func_load):
801 super(Det, self).__init__(ctx, scores, evaluation, func_load)
802 self._titles = self._titles or ["DET dev.", "DET eval."]
803 self._x_label = self._x_label or "FPR (%)"
804 self._y_label = self._y_label or "FNR (%)"
805 self._legend_loc = self._legend_loc or "upper right"
806 if self._far_at is not None:
807 self._trans_far_val = ppndf(self._far_at)
808 # custom defaults here
809 if self._x_rotation is None:
810 self._x_rotation = 50
812 if self._axlim is None:
813 self._axlim = [0.01, 99, 0.01, 99]
815 if self._min_dig is not None:
816 self._axlim[0] = math.pow(10, self._min_dig) * 100
818 self._min_dig = -4 if self._min_dig is None else self._min_dig
820 def compute(self, idx, input_scores, input_names):
821 """Plot DET for dev and eval data using
822 :py:func:`bob.measure.plot.det`"""
823 neg_list, pos_list, _ = utils.get_fta_list(input_scores)
824 dev_neg, dev_pos = neg_list[0], pos_list[0]
825 dev_file = input_names[0]
826 if self._eval:
827 eval_neg, eval_pos = neg_list[1], pos_list[1]
828 eval_file = input_names[1]
830 mpl.figure(1)
831 if self._eval and eval_neg is not None:
832 LOGGER.info("DET dev. curve using %s", dev_file)
833 plot.det(
834 dev_neg,
835 dev_pos,
836 self._points,
837 min_far=self._min_dig,
838 color=self._colors[idx],
839 linestyle=self._linestyles[idx],
840 label=self._label("dev.", idx),
841 alpha=self._alpha,
842 )
843 if self._split:
844 mpl.figure(2)
845 linestyle = "--" if not self._split else self._linestyles[idx]
846 LOGGER.info("DET eval. curve using %s", eval_file)
847 plot.det(
848 eval_neg,
849 eval_pos,
850 self._points,
851 min_far=self._min_dig,
852 color=self._colors[idx],
853 linestyle=linestyle,
854 label=self._label("eval.", idx),
855 alpha=self._alpha,
856 )
857 if self._far_at is not None:
858 from .. import farfrr
860 for line in self._far_at:
861 thres_line = far_threshold(dev_neg, dev_pos, line)
862 eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
863 eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
864 mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
865 self._eval_points[line].append((eval_fmr, eval_fnmr))
866 else:
867 LOGGER.info("DET dev. curve using %s", dev_file)
868 plot.det(
869 dev_neg,
870 dev_pos,
871 self._points,
872 min_far=self._min_dig,
873 color=self._colors[idx],
874 linestyle=self._linestyles[idx],
875 label=self._label("dev.", idx),
876 alpha=self._alpha,
877 )
879 def _set_axis(self):
880 plot.det_axis(self._axlim)
883class Epc(PlotBase):
884 """Handles the plotting of EPC"""
886 def __init__(self, ctx, scores, evaluation, func_load, hter="HTER"):
887 super(Epc, self).__init__(ctx, scores, evaluation, func_load)
888 if self._min_arg != 2:
889 raise click.UsageError("EPC requires dev. and eval. score files")
890 self._titles = self._titles or ["EPC"] * 2
891 self._x_label = self._x_label or r"$\alpha$"
892 self._y_label = self._y_label or hter + " (%)"
893 self._legend_loc = self._legend_loc or "upper center"
894 self._eval = True # always eval data with EPC
895 self._split = False
896 self._nb_figs = 1
897 self._far_at = None
899 def compute(self, idx, input_scores, input_names):
900 """Plot EPC using :py:func:`bob.measure.plot.epc`"""
901 neg_list, pos_list, _ = utils.get_fta_list(input_scores)
902 dev_neg, dev_pos = neg_list[0], pos_list[0]
903 dev_file = input_names[0]
904 if self._eval:
905 eval_neg, eval_pos = neg_list[1], pos_list[1]
906 eval_file = input_names[1]
908 LOGGER.info("EPC using %s", dev_file + "_" + eval_file)
909 plot.epc(
910 dev_neg,
911 dev_pos,
912 eval_neg,
913 eval_pos,
914 self._points,
915 color=self._colors[idx],
916 linestyle=self._linestyles[idx],
917 label=self._label("curve", idx),
918 alpha=self._alpha,
919 )
922class GridSubplot(PlotBase):
923 """A base class for plots that contain subplots and legends.
925 To use this class, use `create_subplot` in `compute` each time you need a
926 new axis. and call `finalize_one_page` in `compute` when a page is finished
927 rendering.
928 """
930 def __init__(self, ctx, scores, evaluation, func_load):
931 super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load)
933 # Check legend
934 self._legend_loc = self._legend_loc or "upper center"
935 if self._legend_loc == "best":
936 self._legend_loc = "upper center"
937 if "upper" not in self._legend_loc and "lower" not in self._legend_loc:
938 raise ValueError(
939 "Only best, upper-*, and lower-* legend locations are supported!"
940 )
941 self._nlegends = ctx.meta.get("legends_ncol", 3)
943 # subplot grid
944 self._nrows = ctx.meta.get("n_row", 1)
945 self._ncols = ctx.meta.get("n_col", 1)
947 def init_process(self):
948 super(GridSubplot, self).init_process()
949 self._create_grid_spec()
951 def _create_grid_spec(self):
952 # create a compatible GridSpec
953 self._gs = gridspec.GridSpec(
954 self._nrows,
955 self._ncols,
956 figure=mpl.gcf(),
957 )
959 def create_subplot(self, n, shared_axis=None):
960 i, j = numpy.unravel_index(n, (self._nrows, self._ncols))
961 axis = mpl.gcf().add_subplot(
962 self._gs[i : i + 1, j : j + 1], sharex=shared_axis
963 )
964 return axis
966 def finalize_one_page(self):
967 # print legend on the page
968 self.plot_legends()
969 fig = mpl.gcf()
970 axes = fig.get_axes()
972 LOGGER.debug("%s contains %d axes:", fig, len(axes))
973 for i, ax in enumerate(axes, start=1):
974 LOGGER.debug("Axes %d: %s", i, ax)
976 self._pdf_page.savefig(bbox_inches="tight")
977 mpl.clf()
978 mpl.figure()
979 self._create_grid_spec()
981 def plot_legends(self):
982 """Print legend on current page"""
983 if not self._disp_legend:
984 return
986 lines = []
987 labels = []
988 for ax in mpl.gcf().get_axes():
989 ali, ala = ax.get_legend_handles_labels()
990 # avoid duplicates in legend
991 for li, la in zip(ali, ala):
992 if la not in labels:
993 lines.append(li)
994 labels.append(la)
996 # create legend on the top or bottom axis
997 fig = mpl.gcf()
998 if "upper" in self._legend_loc:
999 # Set anchor to top of figure
1000 bbox_to_anchor = (0.0, 1.0, 1.0, 0.0)
1001 # Legend will be anchored with its bottom side, so switch the loc
1002 anchored_loc = self._legend_loc.replace("upper", "lower")
1003 else:
1004 # Set anchor to bottom of figure
1005 bbox_to_anchor = (0.0, 0.0, 1.0, 0.0)
1006 # Legend will be anchored with its top side, so switch the loc
1007 anchored_loc = self._legend_loc.replace("lower", "upper")
1008 leg = fig.legend(
1009 lines,
1010 labels,
1011 loc=anchored_loc,
1012 ncol=self._nlegends,
1013 bbox_to_anchor=bbox_to_anchor,
1014 )
1016 return leg
1019class Hist(GridSubplot):
1020 """Functional base class for histograms"""
1022 def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
1023 super(Hist, self).__init__(ctx, scores, evaluation, func_load)
1024 self._nbins = ctx.meta.get("n_bins", ["doane"])
1025 self._nhist_per_system = nhist_per_system
1026 self._nbins = check_list_value(
1027 self._nbins, nhist_per_system, "n_bins", "histograms"
1028 )
1029 self._thres = ctx.meta.get("thres")
1030 self._thres = check_list_value(
1031 self._thres, self.n_systems, "thresholds"
1032 )
1033 self._criterion = ctx.meta.get("criterion")
1034 # no vertical (threshold) is displayed
1035 self._no_line = ctx.meta.get("no_line", False)
1036 # do not display dev histo
1037 self._hide_dev = ctx.meta.get("hide_dev", False)
1038 if self._hide_dev and not self._eval:
1039 raise click.BadParameter(
1040 "You can only use --hide-dev along with --eval"
1041 )
1042 # dev hist are displayed next to eval hist
1043 self._nrows *= 1 if self._hide_dev or not self._eval else 2
1044 self._nlegends = ctx.meta.get("legends_ncol", 3)
1046 # number of subplot on one page
1047 self._step_print = int(self._nrows * self._ncols)
1048 self._title_base = "Scores"
1049 self._y_label = self._y_label or "Probability density"
1050 self._x_label = self._x_label or "Score values"
1051 self._end_setup_plot = False
1052 # overide _titles of PlotBase
1053 self._titles = ctx.meta.get("titles", []) * 2
1055 def compute(self, idx, input_scores, input_names):
1056 """Draw histograms of negative and positive scores."""
1057 (
1058 dev_neg,
1059 dev_pos,
1060 eval_neg,
1061 eval_pos,
1062 threshold,
1063 ) = self._get_neg_pos_thres(idx, input_scores, input_names)
1065 # keep id of the current system
1066 sys = idx
1067 # if the id of the current system does not match the id of the plot,
1068 # change it
1069 if not self._hide_dev and self._eval:
1070 row = int(idx / self._ncols) * 2
1071 col = idx % self._ncols
1072 idx = col + self._ncols * row
1074 dev_axis = None
1076 if not self._hide_dev or not self._eval:
1077 dev_axis = self._print_subplot(
1078 idx,
1079 sys,
1080 dev_neg,
1081 dev_pos,
1082 threshold,
1083 not self._no_line,
1084 False,
1085 )
1087 if self._eval:
1088 idx += self._ncols if not self._hide_dev else 0
1089 self._print_subplot(
1090 idx,
1091 sys,
1092 eval_neg,
1093 eval_pos,
1094 threshold,
1095 not self._no_line,
1096 True,
1097 shared_axis=dev_axis,
1098 )
1100 def _print_subplot(
1101 self,
1102 idx,
1103 sys,
1104 neg,
1105 pos,
1106 threshold,
1107 draw_line,
1108 evaluation,
1109 shared_axis=None,
1110 ):
1111 """print a subplot for the given score and subplot index"""
1112 n = idx % self._step_print
1113 col = n % self._ncols
1114 sub_plot_idx = n + 1
1115 axis = self.create_subplot(n, shared_axis)
1116 self._setup_hist(neg, pos)
1117 if col == 0:
1118 axis.set_ylabel(self._y_label)
1119 # systems per page
1120 sys_per_page = self._step_print / (
1121 1 if self._hide_dev or not self._eval else 2
1122 )
1123 # rest to be printed
1124 sys_idx = sys % sys_per_page
1125 rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
1126 # lower histo only
1127 is_lower = evaluation or not self._eval
1128 if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
1129 axis.set_xlabel(self._x_label)
1130 dflt_title = "Eval. scores" if evaluation else "Dev. scores"
1131 if self.n_systems == 1 and (not self._eval or self._hide_dev):
1132 dflt_title = " "
1133 add = self.n_systems if is_lower else 0
1134 axis.set_title(self._get_title(sys + add, dflt_title))
1135 label = "%s threshold%s" % (
1136 "" if self._criterion is None else self._criterion.upper(),
1137 " (dev)" if self._eval else "",
1138 )
1139 if draw_line:
1140 self._lines(threshold, label, neg, pos, idx)
1142 # enable the grid and set it below other elements
1143 axis.set_axisbelow(True)
1144 axis.grid(True, color=self._grid_color)
1146 # if it was the last subplot of the page or the last subplot
1147 # to display, save figure
1148 if self._step_print == sub_plot_idx or (
1149 is_lower and sys == self.n_systems - 1
1150 ):
1151 self.finalize_one_page()
1152 return axis
1154 def _get_title(self, idx, dflt=None):
1155 """Get the histo title for the given idx"""
1156 title = (
1157 self._titles[idx]
1158 if self._titles is not None and idx < len(self._titles)
1159 else dflt
1160 )
1161 title = title or self._title_base
1162 title = (
1163 "" if title is not None and not title.replace(" ", "") else title
1164 )
1165 return title or ""
1167 def _get_neg_pos_thres(self, idx, input_scores, input_names):
1168 """Get scores and threshod for the given system at index idx"""
1169 neg_list, pos_list, _ = utils.get_fta_list(input_scores)
1170 length = len(neg_list)
1171 # lists returned by get_fta_list contains all the following items:
1172 # for bio or measure without eval:
1173 # [dev]
1174 # for vuln with {licit,spoof} with eval:
1175 # [dev, eval]
1176 # for vuln with {licit,spoof} without eval:
1177 # [licit_dev, spoof_dev]
1178 # for vuln with {licit,spoof} with eval:
1179 # [licit_dev, licit_eval, spoof_dev, spoof_eval]
1180 step = 2 if self._eval else 1
1181 # can have several files for one system
1182 dev_neg = [neg_list[x] for x in range(0, length, step)]
1183 dev_pos = [pos_list[x] for x in range(0, length, step)]
1184 eval_neg = eval_pos = None
1185 if self._eval:
1186 eval_neg = [neg_list[x] for x in range(1, length, step)]
1187 eval_pos = [pos_list[x] for x in range(1, length, step)]
1189 threshold = (
1190 utils.get_thres(self._criterion, dev_neg[0], dev_pos[0])
1191 if self._thres is None
1192 else self._thres[idx]
1193 )
1194 return dev_neg, dev_pos, eval_neg, eval_pos, threshold
1196 def _density_hist(self, scores, n, **kwargs):
1197 """Plots one density histo"""
1198 n, bins, patches = mpl.hist(
1199 scores, density=True, bins=self._nbins[n], **kwargs
1200 )
1201 return (n, bins, patches)
1203 def _lines(
1204 self, threshold, label=None, neg=None, pos=None, idx=None, **kwargs
1205 ):
1206 """Plots vertical line at threshold"""
1207 label = label or "Threshold"
1208 kwargs.setdefault("color", "C3")
1209 kwargs.setdefault("linestyle", "--")
1210 kwargs.setdefault("label", label)
1211 # plot a vertical threshold line
1212 mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
1214 def _setup_hist(self, neg, pos):
1215 """This function can be overwritten in derived classes
1217 Plots all the density histo required in one plot. Here negative and
1218 positive scores densities.
1219 """
1220 self._density_hist(
1221 neg[0], n=0, label="Negatives", alpha=0.5, color="C3"
1222 )
1223 self._density_hist(
1224 pos[0], n=1, label="Positives", alpha=0.5, color="C0"
1225 )