Coverage for src/bob/measure/script/figure.py: 88%

550 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-06-16 14:10 +0200

1"""Runs error analysis on score sets, outputs metrics and plots""" 

2 

3from __future__ import division, print_function 

4 

5import logging 

6import math 

7import sys 

8 

9from abc import ABCMeta, abstractmethod 

10 

11import click 

12import matplotlib 

13import matplotlib.pyplot as mpl 

14import numpy 

15 

16from matplotlib import gridspec 

17from matplotlib.backends.backend_pdf import PdfPages 

18from tabulate import tabulate 

19 

20from .. import far_threshold, plot, ppndf, utils 

21 

22LOGGER = logging.getLogger("bob.measure") 

23 

24 

25def check_list_value(values, desired_number, name, name2="systems"): 

26 if values is not None and len(values) != desired_number: 

27 if len(values) == 1: 

28 values = values * desired_number 

29 else: 

30 raise click.BadParameter( 

31 "#{} ({}) must be either 1 value or the same as " 

32 "#{} ({} values)".format(name, values, name2, desired_number) 

33 ) 

34 

35 return values 

36 

37 

38class MeasureBase(object): 

39 """Base class for metrics and plots. 

40 This abstract class define the framework to plot or compute metrics from a 

41 list of (positive, negative) scores tuples. 

42 

43 Attributes 

44 ---------- 

45 func_load: 

46 Function that is used to load the input files 

47 """ 

48 

49 __metaclass__ = ABCMeta # for python 2.7 compatibility 

50 

51 def __init__(self, ctx, scores, evaluation, func_load): 

52 """ 

53 Parameters 

54 ---------- 

55 ctx : :py:class:`dict` 

56 Click context dictionary. 

57 

58 scores : :any:`list`: 

59 List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1 

60 {dev,eval}-scores2) 

61 eval : :py:class:`bool` 

62 True if eval data are used 

63 func_load : Function that is used to load the input files 

64 """ 

65 self._scores = scores 

66 self._ctx = ctx 

67 self.func_load = func_load 

68 self._legends = ctx.meta.get("legends") 

69 self._eval = evaluation 

70 self._min_arg = ctx.meta.get("min_arg", 1) 

71 if len(scores) < 1 or len(scores) % self._min_arg != 0: 

72 raise click.BadParameter( 

73 "Number of argument must be a non-zero multiple of %d" 

74 % self._min_arg 

75 ) 

76 self.n_systems = int(len(scores) / self._min_arg) 

77 if self._legends is not None and len(self._legends) < self.n_systems: 

78 raise click.BadParameter( 

79 "Number of legends must be >= to the " "number of systems" 

80 ) 

81 

82 def run(self): 

83 """Generate outputs (e.g. metrics, files, pdf plots). 

84 This function calls abstract methods 

85 :func:`~bob.measure.script.figure.MeasureBase.init_process` (before 

86 loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute` 

87 (in the loop iterating through the different 

88 systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process` 

89 (after the loop). 

90 """ 

91 # init matplotlib, log files, ... 

92 self.init_process() 

93 # iterates through the different systems and feed `compute` 

94 # with the dev (and eval) scores of each system 

95 # Note that more than one dev or eval scores score can be passed to 

96 # each system 

97 for idx in range(self.n_systems): 

98 # load scores for each system: get the corresponding arrays and 

99 # base-name of files 

100 input_scores, input_names = self._load_files( 

101 # Scores are given as followed: 

102 # SysA-dev SysA-eval ... SysA-XX SysB-dev SysB-eval ... SysB-XX 

103 # ------------------------------ ------------------------------ 

104 # First set of `self._min_arg` Second set of input files 

105 # input files starting at for SysB 

106 # index idx * self._min_arg 

107 self._scores[idx * self._min_arg : (idx + 1) * self._min_arg] 

108 ) 

109 LOGGER.info("-----Input files for system %d-----", idx + 1) 

110 for i, name in enumerate(input_names): 

111 if not self._eval: 

112 LOGGER.info("Dev. score %d: %s", i + 1, name) 

113 else: 

114 if i % 2 == 0: 

115 LOGGER.info("Dev. score %d: %s", i / 2 + 1, name) 

116 else: 

117 LOGGER.info("Eval. score %d: %s", i / 2 + 1, name) 

118 LOGGER.info("----------------------------------") 

119 

120 self.compute(idx, input_scores, input_names) 

121 # setup final configuration, plotting properties, ... 

122 self.end_process() 

123 

124 # protected functions that need to be overwritten 

125 def init_process(self): 

126 """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run 

127 before iterating through the different systems. 

128 Should reimplemented in derived classes""" 

129 pass 

130 

131 # Main computations are done here in the subclasses 

132 @abstractmethod 

133 def compute(self, idx, input_scores, input_names): 

134 """Compute metrics or plots from the given scores provided by 

135 :py:func:`~bob.measure.script.figure.MeasureBase.run`. 

136 Should reimplemented in derived classes 

137 

138 Parameters 

139 ---------- 

140 idx : :obj:`int` 

141 index of the system 

142 input_scores: :any:`list` 

143 list of scores returned by the loading function 

144 input_names: :any:`list` 

145 list of base names for the input file of the system 

146 """ 

147 pass 

148 # structure of input is (vuln example): 

149 # if evaluation is provided 

150 # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos), 

151 # (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)] 

152 # and if only dev: 

153 # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)] 

154 

155 # Things to do after the main iterative computations are done 

156 @abstractmethod 

157 def end_process(self): 

158 """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run 

159 after iterating through the different systems. 

160 Should reimplemented in derived classes""" 

161 pass 

162 

163 # common protected functions 

164 

