Coverage for src/bob/pad/base/script/cross.py: 0%
125 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 23:40 +0200
1"""Prints Cross-db metrics analysis
2"""
3import itertools
4import json
5import logging
6import math
7import os
9import click
10import jinja2
11import yaml
13from clapper.click import log_parameters, verbosity_option
14from tabulate import tabulate
16from bob.bio.base.score.load import get_negatives_positives, load_score
17from bob.measure import farfrr
18from bob.measure.script import common_options
19from bob.measure.utils import get_fta
21from ..error_utils import calc_threshold
22from .pad_commands import CRITERIA
24logger = logging.getLogger(__name__)
27def bool_option(name, short_name, desc, dflt=False, **kwargs):
28 """Generic provider for boolean options
30 Parameters
31 ----------
32 name : str
33 name of the option
34 short_name : str
35 short name for the option
36 desc : str
37 short description for the option
38 dflt : bool or None
39 Default value
40 **kwargs
41 All kwargs are passed to click.option.
43 Returns
44 -------
45 ``callable``
46 A decorator to be used for adding this option.
47 """
49 def custom_bool_option(func):
50 def callback(ctx, param, value):
51 ctx.meta[name.replace("-", "_")] = value
52 return value
54 return click.option(
55 "-%s/-n%s" % (short_name, short_name),
56 "--%s/--no-%s" % (name, name),
57 default=dflt,
58 help=desc,
59 show_default=True,
60 callback=callback,
61 is_eager=True,
62 **kwargs,
63 )(func)
65 return custom_bool_option
68def _ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=dict):
69 """Loads the contents of the YAML stream into :py:class:`collections.OrderedDict`'s
71 See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
73 """
75 class OrderedLoader(Loader):
76 pass
78 def construct_mapping(loader, node):
79 loader.flatten_mapping(node)
80 return object_pairs_hook(loader.construct_pairs(node))
82 OrderedLoader.add_constructor(
83 yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping
84 )
86 return yaml.load(stream, OrderedLoader)
89def expand(data):
90 """Generates configuration sets based on the YAML input contents
92 For an introduction to the YAML mark-up, just search the net. Here is one of
93 its references: https://en.wikipedia.org/wiki/YAML
95 A configuration set corresponds to settings for **all** variables in the
96 input template that needs replacing. For example, if your template mentions
97 the variables ``name`` and ``version``, then each configuration set should
98 yield values for both ``name`` and ``version``.
100 For example:
102 .. code-block:: yaml
104 name: [john, lisa]
105 version: [v1, v2]
108 This should yield to the following configuration sets:
110 .. code-block:: python
112 [
113 {'name': 'john', 'version': 'v1'},
114 {'name': 'john', 'version': 'v2'},
115 {'name': 'lisa', 'version': 'v1'},
116 {'name': 'lisa', 'version': 'v2'},
117 ]
120 Each key in the input file should correspond to either an object or a YAML
121 array. If the object is a list, then we'll iterate over it for every possible
122 combination of elements in the lists. If the element in question is not a
123 list, then it is considered unique and repeated for each yielded
124 configuration set. Example
126 .. code-block:: yaml
128 name: [john, lisa]
129 version: [v1, v2]
130 text: >
131 hello,
132 world!
134 Should yield to the following configuration sets:
136 .. code-block:: python
138 [
139 {'name': 'john', 'version': 'v1', 'text': 'hello, world!'},
140 {'name': 'john', 'version': 'v2', 'text': 'hello, world!'},
141 {'name': 'lisa', 'version': 'v1', 'text': 'hello, world!'},
142 {'name': 'lisa', 'version': 'v2', 'text': 'hello, world!'},
143 ]
145 Keys starting with one `_` (underscore) are treated as "unique" objects as
146 well. Example:
148 .. code-block:: yaml
150 name: [john, lisa]
151 version: [v1, v2]
152 _unique: [i1, i2]
154 Should yield to the following configuration sets:
156 .. code-block:: python
158 [
159 {'name': 'john', 'version': 'v1', '_unique': ['i1', 'i2']},
160 {'name': 'john', 'version': 'v2', '_unique': ['i1', 'i2']},
161 {'name': 'lisa', 'version': 'v1', '_unique': ['i1', 'i2']},
162 {'name': 'lisa', 'version': 'v2', '_unique': ['i1', 'i2']},
163 ]
166 Parameters:
168 data (str): YAML data to be parsed
171 Yields:
173 dict: A dictionary of key-value pairs for building the templates
175 """
177 data = _ordered_load(data, yaml.SafeLoader)
179 # separates "unique" objects from the ones we have to iterate
180 # pre-assemble return dictionary
181 iterables = dict()
182 unique = dict()
183 for key, value in data.items():
184 if isinstance(value, list) and not key.startswith("_"):
185 iterables[key] = value
186 else:
187 unique[key] = value
189 # generates all possible combinations of iterables
190 for values in itertools.product(*iterables.values()):
191 retval = dict(unique)
192 keys = list(iterables.keys())
193 retval.update(dict(zip(keys, values)))
194 yield retval
197@click.command(
198 epilog="""\b
199Examples:
200 $ bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \
201 -td replaymobile \
202 -d replaymobile -p grandtest \
203 -d oulunpu -p Protocol_1 \
204 -a replaymobile_grandtest_frame-diff-svm \
205 -a replaymobile_grandtest_qm-svm-64 \
206 -a replaymobile_grandtest_lbp-svm-64 \
207 > replaymobile.rst &
208"""
209)
210@click.argument("score_jinja_template")
211@click.option(
212 "-d",
213 "--database",
214 "databases",
215 multiple=True,
216 required=True,
217 show_default=True,
218 help="Names of the evaluation databases",
219)
220@click.option(
221 "-p",
222 "--protocol",
223 "protocols",
224 multiple=True,
225 required=True,
226 show_default=True,
227 help="Names of the protocols of the evaluation databases",
228)
229@click.option(
230 "-a",
231 "--algorithm",
232 "algorithms",
233 multiple=True,
234 required=True,
235 show_default=True,
236 help="Names of the algorithms",
237)
238@click.option(
239 "-n",
240 "--names",
241 type=click.File("r"),
242 help="Name of algorithms to show in the table. Provide a path "
243 "to a json file maps algorithm names to names that you want to "
244 "see in the table.",
245)
246@click.option(
247 "-td",
248 "--train-database",
249 required=True,
250 help="The database that was used to train the algorithms.",
251)
252@click.option(
253 "-pn",
254 "--pai-names",
255 type=click.File("r"),
256 help="Name of PAIs to compute the errors per PAI. Provide a path "
257 "to a json file maps attack_type in scores to PAIs that you want to "
258 "see in the table.",
259)
260@click.option(
261 "-g",
262 "--group",
263 "groups",
264 multiple=True,
265 show_default=True,
266 default=["train", "dev", "eval"],
267)
268@bool_option("sort", "s", "whether the table should be sorted.", True)
269@common_options.criterion_option(lcriteria=CRITERIA, check=False)
270@common_options.far_option()
271@common_options.table_option()
272@common_options.output_log_metric_option()
273@common_options.decimal_option(dflt=2, short="-dec")
274@verbosity_option(logger)
275@click.pass_context
276def cross(
277 ctx,
278 score_jinja_template,
279 databases,
280 protocols,
281 algorithms,
282 names,
283 train_database,
284 pai_names,
285 groups,
286 sort,
287 decimal,
288 verbose,
289 **kwargs,
290):
291 """Cross-db analysis metrics"""
292 log_parameters(logger)
294 names = {} if names is None else json.load(names)
296 env = jinja2.Environment(undefined=jinja2.StrictUndefined)
298 data = {
299 "evaluation": [
300 {"database": db, "protocol": proto}
301 for db, proto in zip(databases, protocols)
302 ],
303 "algorithm": algorithms,
304 "group": groups,
305 }
307 metrics = {}
309 for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
310 logger.debug(variables)
312 score_path = env.from_string(score_jinja_template).render(variables)
313 logger.info(score_path)
315 database, protocol, algorithm, group = (
316 variables["evaluation"]["database"],
317 variables["evaluation"]["protocol"],
318 variables["algorithm"],
319 variables["group"],
320 )
322 # if algorithm name does not have train_database name in it.
323 if train_database not in algorithm and database != train_database:
324 score_path = score_path.replace(
325 algorithm, database + "_" + algorithm
326 )
327 logger.info("Score path changed to: %s", score_path)
329 if not os.path.exists(score_path):
330 metrics[(database, protocol, algorithm, group)] = (
331 float("nan"),
332 ) * 5
333 continue
335 scores = load_score(score_path)
336 neg, pos = get_negatives_positives(scores)
337 (neg, pos), fta = get_fta((neg, pos))
339 if group == "eval":
340 threshold = metrics[(database, protocol, algorithm, "dev")][1]
341 else:
342 try:
343 threshold = calc_threshold(
344 ctx.meta["criterion"],
345 pos,
346 [neg],
347 neg,
348 ctx.meta["far_value"],
349 )
350 except RuntimeError:
351 logger.error("Something wrong with {}".format(score_path))
352 raise
354 far, frr = farfrr(neg, pos, threshold)
355 hter = (far + frr) / 2
357 metrics[(database, protocol, algorithm, group)] = (
358 hter,
359 threshold,
360 fta,
361 far,
362 frr,
363 )
365 logger.debug("metrics: %s", metrics)
367 headers = ["Algorithms"]
368 for db in databases:
369 headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
370 rows = []
372 # sort the algorithms based on HTER test, EER dev, EER train
373 train_protocol = protocols[databases.index(train_database)]
374 if sort:
376 def sort_key(alg):
377 r = []
378 for grp in ("eval", "dev", "train"):
379 hter = metrics[(train_database, train_protocol, alg, group)][0]
380 r.append(1 if math.isnan(hter) else hter)
381 return tuple(r)
383 algorithms = sorted(algorithms, key=sort_key)
385 for algorithm in algorithms:
386 name = algorithm.replace(train_database + "_", "")
387 name = name.replace(train_protocol + "_", "")
388 name = names.get(name, name)
389 rows.append([name])
390 for database, protocol in zip(databases, protocols):
391 cell = []
392 for group in groups:
393 hter, threshold, fta, far, frr = metrics[
394 (database, protocol, algorithm, group)
395 ]
396 if group == "eval":
397 cell += [far, frr, hter]
398 else:
399 cell += [hter]
400 cell = [round(c * 100, decimal) for c in cell]
401 rows[-1].extend(cell)
403 title = " Trained on {} ".format(train_database)
404 title_line = "\n" + "=" * len(title) + "\n"
405 # open log file for writing if any
406 ctx.meta["log"] = (
407 ctx.meta["log"]
408 if ctx.meta["log"] is None
409 else open(ctx.meta["log"], "w")
410 )
411 click.echo(title_line + title + title_line, file=ctx.meta["log"])
412 click.echo(
413 tabulate(rows, headers, ctx.meta["tablefmt"], floatfmt=".1f"),
414 file=ctx.meta["log"],
415 )