1#!/usr/bin/env python
2# coding=utf-8
3
4
5import tabulate
6import numpy as np
7import torch
8from sklearn.metrics import auc, precision_recall_curve as pr_curve, roc_curve as r_curve, f1_score, accuracy_score
9from ..engine.evaluator import posneg
10from ..utils.measure import bayesian_measures, base_measures
11
12
13def performance_table(data, fmt):
14 """Tables result comparison in a given format
15
16
17 Parameters
18 ----------
19
20 data : dict
21 A dictionary in which keys are strings defining plot labels and values
22 are dictionaries with two entries:
23
24 * ``df``: :py:class:`pandas.DataFrame`
25
26 A dataframe that is produced by our predictor engine containing
27 the following columns: ``filename``, ``likelihood``,
28 ``ground_truth``.
29
30 * ``threshold``: :py:class:`list`
31
32 A threshold to compute measures.
33
34
35 fmt : str
36 One of the formats supported by tabulate.
37
38
39 Returns
40 -------
41
42 table : str
43 A table in a specific format
44
45 """
46
47 headers = [
48 "Dataset",
49 "T",
50 "F1 (95% CI)",
51 "Prec (95% CI)",
52 "Recall/Sen (95% CI)",
53 "Spec (95% CI)",
54 "Acc (95% CI)",
55 "AUC (PRC)",
56 "AUC (ROC)"
57 ]
58
59 table = []
60 for k, v in data.items():
61 entry = [k, v["threshold"], ]
62
63 df = v["df"]
64
65 gt = torch.tensor(df['ground_truth'].values)
66 pred = torch.tensor(df['likelihood'].values)
67 threshold = v["threshold"]
68
69 tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
70
71 # calc measures from scalars
72 tp_count = torch.sum(tp_tensor).item()
73 fp_count = torch.sum(fp_tensor).item()
74 tn_count = torch.sum(tn_tensor).item()
75 fn_count = torch.sum(fn_tensor).item()
76
77 base_m = base_measures(
78 tp_count,
79 fp_count,
80 tn_count,
81 fn_count,
82 )
83
84 bayes_m = bayesian_measures(
85 tp_count,
86 fp_count,
87 tn_count,
88 fn_count,
89 lambda_=1,
90 coverage=0.95,
91 )
92
93 # statistics based on the "assigned" threshold (a priori, less biased)
94 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[5], bayes_m[5][2], bayes_m[5][3])) # f1
95 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[0], bayes_m[0][2], bayes_m[0][3])) # precision
96 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[1], bayes_m[1][2], bayes_m[1][3])) # recall/sensitivity
97 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[2], bayes_m[2][2], bayes_m[2][3])) # specificity
98 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[3], bayes_m[3][2], bayes_m[3][3])) # accuracy
99
100 prec, recall, _ = pr_curve(gt, pred)
101 fpr, tpr, _ = r_curve(gt, pred)
102
103 entry.append(auc(recall, prec))
104 entry.append(auc(fpr, tpr))
105
106 table.append(entry)
107
108 return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")