165 def _load_files(self, filepaths): 

166 """Load the input files and return the base names of the files 

167 

168 Returns 

169 ------- 

170 scores: :any:`list`: 

171 A list that contains the output of 

172 ``func_load`` for the given files 

173 basenames: :any:`list`: 

174 A list of the given files 

175 """ 

176 scores = [] 

177 basenames = [] 

178 for filename in filepaths: 

179 basenames.append(filename) 

180 scores.append(self.func_load(filename)) 

181 return scores, basenames 

182 

183 

184class Metrics(MeasureBase): 

185 """Compute metrics from score files 

186 

187 Attributes 

188 ---------- 

189 log_file: str 

190 output stream 

191 """ 

192 

193 def __init__( 

194 self, 

195 ctx, 

196 scores, 

197 evaluation, 

198 func_load, 

199 names=( 

200 "False Positive Rate", 

201 "False Negative Rate", 

202 "Precision", 

203 "Recall", 

204 "F1-score", 

205 "Area Under ROC Curve", 

206 "Area Under ROC Curve (log scale)", 

207 ), 

208 ): 

209 super(Metrics, self).__init__(ctx, scores, evaluation, func_load) 

210 self.names = names 

211 self._tablefmt = ctx.meta.get("tablefmt") 

212 self._criterion = ctx.meta.get("criterion") 

213 self._open_mode = ctx.meta.get("open_mode") 

214 self._thres = ctx.meta.get("thres") 

215 self._decimal = ctx.meta.get("decimal", 2) 

216 if self._thres is not None: 

217 if len(self._thres) == 1: 

218 self._thres = self._thres * self.n_systems 

219 elif len(self._thres) != self.n_systems: 

220 raise click.BadParameter( 

221 "#thresholds must be the same as #systems (%d)" 

222 % len(self.n_systems) 

223 ) 

224 self._far = ctx.meta.get("far_value") 

225 self._log = ctx.meta.get("log") 

226 self.log_file = sys.stdout 

227 if self._log is not None: 

228 self.log_file = open(self._log, self._open_mode) 

229 

230 def get_thres(self, criterion, dev_neg, dev_pos, far): 

231 return utils.get_thres(criterion, dev_neg, dev_pos, far) 

232 

233 def _numbers(self, neg, pos, threshold, fta): 

234 from .. import f_score, farfrr, precision_recall, roc_auc_score 

235 

236 # fpr and fnr 

237 fmr, fnmr = farfrr(neg, pos, threshold) 

238 hter = (fmr + fnmr) / 2.0 

239 far = fmr * (1 - fta) 

240 frr = fta + fnmr * (1 - fta) 

241 

242 ni = neg.shape[0] # number of impostors 

243 fm = int(round(fmr * ni)) # number of false accepts 

244 nc = pos.shape[0] # number of clients 

245 fnm = int(round(fnmr * nc)) # number of false rejects 

246 

247 # precision and recall 

248 precision, recall = precision_recall(neg, pos, threshold) 

249 

250 # f_score 

251 f1_score = f_score(neg, pos, threshold, 1) 

252 

253 # AUC ROC 

254 auc = roc_auc_score(neg, pos) 

255 auc_log = roc_auc_score(neg, pos, log_scale=True) 

256 return ( 

257 fta, 

258 fmr, 

259 fnmr, 

260 hter, 

261 far, 

262 frr, 

263 fm, 

264 ni, 

265 fnm, 

266 nc, 

267 precision, 

268 recall, 

269 f1_score, 

270 auc, 

271 auc_log, 

272 ) 

273 

274 def _strings(self, metrics): 

275 n_dec = ".%df" % self._decimal 

276 fta_str = "%s%%" % format(100 * metrics[0], n_dec) 

277 fmr_str = "%s%% (%d/%d)" % ( 

278 format(100 * metrics[1], n_dec), 

279 metrics[6], 

280 metrics[7], 

281 ) 

282 fnmr_str = "%s%% (%d/%d)" % ( 

283 format(100 * metrics[2], n_dec), 

284 metrics[8], 

285 metrics[9], 

286 ) 

287 far_str = "%s%%" % format(100 * metrics[4], n_dec) 

288 frr_str = "%s%%" % format(100 * metrics[5], n_dec) 

289 hter_str = "%s%%" % format(100 * metrics[3], n_dec) 

290 prec_str = "%s" % format(metrics[10], n_dec) 

291 recall_str = "%s" % format(metrics[11], n_dec) 

292 f1_str = "%s" % format(metrics[12], n_dec) 

293 auc_str = "%s" % format(metrics[13], n_dec) 

294 auc_log_str = "%s" % format(metrics[14], n_dec) 

295 

296 return ( 

297 fta_str, 

298 fmr_str, 

299 fnmr_str, 

300 far_str, 

301 frr_str, 

302 hter_str, 

303 prec_str, 

304 recall_str, 

305 f1_str, 

306 auc_str, 

307 auc_log_str, 

308 ) 

309 

310 def _get_all_metrics(self, idx, input_scores, input_names): 

311 """Compute all metrics for dev and eval scores""" 

312 neg_list, pos_list, fta_list = utils.get_fta_list(input_scores) 

313 dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0] 

314 dev_file = input_names[0] 

315 if self._eval: 

316 eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1] 

317 

318 threshold = ( 

319 self.get_thres(self._criterion, dev_neg, dev_pos, self._far) 

320 if self._thres is None 

321 else self._thres[idx] 

322 ) 

323 

324 title = self._legends[idx] if self._legends is not None else None 

325 if self._thres is None: 

326 far_str = "" 

327 if self._criterion == "far" and self._far is not None: 

328 far_str = str(self._far) 

329 click.echo( 

330 "[Min. criterion: %s %s] Threshold on Development set `%s`: %e" 

331 % ( 

332 self._criterion.upper(), 

333 far_str, 

334 title or dev_file, 

335 threshold, 

336 ), 

337 file=self.log_file, 

338 ) 

339 else: 

340 click.echo( 

341 "[Min. criterion: user provided] Threshold on " 

342 "Development set `%s`: %e" % (dev_file or title, threshold), 

343 file=self.log_file, 

344 ) 

345 

346 res = [] 

347 res.append( 

348 self._strings(self._numbers(dev_neg, dev_pos, threshold, dev_fta)) 

349 ) 

350 

351 if self._eval: 

352 # computes statistics for the eval set based on the threshold a 

353 # priori 

354 res.append( 

355 self._strings( 

356 self._numbers(eval_neg, eval_pos, threshold, eval_fta) 

357 ) 

358 ) 

359 else: 

360 res.append(None) 

361 

362 return res 

363 

364 def compute(self, idx, input_scores, input_names): 

365 """Compute metrics thresholds and tables (FPR, FNR, precision, recall, 

366 f1_score) for given system inputs""" 

367 dev_file = input_names[0] 

368 title = self._legends[idx] if self._legends is not None else None 

369 all_metrics = self._get_all_metrics(idx, input_scores, input_names) 

370 fta_dev = float(all_metrics[0][0].replace("%", "")) 

371 if fta_dev > 0.0: 

372 LOGGER.warn( 

373 "NaNs scores (%s) were found in %s amd removed", 

374 all_metrics[0][0], 

375 dev_file, 

376 ) 

377 headers = [" " or title, "Development"] 

378 rows = [ 

379 [self.names[0], all_metrics[0][1]], 

380 [self.names[1], all_metrics[0][2]], 

381 [self.names[2], all_metrics[0][6]], 

382 [self.names[3], all_metrics[0][7]], 

383 [self.names[4], all_metrics[0][8]], 

384 [self.names[5], all_metrics[0][9]], 

385 [self.names[6], all_metrics[0][10]], 

386 ] 

387 

388 if self._eval: 

389 eval_file = input_names[1] 

390 fta_eval = float(all_metrics[1][0].replace("%", "")) 

391 if fta_eval > 0.0: 

392 LOGGER.warn( 

393 "NaNs scores (%s) were found in %s and removed.", 

394 all_metrics[1][0], 

395 eval_file, 

396 ) 

397 # computes statistics for the eval set based on the threshold a 

398 # priori 

399 headers.append("Evaluation") 

400 rows[0].append(all_metrics[1][1]) 

401 rows[1].append(all_metrics[1][2]) 

402 rows[2].append(all_metrics[1][6]) 

403 rows[3].append(all_metrics[1][7]) 

404 rows[4].append(all_metrics[1][8]) 

405 rows[5].append(all_metrics[1][9]) 

406 rows[6].append(all_metrics[1][10]) 

407 

408 click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file) 

409 

410 def end_process(self): 

411 """Close log file if needed""" 

412 if self._log is not None: 

413 self.log_file.close() 

414 

415 

416class MultiMetrics(Metrics): 

417 """Computes average of metrics based on several protocols (cross 

418 validation) 

419 

420 Attributes 

421 ---------- 

422 log_file : str 

423 output stream 

424 names : tuple 

425 List of names for the metrics. 

426 """ 

427 

428 def __init__( 

429 self, 

430 ctx, 

431 scores, 

432 evaluation, 

433 func_load, 

434 names=( 

435 "NaNs Rate", 

436 "False Positive Rate", 

437 "False Negative Rate", 

438 "False Accept Rate", 

439 "False Reject Rate", 

440 "Half Total Error Rate", 

441 ), 

442 ): 

443 super(MultiMetrics, self).__init__( 

444 ctx, scores, evaluation, func_load, names=names 

445 ) 

446 

447 self.headers = ["Methods"] + list(self.names) 

448 if self._eval: 

449 self.headers.insert(1, self.names[5] + " (dev)") 

450 self.rows = [] 

451 

452 def _strings(self, metrics): 

453 ( 

454 ftam, 

455 fmrm, 

456 fnmrm, 

457 hterm, 

458 farm, 

459 frrm, 

460 _, 

461 _, 

462 _, 

463 _, 

464 _, 

465 _, 

466 _, 

467 ) = metrics.mean(axis=0) 

468 ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _, _, _, _ = metrics.std( 

469 axis=0 

470 ) 

471 n_dec = ".%df" % self._decimal 

472 fta_str = "%s%% (%s%%)" % ( 

473 format(100 * ftam, n_dec), 

474 format(100 * ftas, n_dec), 

475 ) 

476 fmr_str = "%s%% (%s%%)" % ( 

477 format(100 * fmrm, n_dec), 

478 format(100 * fmrs, n_dec), 

479 ) 

480 fnmr_str = "%s%% (%s%%)" % ( 

481 format(100 * fnmrm, n_dec), 

482 format(100 * fnmrs, n_dec), 

483 ) 

484 far_str = "%s%% (%s%%)" % ( 

485 format(100 * farm, n_dec), 

486 format(100 * fars, n_dec), 

487 ) 

488 frr_str = "%s%% (%s%%)" % ( 

489 format(100 * frrm, n_dec), 

490 format(100 * frrs, n_dec), 

491 ) 

492 hter_str = "%s%% (%s%%)" % ( 

493 format(100 * hterm, n_dec), 

494 format(100 * hters, n_dec), 

495 ) 

496 return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str 

497 

498 def compute(self, idx, input_scores, input_names): 

499 """Computes the average of metrics over several protocols.""" 

500 neg_list, pos_list, fta_list = utils.get_fta_list(input_scores) 

501 step = 2 if self._eval else 1 

502 self._dev_metrics = [] 

503 self._thresholds = [] 

504 for i in range(0, len(input_scores), step): 

505 neg, pos, fta = neg_list[i], pos_list[i], fta_list[i] 

506 threshold = ( 

507 self.get_thres(self._criterion, neg, pos, self._far) 

508 if self._thres is None 

509 else self._thres[idx] 

510 ) 

511 self._thresholds.append(threshold) 

512 self._dev_metrics.append(self._numbers(neg, pos, threshold, fta)) 

513 self._dev_metrics = numpy.array(self._dev_metrics) 

514 

515 if self._eval: 

516 self._eval_metrics = [] 

517 for i in range(1, len(input_scores), step): 

518 neg, pos, fta = neg_list[i], pos_list[i], fta_list[i] 

519 threshold = self._thresholds[i // 2] 

520 self._eval_metrics.append( 

521 self._numbers(neg, pos, threshold, fta) 

522 ) 

523 self._eval_metrics = numpy.array(self._eval_metrics) 

524 

525 title = self._legends[idx] if self._legends is not None else None 

526 

527 fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = self._strings( 

528 self._dev_metrics 

529 ) 

530 

531 if self._eval: 

532 self.rows.append([title, hter_str]) 

533 else: 

534 self.rows.append( 

535 [title, fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str] 

536 ) 

537 

538 if self._eval: 

539 # computes statistics for the eval set based on the threshold a 

540 # priori 

541 ( 

542 fta_str, 

543 fmr_str, 

544 fnmr_str, 

545 far_str, 

546 frr_str, 

547 hter_str, 

548 ) = self._strings(self._eval_metrics) 

549 

550 self.rows[-1].extend( 

551 [fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str] 

552 ) 

553 

554 def end_process(self): 

555 click.echo( 

556 tabulate(self.rows, self.headers, self._tablefmt), 

557 file=self.log_file, 

558 ) 

559 super(MultiMetrics, self).end_process() 

560 

561 

562class PlotBase(MeasureBase): 

563 """Base class for plots. Regroup several options and code 

564 shared by the different plots 

565 """ 

566 

567 def __init__(self, ctx, scores, evaluation, func_load): 

568 super(PlotBase, self).__init__(ctx, scores, evaluation, func_load) 

569 self._output = ctx.meta.get("output") 

570 self._points = ctx.meta.get("points", 2000) 

571 self._split = ctx.meta.get("split") 

572 self._axlim = ctx.meta.get("axlim") 

573 self._alpha = ctx.meta.get("alpha") 

574 self._disp_legend = ctx.meta.get("disp_legend", True) 

575 self._legend_loc = ctx.meta.get("legend_loc") 

576 self._min_dig = None 

577 if "min_far_value" in ctx.meta: 

578 self._min_dig = int(math.log10(ctx.meta["min_far_value"])) 

579 elif self._axlim is not None and self._axlim[0] is not None: 

580 self._min_dig = int( 

581 math.log10(self._axlim[0]) if self._axlim[0] != 0 else 0 

582 ) 

583 self._clayout = ctx.meta.get("clayout") 

584 self._far_at = ctx.meta.get("lines_at") 

585 self._trans_far_val = self._far_at 

586 if self._far_at is not None: 

587 self._eval_points = {line: [] for line in self._far_at} 

588 self._lines_val = [] 

589 self._print_fn = ctx.meta.get("show_fn", True) 

590 self._x_rotation = ctx.meta.get("x_rotation") 

591 if "style" in ctx.meta: 

592 mpl.style.use(ctx.meta["style"]) 

593 self._nb_figs = 2 if self._eval and self._split else 1 

594 self._colors = utils.get_colors(self.n_systems) 

595 self._line_linestyles = ctx.meta.get("line_styles", False) 

596 self._linestyles = utils.get_linestyles( 

597 self.n_systems, self._line_linestyles 

598 ) 

599 self._titles = ctx.meta.get("titles", []) * 2 

600 # for compatibility 

601 self._title = ctx.meta.get("title") 

602 if not self._titles and self._title is not None: 

603 self._titles = [self._title] * 2 

604 

605 self._x_label = ctx.meta.get("x_label") 

606 self._y_label = ctx.meta.get("y_label") 

607 self._grid_color = "silver" 

608 self._pdf_page = None 

609 self._end_setup_plot = True 

610 

611 def init_process(self): 

612 """Open pdf and set axis font size if provided""" 

613 if not hasattr(matplotlib, "backends"): 

614 matplotlib.use("pdf") 

615 

616 self._pdf_page = ( 

617 self._ctx.meta["PdfPages"] 

618 if "PdfPages" in self._ctx.meta 

619 else PdfPages(self._output) 

620 ) 

621 

622 for i in range(self._nb_figs): 

623 fs = self._ctx.meta.get("figsize") 

624 fig = mpl.figure(i + 1, figsize=fs) 

625 fig.set_constrained_layout(self._clayout) 

626 fig.clear() 

627 

628 def end_process(self): 

629 """Set title, legend, axis labels, grid colors, save figures, drow 

630 lines and close pdf if needed""" 

631 # draw vertical lines 

632 if self._far_at is not None: 

633 for line, line_trans in zip(self._far_at, self._trans_far_val): 

634 mpl.figure(1) 

635 mpl.plot( 

636 [line_trans, line_trans], 

637 [-100.0, 100.0], 

638 "--", 

639 color="black", 

640 ) 

641 if self._eval and self._split: 

642 mpl.figure(2) 

643 x_values = [i for i, _ in self._eval_points[line]] 

644 y_values = [j for _, j in self._eval_points[line]] 

645 sort_indice = sorted( 

646 range(len(x_values)), key=x_values.__getitem__ 

647 ) 

648 x_values = [x_values[i] for i in sort_indice] 

649 y_values = [y_values[i] for i in sort_indice] 

650 mpl.plot(x_values, y_values, "--", color="black") 

651 # only for plots 

652 if self._end_setup_plot: 

653 for i in range(self._nb_figs): 

654 fig = mpl.figure(i + 1) 

655 title = "" if not self._titles else self._titles[i] 

656 mpl.title(title if title.replace(" ", "") else "") 

657 mpl.xlabel(self._x_label) 

658 mpl.ylabel(self._y_label) 

659 mpl.grid(True, color=self._grid_color) 

660 if self._disp_legend: 

661 self.plot_legends() 

662 self._set_axis() 

663 mpl.xticks(rotation=self._x_rotation) 

664 self._pdf_page.savefig(fig) 

665 

666 # do not want to close PDF when running evaluate 

667 if "PdfPages" in self._ctx.meta and ( 

668 "closef" not in self._ctx.meta or self._ctx.meta["closef"] 

669 ): 

670 self._pdf_page.close() 

671 

672 def plot_legends(self): 

673 """Print legend on current plot""" 

674 if not self._disp_legend: 

675 return 

676 

677 lines = [] 

678 labels = [] 

679 for ax in mpl.gcf().get_axes(): 

680 ali, ala = ax.get_legend_handles_labels() 

681 # avoid duplicates in legend 

682 for li, la in zip(ali, ala): 

683 if la not in labels: 

684 lines.append(li) 

685 labels.append(la) 

686 

687 # create legend on the top or bottom axis 

688 leg = mpl.legend( 

689 lines, 

690 labels, 

691 loc=self._legend_loc, 

692 ncol=1, 

693 ) 

694 

695 return leg 

696 

697 # common protected functions 

698 

699 def _label(self, base, idx): 

700 if self._legends is not None and len(self._legends) > idx: 

701 return self._legends[idx] 

702 if self.n_systems > 1: 

703 return base + (" %d" % (idx + 1)) 

704 return base 

705 

706 def _set_axis(self): 

707 if self._axlim is not None: 

708 mpl.axis(self._axlim) 

709 

710 

711class Roc(PlotBase): 

712 """Handles the plotting of ROC""" 

713 

714 def __init__(self, ctx, scores, evaluation, func_load): 

715 super(Roc, self).__init__(ctx, scores, evaluation, func_load) 

716 self._titles = self._titles or ["ROC dev.", "ROC eval."] 

717 self._x_label = self._x_label or "FPR" 

718 self._semilogx = ctx.meta.get("semilogx", True) 

719 self._tpr = ctx.meta.get("tpr", True) 

720 dflt_y_label = "TPR" if self._tpr else "FNR" 

721 self._y_label = self._y_label or dflt_y_label 

722 best_legend = "lower right" if self._semilogx else "upper right" 

723 self._legend_loc = self._legend_loc or best_legend 

724 # custom defaults 

725 if self._axlim is None: 

726 self._axlim = [None, None, -0.05, 1.05] 

727 self._min_dig = -4 if self._min_dig is None else self._min_dig 

728 

729 def compute(self, idx, input_scores, input_names): 

730 """Plot ROC for dev and eval data using 

731 :py:func:`bob.measure.plot.roc`""" 

732 neg_list, pos_list, _ = utils.get_fta_list(input_scores) 

733 dev_neg, dev_pos = neg_list[0], pos_list[0] 

734 dev_file = input_names[0] 

735 if self._eval: 

736 eval_neg, eval_pos = neg_list[1], pos_list[1] 

737 eval_file = input_names[1] 

738 

739 mpl.figure(1) 

740 if self._eval: 

741 LOGGER.info("ROC dev. curve using %s", dev_file) 

742 plot.roc( 

743 dev_neg, 

744 dev_pos, 

745 npoints=self._points, 

746 semilogx=self._semilogx, 

747 tpr=self._tpr, 

748 min_far=self._min_dig, 

749 color=self._colors[idx], 

750 linestyle=self._linestyles[idx], 

751 label=self._label("dev", idx), 

752 alpha=self._alpha, 

753 ) 

754 if self._split: 

755 mpl.figure(2) 

756 

757 linestyle = "--" if not self._split else self._linestyles[idx] 

758 LOGGER.info("ROC eval. curve using %s", eval_file) 

759 plot.roc( 

760 eval_neg, 

761 eval_pos, 

762 linestyle=linestyle, 

763 npoints=self._points, 

764 semilogx=self._semilogx, 

765 tpr=self._tpr, 

766 min_far=self._min_dig, 

767 color=self._colors[idx], 

768 label=self._label("eval.", idx), 

769 alpha=self._alpha, 

770 ) 

771 if self._far_at is not None: 

772 from .. import fprfnr 

773 

774 for line in self._far_at: 

775 thres_line = far_threshold(dev_neg, dev_pos, line) 

776 eval_fmr, eval_fnmr = fprfnr(eval_neg, eval_pos, thres_line) 

777 if self._tpr: 

778 eval_fnmr = 1 - eval_fnmr 

779 mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30) 

780 self._eval_points[line].append((eval_fmr, eval_fnmr)) 

781 else: 

782 LOGGER.info("ROC dev. curve using %s", dev_file) 

783 plot.roc( 

784 dev_neg, 

785 dev_pos, 

786 npoints=self._points, 

787 semilogx=self._semilogx, 

788 tpr=self._tpr, 

789 min_far=self._min_dig, 

790 color=self._colors[idx], 

791 linestyle=self._linestyles[idx], 

792 label=self._label("dev", idx), 

793 alpha=self._alpha, 

794 ) 

795 

796 

797class Det(PlotBase): 

798 """Handles the plotting of DET""" 

799 

800 def __init__(self, ctx, scores, evaluation, func_load): 

801 super(Det, self).__init__(ctx, scores, evaluation, func_load) 

802 self._titles = self._titles or ["DET dev.", "DET eval."] 

803 self._x_label = self._x_label or "FPR (%)" 

804 self._y_label = self._y_label or "FNR (%)" 

805 self._legend_loc = self._legend_loc or "upper right" 

806 if self._far_at is not None: 

807 self._trans_far_val = ppndf(self._far_at) 

808 # custom defaults here 

809 if self._x_rotation is None: 

810 self._x_rotation = 50 

811 

812 if self._axlim is None: 

813 self._axlim = [0.01, 99, 0.01, 99] 

814 

815 if self._min_dig is not None: 

816 self._axlim[0] = math.pow(10, self._min_dig) * 100 

817 

818 self._min_dig = -4 if self._min_dig is None else self._min_dig 

819 

820 def compute(self, idx, input_scores, input_names): 

821 """Plot DET for dev and eval data using 

822 :py:func:`bob.measure.plot.det`""" 

823 neg_list, pos_list, _ = utils.get_fta_list(input_scores) 

824 dev_neg, dev_pos = neg_list[0], pos_list[0] 

825 dev_file = input_names[0] 

826 if self._eval: 

827 eval_neg, eval_pos = neg_list[1], pos_list[1] 

828 eval_file = input_names[1] 

829 

830 mpl.figure(1) 

831 if self._eval and eval_neg is not None: 

832 LOGGER.info("DET dev. curve using %s", dev_file) 

833 plot.det( 

834 dev_neg, 

835 dev_pos, 

836 self._points, 

837 min_far=self._min_dig, 

838 color=self._colors[idx], 

839 linestyle=self._linestyles[idx], 

840 label=self._label("dev.", idx), 

841 alpha=self._alpha, 

842 ) 

843 if self._split: 

844 mpl.figure(2) 

845 linestyle = "--" if not self._split else self._linestyles[idx] 

846 LOGGER.info("DET eval. curve using %s", eval_file) 

847 plot.det( 

848 eval_neg, 

849 eval_pos, 

850 self._points, 

851 min_far=self._min_dig, 

852 color=self._colors[idx], 

853 linestyle=linestyle, 

854 label=self._label("eval.", idx), 

855 alpha=self._alpha, 

856 ) 

857 if self._far_at is not None: 

858 from .. import farfrr 

859 

860 for line in self._far_at: 

861 thres_line = far_threshold(dev_neg, dev_pos, line) 

862 eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line) 

863 eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr) 

864 mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30) 

865 self._eval_points[line].append((eval_fmr, eval_fnmr)) 

866 else: 

867 LOGGER.info("DET dev. curve using %s", dev_file) 

868 plot.det( 

869 dev_neg, 

870 dev_pos, 

871 self._points, 

872 min_far=self._min_dig, 

873 color=self._colors[idx], 

874 linestyle=self._linestyles[idx], 

875 label=self._label("dev.", idx), 

876 alpha=self._alpha, 

877 ) 

878 

879 def _set_axis(self): 

880 plot.det_axis(self._axlim) 

881 

882 

883class Epc(PlotBase): 

884 """Handles the plotting of EPC""" 

885 

886 def __init__(self, ctx, scores, evaluation, func_load, hter="HTER"): 

887 super(Epc, self).__init__(ctx, scores, evaluation, func_load) 

888 if self._min_arg != 2: 

889 raise click.UsageError("EPC requires dev. and eval. score files") 

890 self._titles = self._titles or ["EPC"] * 2 

891 self._x_label = self._x_label or r"$\alpha$" 

892 self._y_label = self._y_label or hter + " (%)" 

893 self._legend_loc = self._legend_loc or "upper center" 

894 self._eval = True # always eval data with EPC 

895 self._split = False 

896 self._nb_figs = 1 

897 self._far_at = None 

898 

899 def compute(self, idx, input_scores, input_names): 

900 """Plot EPC using :py:func:`bob.measure.plot.epc`""" 

901 neg_list, pos_list, _ = utils.get_fta_list(input_scores) 

902 dev_neg, dev_pos = neg_list[0], pos_list[0] 

903 dev_file = input_names[0] 

904 if self._eval: 

905 eval_neg, eval_pos = neg_list[1], pos_list[1] 

906 eval_file = input_names[1] 

907 

908 LOGGER.info("EPC using %s", dev_file + "_" + eval_file) 

909 plot.epc( 

910 dev_neg, 

911 dev_pos, 

912 eval_neg, 

913 eval_pos, 

914 self._points, 

915 color=self._colors[idx], 

916 linestyle=self._linestyles[idx], 

917 label=self._label("curve", idx), 

918 alpha=self._alpha, 

919 ) 

920 

921 

922class GridSubplot(PlotBase): 

923 """A base class for plots that contain subplots and legends. 

924 

925 To use this class, use `create_subplot` in `compute` each time you need a 

926 new axis. and call `finalize_one_page` in `compute` when a page is finished 

927 rendering. 

928 """ 

929 

930 def __init__(self, ctx, scores, evaluation, func_load): 

931 super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load) 

932 

933 # Check legend 

934 self._legend_loc = self._legend_loc or "upper center" 

935 if self._legend_loc == "best": 

936 self._legend_loc = "upper center" 

937 if "upper" not in self._legend_loc and "lower" not in self._legend_loc: 

938 raise ValueError( 

939 "Only best, upper-*, and lower-* legend locations are supported!" 

940 ) 

941 self._nlegends = ctx.meta.get("legends_ncol", 3) 

942 

943 # subplot grid 

944 self._nrows = ctx.meta.get("n_row", 1) 

945 self._ncols = ctx.meta.get("n_col", 1) 

946 

947 def init_process(self): 

948 super(GridSubplot, self).init_process() 

949 self._create_grid_spec() 

950 

951 def _create_grid_spec(self): 

952 # create a compatible GridSpec 

953 self._gs = gridspec.GridSpec( 

954 self._nrows, 

955 self._ncols, 

956 figure=mpl.gcf(), 

957 ) 

958 

959 def create_subplot(self, n, shared_axis=None): 

960 i, j = numpy.unravel_index(n, (self._nrows, self._ncols)) 

961 axis = mpl.gcf().add_subplot( 

962 self._gs[i : i + 1, j : j + 1], sharex=shared_axis 

963 ) 

964 return axis 

965 

966 def finalize_one_page(self): 

967 # print legend on the page 

968 self.plot_legends() 

969 fig = mpl.gcf() 

970 axes = fig.get_axes() 

971 

972 LOGGER.debug("%s contains %d axes:", fig, len(axes)) 

973 for i, ax in enumerate(axes, start=1): 

974 LOGGER.debug("Axes %d: %s", i, ax) 

975 

976 self._pdf_page.savefig(bbox_inches="tight") 

977 mpl.clf() 

978 mpl.figure() 

979 self._create_grid_spec() 

980 

981 def plot_legends(self): 

982 """Print legend on current page""" 

983 if not self._disp_legend: 

984 return 

985 

986 lines = [] 

987 labels = [] 

988 for ax in mpl.gcf().get_axes(): 

989 ali, ala = ax.get_legend_handles_labels() 

990 # avoid duplicates in legend 

991 for li, la in zip(ali, ala): 

992 if la not in labels: 

993 lines.append(li) 

994 labels.append(la) 

995 

996 # create legend on the top or bottom axis 

997 fig = mpl.gcf() 

998 if "upper" in self._legend_loc: 

999 # Set anchor to top of figure 

1000 bbox_to_anchor = (0.0, 1.0, 1.0, 0.0) 

1001 # Legend will be anchored with its bottom side, so switch the loc 

1002 anchored_loc = self._legend_loc.replace("upper", "lower") 

1003 else: 

1004 # Set anchor to bottom of figure 

1005 bbox_to_anchor = (0.0, 0.0, 1.0, 0.0) 

1006 # Legend will be anchored with its top side, so switch the loc 

1007 anchored_loc = self._legend_loc.replace("lower", "upper") 

1008 leg = fig.legend( 

1009 lines, 

1010 labels, 

1011 loc=anchored_loc, 

1012 ncol=self._nlegends, 

1013 bbox_to_anchor=bbox_to_anchor, 

1014 ) 

1015 

1016 return leg 

1017 

1018 

1019class Hist(GridSubplot): 

1020 """Functional base class for histograms""" 

1021 

1022 def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2): 

1023 super(Hist, self).__init__(ctx, scores, evaluation, func_load) 

1024 self._nbins = ctx.meta.get("n_bins", ["doane"]) 

1025 self._nhist_per_system = nhist_per_system 

1026 self._nbins = check_list_value( 

1027 self._nbins, nhist_per_system, "n_bins", "histograms" 

1028 ) 

1029 self._thres = ctx.meta.get("thres") 

1030 self._thres = check_list_value( 

1031 self._thres, self.n_systems, "thresholds" 

1032 ) 

1033 self._criterion = ctx.meta.get("criterion") 

1034 # no vertical (threshold) is displayed 

1035 self._no_line = ctx.meta.get("no_line", False) 

1036 # do not display dev histo 

1037 self._hide_dev = ctx.meta.get("hide_dev", False) 

1038 if self._hide_dev and not self._eval: 

1039 raise click.BadParameter( 

1040 "You can only use --hide-dev along with --eval" 

1041 ) 

1042 # dev hist are displayed next to eval hist 

1043 self._nrows *= 1 if self._hide_dev or not self._eval else 2 

1044 self._nlegends = ctx.meta.get("legends_ncol", 3) 

1045 

1046 # number of subplot on one page 

1047 self._step_print = int(self._nrows * self._ncols) 

1048 self._title_base = "Scores" 

1049 self._y_label = self._y_label or "Probability density" 

1050 self._x_label = self._x_label or "Score values" 

1051 self._end_setup_plot = False 

1052 # overide _titles of PlotBase 

1053 self._titles = ctx.meta.get("titles", []) * 2 

1054 

1055 def compute(self, idx, input_scores, input_names): 

1056 """Draw histograms of negative and positive scores.""" 

1057 ( 

1058 dev_neg, 

1059 dev_pos, 

1060 eval_neg, 

1061 eval_pos, 

1062 threshold, 

1063 ) = self._get_neg_pos_thres(idx, input_scores, input_names) 

1064 

1065 # keep id of the current system 

1066 sys = idx 

1067 # if the id of the current system does not match the id of the plot, 

1068 # change it 

1069 if not self._hide_dev and self._eval: 

1070 row = int(idx / self._ncols) * 2 

1071 col = idx % self._ncols 

1072 idx = col + self._ncols * row 

1073 

1074 dev_axis = None 

1075 

1076 if not self._hide_dev or not self._eval: 

1077 dev_axis = self._print_subplot( 

1078 idx, 

1079 sys, 

1080 dev_neg, 

1081 dev_pos, 

1082 threshold, 

1083 not self._no_line, 

1084 False, 

1085 ) 

1086 

1087 if self._eval: 

1088 idx += self._ncols if not self._hide_dev else 0 

1089 self._print_subplot( 

1090 idx, 

1091 sys, 

1092 eval_neg, 

1093 eval_pos, 

1094 threshold, 

1095 not self._no_line, 

1096 True, 

1097 shared_axis=dev_axis, 

1098 ) 

1099 

1100 def _print_subplot( 

1101 self, 

1102 idx, 

1103 sys, 

1104 neg, 

1105 pos, 

1106 threshold, 

1107 draw_line, 

1108 evaluation, 

1109 shared_axis=None, 

1110 ): 

1111 """print a subplot for the given score and subplot index""" 

1112 n = idx % self._step_print 

1113 col = n % self._ncols 

1114 sub_plot_idx = n + 1 

1115 axis = self.create_subplot(n, shared_axis) 

1116 self._setup_hist(neg, pos) 

1117 if col == 0: 

1118 axis.set_ylabel(self._y_label) 

1119 # systems per page 

1120 sys_per_page = self._step_print / ( 

1121 1 if self._hide_dev or not self._eval else 2 

1122 ) 

1123 # rest to be printed 

1124 sys_idx = sys % sys_per_page 

1125 rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page 

1126 # lower histo only 

1127 is_lower = evaluation or not self._eval 

1128 if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print): 

1129 axis.set_xlabel(self._x_label) 

1130 dflt_title = "Eval. scores" if evaluation else "Dev. scores" 

1131 if self.n_systems == 1 and (not self._eval or self._hide_dev): 

1132 dflt_title = " " 

1133 add = self.n_systems if is_lower else 0 

1134 axis.set_title(self._get_title(sys + add, dflt_title)) 

1135 label = "%s threshold%s" % ( 

1136 "" if self._criterion is None else self._criterion.upper(), 

1137 " (dev)" if self._eval else "", 

1138 ) 

1139 if draw_line: 

1140 self._lines(threshold, label, neg, pos, idx) 

1141 

1142 # enable the grid and set it below other elements 

1143 axis.set_axisbelow(True) 

1144 axis.grid(True, color=self._grid_color) 

1145 

1146 # if it was the last subplot of the page or the last subplot 

1147 # to display, save figure 

1148 if self._step_print == sub_plot_idx or ( 

1149 is_lower and sys == self.n_systems - 1 

1150 ): 

1151 self.finalize_one_page() 

1152 return axis 

1153 

1154 def _get_title(self, idx, dflt=None): 

1155 """Get the histo title for the given idx""" 

1156 title = ( 

1157 self._titles[idx] 

1158 if self._titles is not None and idx < len(self._titles) 

1159 else dflt 

1160 ) 

1161 title = title or self._title_base 

1162 title = ( 

1163 "" if title is not None and not title.replace(" ", "") else title 

1164 ) 

1165 return title or "" 

1166 

1167 def _get_neg_pos_thres(self, idx, input_scores, input_names): 

1168 """Get scores and threshod for the given system at index idx""" 

1169 neg_list, pos_list, _ = utils.get_fta_list(input_scores) 

1170 length = len(neg_list) 

1171 # lists returned by get_fta_list contains all the following items: 

1172 # for bio or measure without eval: 

1173 # [dev] 

1174 # for vuln with {licit,spoof} with eval: 

1175 # [dev, eval] 

1176 # for vuln with {licit,spoof} without eval: 

1177 # [licit_dev, spoof_dev] 

1178 # for vuln with {licit,spoof} with eval: 

1179 # [licit_dev, licit_eval, spoof_dev, spoof_eval] 

1180 step = 2 if self._eval else 1 

1181 # can have several files for one system 

1182 dev_neg = [neg_list[x] for x in range(0, length, step)] 

1183 dev_pos = [pos_list[x] for x in range(0, length, step)] 

1184 eval_neg = eval_pos = None 

1185 if self._eval: 

1186 eval_neg = [neg_list[x] for x in range(1, length, step)] 

1187 eval_pos = [pos_list[x] for x in range(1, length, step)] 

1188 

1189 threshold = ( 

1190 utils.get_thres(self._criterion, dev_neg[0], dev_pos[0]) 

1191 if self._thres is None 

1192 else self._thres[idx] 

1193 ) 

1194 return dev_neg, dev_pos, eval_neg, eval_pos, threshold 

1195 

1196 def _density_hist(self, scores, n, **kwargs): 

1197 """Plots one density histo""" 

1198 n, bins, patches = mpl.hist( 

1199 scores, density=True, bins=self._nbins[n], **kwargs 

1200 ) 

1201 return (n, bins, patches) 

1202 

1203 def _lines( 

1204 self, threshold, label=None, neg=None, pos=None, idx=None, **kwargs 

1205 ): 

1206 """Plots vertical line at threshold""" 

1207 label = label or "Threshold" 

1208 kwargs.setdefault("color", "C3") 

1209 kwargs.setdefault("linestyle", "--") 

1210 kwargs.setdefault("label", label) 

1211 # plot a vertical threshold line 

1212 mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs) 

1213 

1214 def _setup_hist(self, neg, pos): 

1215 """This function can be overwritten in derived classes 

1216 

1217 Plots all the density histo required in one plot. Here negative and 

1218 positive scores densities. 

1219 """ 

1220 self._density_hist( 

1221 neg[0], n=0, label="Negatives", alpha=0.5, color="C3" 

1222 ) 

1223 self._density_hist( 

1224 pos[0], n=1, label="Positives", alpha=0.5, color="C0" 

1225 )