1#!/usr/bin/env python
2# coding=utf-8
3
4"""Tests for our CLI applications"""
5
6import os
7import re
8import contextlib
9from _pytest.tmpdir import tmp_path
10from bob.extension import rc
11import pkg_resources
12
13import pytest
14from click.testing import CliRunner
15
16from . import mock_dataset
17
18# Download test data and get their location if needed
19montgomery_datadir = mock_dataset()
20
21_pasa_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_fpasa_checkpoint.pth"
22_signstotb_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_signstotb_checkpoint.pth"
23_logreg_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_logreg_checkpoint.pth"
24# _densenetrs_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_densenetrs_checkpoint.pth"
25
26
27@pytest.fixture(scope="session")
28def temporary_basedir(tmp_path_factory):
29 return tmp_path_factory.mktemp("test-cli")
30
31
32@contextlib.contextmanager
33def rc_context(**new_config):
34 old_rc = rc.copy()
35 rc.update(new_config)
36 try:
37 yield
38 finally:
39 rc.clear()
40 rc.update(old_rc)
41
42
43@contextlib.contextmanager
44def stdout_logging():
45
46 ## copy logging messages to std out
47 import logging
48 import io
49
50 buf = io.StringIO()
51 ch = logging.StreamHandler(buf)
52 ch.setFormatter(logging.Formatter("%(message)s"))
53 ch.setLevel(logging.INFO)
54 logger = logging.getLogger("bob")
55 logger.addHandler(ch)
56 yield buf
57 logger.removeHandler(ch)
58
59
60def _assert_exit_0(result):
61
62 assert (
63 result.exit_code == 0
64 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
65
66
67def _data_file(f):
68 return pkg_resources.resource_filename(__name__, os.path.join("data", f))
69
70
71def _check_help(entry_point):
72
73 runner = CliRunner()
74 result = runner.invoke(entry_point, ["--help"])
75 _assert_exit_0(result)
76 assert result.output.startswith("Usage:")
77
78
79def test_config_help():
80 from ..scripts.config import config
81
82 _check_help(config)
83
84
85def test_config_list_help():
86 from ..scripts.config import list
87
88 _check_help(list)
89
90
91def test_config_list():
92 from ..scripts.config import list
93
94 runner = CliRunner()
95 result = runner.invoke(list)
96 _assert_exit_0(result)
97 assert "module: bob.med.tb.configs.datasets" in result.output
98 assert "module: bob.med.tb.configs.models" in result.output
99
100
101def test_config_list_v():
102 from ..scripts.config import list
103
104 result = CliRunner().invoke(list, ["--verbose"])
105 _assert_exit_0(result)
106 assert "module: bob.med.tb.configs.datasets" in result.output
107 assert "module: bob.med.tb.configs.models" in result.output
108
109
110def test_config_describe_help():
111 from ..scripts.config import describe
112
113 _check_help(describe)
114
115
116def test_config_describe_montgomery():
117 from ..scripts.config import describe
118
119 runner = CliRunner()
120 result = runner.invoke(describe, ["montgomery"])
121 _assert_exit_0(result)
122 assert "Montgomery dataset for TB detection" in result.output
123
124
125def test_dataset_help():
126 from ..scripts.dataset import dataset
127
128 _check_help(dataset)
129
130
131def test_dataset_list_help():
132 from ..scripts.dataset import list
133
134 _check_help(list)
135
136
137def test_dataset_list():
138 from ..scripts.dataset import list
139
140 runner = CliRunner()
141 result = runner.invoke(list)
142 _assert_exit_0(result)
143 assert result.output.startswith("Supported datasets:")
144
145
146def test_dataset_check_help():
147 from ..scripts.dataset import check
148
149 _check_help(check)
150
151
152def test_dataset_check():
153 from ..scripts.dataset import check
154
155 runner = CliRunner()
156 result = runner.invoke(check, ["--verbose", "--limit=2"])
157 _assert_exit_0(result)
158
159
160def test_main_help():
161 from ..scripts.tb import tb
162
163 _check_help(tb)
164
165
166def test_train_help():
167 from ..scripts.train import train
168
169 _check_help(train)
170
171
172def _str_counter(substr, s):
173 return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))
174
175
176def test_predict_help():
177 from ..scripts.predict import predict
178
179 _check_help(predict)
180
181
182def test_predtojson_help():
183 from ..scripts.predtojson import predtojson
184
185 _check_help(predtojson)
186
187
188def test_aggregpred_help():
189 from ..scripts.aggregpred import aggregpred
190
191 _check_help(aggregpred)
192
193
194def test_evaluate_help():
195 from ..scripts.evaluate import evaluate
196
197 _check_help(evaluate)
198
199
200def test_compare_help():
201 from ..scripts.compare import compare
202
203 _check_help(compare)
204
205
206def test_train_pasa_montgomery(temporary_basedir):
207
208 # Temporarily modify Montgomery datadir
209 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
210 with rc_context(**new_value):
211
212 from ..scripts.train import train
213
214 runner = CliRunner()
215
216 with stdout_logging() as buf:
217
218 output_folder = str(temporary_basedir / "results")
219 result = runner.invoke(
220 train,
221 [
222 "pasa",
223 "montgomery",
224 "-vv",
225 "--epochs=1",
226 "--batch-size=1",
227 "--normalization=current",
228 f"--output-folder={output_folder}",
229 ],
230 )
231 _assert_exit_0(result)
232
233 assert os.path.exists(
234 os.path.join(output_folder, "model_final_epoch.pth")
235 )
236 assert os.path.exists(
237 os.path.join(output_folder, "model_lowest_valid_loss.pth")
238 )
239 assert os.path.exists(
240 os.path.join(output_folder, "last_checkpoint")
241 )
242 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
243 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
244 assert os.path.exists(
245 os.path.join(output_folder, "model_summary.txt")
246 )
247
248 keywords = {
249 r"^Found \(dedicated\) '__train__' set for training$": 1,
250 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
251 r"^Continuing from epoch 0$": 1,
252 r"^Saving model summary at.*$": 1,
253 r"^Model has.*$": 1,
254 r"^Saving checkpoint": 2,
255 r"^Total training time:": 1,
256 r"^Z-normalization with mean": 1,
257 }
258 buf.seek(0)
259 logging_output = buf.read()
260
261 for k, v in keywords.items():
262 assert _str_counter(k, logging_output) == v, (
263 f"Count for string '{k}' appeared "
264 f"({_str_counter(k, logging_output)}) "
265 f"instead of the expected {v}:\nOutput:\n{logging_output}"
266 )
267
268
269def test_predict_pasa_montgomery(temporary_basedir):
270
271 # Temporarily modify Montgomery datadir
272 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
273 with rc_context(**new_value):
274
275 from ..scripts.predict import predict
276
277 runner = CliRunner()
278
279 with stdout_logging() as buf:
280
281 output_folder = str(temporary_basedir / "predictions")
282 result = runner.invoke(
283 predict,
284 [
285 "pasa",
286 "montgomery",
287 "-vv",
288 "--batch-size=1",
289 "--relevance-analysis",
290 f"--weight={_pasa_checkpoint_URL}",
291 f"--output-folder={output_folder}",
292 ],
293 )
294 _assert_exit_0(result)
295
296 # check predictions are there
297 predictions_file1 = os.path.join(
298 output_folder, "train/predictions.csv"
299 )
300 predictions_file2 = os.path.join(
301 output_folder, "validation/predictions.csv"
302 )
303 predictions_file3 = os.path.join(
304 output_folder, "test/predictions.csv"
305 )
306 assert os.path.exists(predictions_file1)
307 assert os.path.exists(predictions_file2)
308 assert os.path.exists(predictions_file3)
309
310 keywords = {
311 r"^Loading checkpoint from.*$": 1,
312 r"^Total time:.*$": 3,
313 r"^Relevance analysis.*$": 3,
314 }
315 buf.seek(0)
316 logging_output = buf.read()
317
318 for k, v in keywords.items():
319 assert _str_counter(k, logging_output) == v, (
320 f"Count for string '{k}' appeared "
321 f"({_str_counter(k, logging_output)}) "
322 f"instead of the expected {v}:\nOutput:\n{logging_output}"
323 )
324
325
326def test_predtojson(temporary_basedir):
327
328 # Temporarily modify Montgomery datadir
329 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
330 with rc_context(**new_value):
331
332 from ..scripts.predtojson import predtojson
333
334 runner = CliRunner()
335
336 with stdout_logging() as buf:
337
338 predictions = _data_file("test_predictions.csv")
339 output_folder = str(temporary_basedir / "pred_to_json")
340 result = runner.invoke(
341 predtojson,
342 [
343 "-vv",
344 "train",
345 f"{predictions}",
346 "test",
347 f"{predictions}",
348 f"--output-folder={output_folder}",
349 ],
350 )
351 _assert_exit_0(result)
352
353 # check json file is there
354 assert os.path.exists(os.path.join(output_folder, "dataset.json"))
355
356 keywords = {
357 f"Output folder: {output_folder}": 1,
358 r"Saving JSON file...": 1,
359 r"^Loading predictions from.*$": 2,
360 }
361 buf.seek(0)
362 logging_output = buf.read()
363
364 for k, v in keywords.items():
365 assert _str_counter(k, logging_output) == v, (
366 f"Count for string '{k}' appeared "
367 f"({_str_counter(k, logging_output)}) "
368 f"instead of the expected {v}:\nOutput:\n{logging_output}"
369 )
370
371
372def test_evaluate_pasa_montgomery(temporary_basedir):
373
374 # Temporarily modify Montgomery datadir
375 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
376 with rc_context(**new_value):
377
378 from ..scripts.evaluate import evaluate
379
380 runner = CliRunner()
381
382 with stdout_logging() as buf:
383
384 prediction_folder = str(temporary_basedir / "predictions")
385 output_folder = str(temporary_basedir / "evaluations")
386 result = runner.invoke(
387 evaluate,
388 [
389 "-vv",
390 "montgomery",
391 f"--predictions-folder={prediction_folder}",
392 f"--output-folder={output_folder}",
393 "--threshold=train",
394 "--steps=2000",
395 ],
396 )
397 _assert_exit_0(result)
398
399 # check evaluations are there
400 assert os.path.exists(os.path.join(output_folder, "test.csv"))
401 assert os.path.exists(os.path.join(output_folder, "train.csv"))
402 assert os.path.exists(
403 os.path.join(output_folder, "test_score_table.pdf")
404 )
405 assert os.path.exists(
406 os.path.join(output_folder, "train_score_table.pdf")
407 )
408
409 keywords = {
410 r"^Skipping dataset '__train__'": 1,
411 r"^Evaluating threshold on.*$": 1,
412 r"^Maximum F1-score of.*$": 4,
413 r"^Set --f1_threshold=.*$": 1,
414 r"^Set --eer_threshold=.*$": 1,
415 }
416 buf.seek(0)
417 logging_output = buf.read()
418
419 for k, v in keywords.items():
420 assert _str_counter(k, logging_output) == v, (
421 f"Count for string '{k}' appeared "
422 f"({_str_counter(k, logging_output)}) "
423 f"instead of the expected {v}:\nOutput:\n{logging_output}"
424 )
425
426
427def test_compare_pasa_montgomery(temporary_basedir):
428
429 # Temporarily modify Montgomery datadir
430 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
431 with rc_context(**new_value):
432
433 from ..scripts.compare import compare
434
435 runner = CliRunner()
436
437 with stdout_logging() as buf:
438
439 predictions_folder = str(temporary_basedir / "predictions")
440 output_folder = str(temporary_basedir / "comparisons")
441 result = runner.invoke(
442 compare,
443 [
444 "-vv",
445 "train",
446 f"{predictions_folder}/train/predictions.csv",
447 "test",
448 f"{predictions_folder}/test/predictions.csv",
449 f"--output-figure={output_folder}/compare.pdf",
450 f"--output-table={output_folder}/table.txt",
451 "--threshold=0.5",
452 ],
453 )
454 _assert_exit_0(result)
455
456 # check comparisons are there
457 assert os.path.exists(os.path.join(output_folder, "compare.pdf"))
458 assert os.path.exists(os.path.join(output_folder, "table.txt"))
459
460 keywords = {
461 r"^Dataset '\*': threshold =.*$": 1,
462 r"^Loading predictions from.*$": 2,
463 r"^Tabulating performance summary...": 1,
464 }
465 buf.seek(0)
466 logging_output = buf.read()
467
468 for k, v in keywords.items():
469 assert _str_counter(k, logging_output) == v, (
470 f"Count for string '{k}' appeared "
471 f"({_str_counter(k, logging_output)}) "
472 f"instead of the expected {v}:\nOutput:\n{logging_output}"
473 )
474
475
476def test_train_signstotb_montgomery_rs(temporary_basedir):
477
478 from ..scripts.train import train
479
480 runner = CliRunner()
481
482 with stdout_logging() as buf:
483
484 output_folder = str(temporary_basedir / "results")
485 result = runner.invoke(
486 train,
487 [
488 "signs_to_tb",
489 "montgomery_rs",
490 "-vv",
491 "--epochs=1",
492 "--batch-size=1",
493 f"--weight={_signstotb_checkpoint_URL}",
494 f"--output-folder={output_folder}",
495 ],
496 )
497 _assert_exit_0(result)
498
499 assert os.path.exists(
500 os.path.join(output_folder, "model_final_epoch.pth")
501 )
502 assert os.path.exists(
503 os.path.join(output_folder, "model_lowest_valid_loss.pth")
504 )
505 assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
506 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
507 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
508 assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
509
510 keywords = {
511 r"^Found \(dedicated\) '__train__' set for training$": 1,
512 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
513 r"^Continuing from epoch 0$": 1,
514 r"^Saving model summary at.*$": 1,
515 r"^Model has.*$": 1,
516 r"^Saving checkpoint": 2,
517 r"^Total training time:": 1,
518 }
519 buf.seek(0)
520 logging_output = buf.read()
521
522 for k, v in keywords.items():
523 assert _str_counter(k, logging_output) == v, (
524 f"Count for string '{k}' appeared "
525 f"({_str_counter(k, logging_output)}) "
526 f"instead of the expected {v}:\nOutput:\n{logging_output}"
527 )
528
529
530def test_predict_signstotb_montgomery_rs(temporary_basedir):
531
532 from ..scripts.predict import predict
533
534 runner = CliRunner()
535
536 with stdout_logging() as buf:
537
538 output_folder = str(temporary_basedir / "predictions")
539 result = runner.invoke(
540 predict,
541 [
542 "signs_to_tb",
543 "montgomery_rs",
544 "-vv",
545 "--batch-size=1",
546 "--relevance-analysis",
547 f"--weight={_signstotb_checkpoint_URL}",
548 f"--output-folder={output_folder}",
549 ],
550 )
551 _assert_exit_0(result)
552
553 # check predictions are there
554 predictions_file = os.path.join(output_folder, "train/predictions.csv")
555 RA1 = os.path.join(output_folder, "train_RA.pdf")
556 RA2 = os.path.join(output_folder, "validation_RA.pdf")
557 RA3 = os.path.join(output_folder, "test_RA.pdf")
558 assert os.path.exists(predictions_file)
559 assert os.path.exists(RA1)
560 assert os.path.exists(RA2)
561 assert os.path.exists(RA3)
562
563 keywords = {
564 r"^Loading checkpoint from.*$": 1,
565 r"^Total time:.*$": 3 * 15,
566 r"^Starting relevance analysis for subset.*$": 3,
567 r"^Creating and saving plot at.*$": 3,
568 }
569 buf.seek(0)
570 logging_output = buf.read()
571
572 for k, v in keywords.items():
573 assert _str_counter(k, logging_output) == v, (
574 f"Count for string '{k}' appeared "
575 f"({_str_counter(k, logging_output)}) "
576 f"instead of the expected {v}:\nOutput:\n{logging_output}"
577 )
578
579
580def test_train_logreg_montgomery_rs(temporary_basedir):
581
582 from ..scripts.train import train
583
584 runner = CliRunner()
585
586 with stdout_logging() as buf:
587
588 output_folder = str(temporary_basedir / "results")
589 result = runner.invoke(
590 train,
591 [
592 "logistic_regression",
593 "montgomery_rs",
594 "-vv",
595 "--epochs=1",
596 "--batch-size=1",
597 f"--weight={_logreg_checkpoint_URL}",
598 f"--output-folder={output_folder}",
599 ],
600 )
601 _assert_exit_0(result)
602
603 assert os.path.exists(
604 os.path.join(output_folder, "model_final_epoch.pth")
605 )
606 assert os.path.exists(
607 os.path.join(output_folder, "model_lowest_valid_loss.pth")
608 )
609 assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
610 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
611 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
612 assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
613
614 keywords = {
615 r"^Found \(dedicated\) '__train__' set for training$": 1,
616 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
617 r"^Continuing from epoch 0$": 1,
618 r"^Saving model summary at.*$": 1,
619 r"^Model has.*$": 1,
620 r"^Saving checkpoint": 2,
621 r"^Total training time:": 1,
622 }
623 buf.seek(0)
624 logging_output = buf.read()
625
626 for k, v in keywords.items():
627 assert _str_counter(k, logging_output) == v, (
628 f"Count for string '{k}' appeared "
629 f"({_str_counter(k, logging_output)}) "
630 f"instead of the expected {v}:\nOutput:\n{logging_output}"
631 )
632
633
634def test_predict_logreg_montgomery_rs(temporary_basedir):
635
636 from ..scripts.predict import predict
637
638 runner = CliRunner()
639
640 with stdout_logging() as buf:
641
642 output_folder = str(temporary_basedir / "predictions")
643 result = runner.invoke(
644 predict,
645 [
646 "logistic_regression",
647 "montgomery_rs",
648 "-vv",
649 "--batch-size=1",
650 f"--weight={_logreg_checkpoint_URL}",
651 f"--output-folder={output_folder}",
652 ],
653 )
654 _assert_exit_0(result)
655
656 # check predictions are there
657 predictions_file = os.path.join(output_folder, "train/predictions.csv")
658 wfile = os.path.join(output_folder, "LogReg_Weights.pdf")
659 assert os.path.exists(predictions_file)
660 assert os.path.exists(wfile)
661
662 keywords = {
663 r"^Loading checkpoint from.*$": 1,
664 r"^Total time:.*$": 3,
665 r"^Logistic regression identified: saving model weights.*$": 1,
666 }
667 buf.seek(0)
668 logging_output = buf.read()
669
670 for k, v in keywords.items():
671 assert _str_counter(k, logging_output) == v, (
672 f"Count for string '{k}' appeared "
673 f"({_str_counter(k, logging_output)}) "
674 f"instead of the expected {v}:\nOutput:\n{logging_output}"
675 )
676
677
678def test_aggregpred(temporary_basedir):
679
680 # Temporarily modify Montgomery datadir
681 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
682 with rc_context(**new_value):
683
684 from ..scripts.aggregpred import aggregpred
685
686 runner = CliRunner()
687
688 with stdout_logging() as buf:
689
690 predictions = str(
691 temporary_basedir / "predictions" / "train" / "predictions.csv"
692 )
693 output_folder = str(temporary_basedir / "aggregpred")
694 result = runner.invoke(
695 aggregpred,
696 [
697 "-vv",
698 f"{predictions}",
699 f"{predictions}",
700 f"--output-folder={output_folder}",
701 ],
702 )
703 _assert_exit_0(result)
704
705 # check csv file is there
706 assert os.path.exists(os.path.join(output_folder, "aggregpred.csv"))
707
708 keywords = {
709 f"Output folder: {output_folder}": 1,
710 r"Saving aggregated CSV file...": 1,
711 r"^Loading predictions from.*$": 2,
712 }
713 buf.seek(0)
714 logging_output = buf.read()
715
716 for k, v in keywords.items():
717 assert _str_counter(k, logging_output) == v, (
718 f"Count for string '{k}' appeared "
719 f"({_str_counter(k, logging_output)}) "
720 f"instead of the expected {v}:\nOutput:\n{logging_output}"
721 )
722
723
724# Not enough RAM available to do this test
725# def test_predict_densenetrs_montgomery(temporary_basedir):
726
727# # Temporarily modify Montgomery datadir
728# new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
729# with rc_context(**new_value):
730
731# from ..scripts.predict import predict
732
733# runner = CliRunner()
734
735# with stdout_logging() as buf:
736
737# output_folder = str(temporary_basedir / "predictions")
738# result = runner.invoke(
739# predict,
740# [
741# "densenet_rs",
742# "montgomery_f0_rgb",
743# "-vv",
744# "--batch-size=1",
745# f"--weight={_densenetrs_checkpoint_URL}",
746# f"--output-folder={output_folder}",
747# "--grad-cams"
748# ],
749# )
750# _assert_exit_0(result)
751
752# # check predictions are there
753# predictions_file1 = os.path.join(output_folder, "train/predictions.csv")
754# predictions_file2 = os.path.join(output_folder, "validation/predictions.csv")
755# predictions_file3 = os.path.join(output_folder, "test/predictions.csv")
756# assert os.path.exists(predictions_file1)
757# assert os.path.exists(predictions_file2)
758# assert os.path.exists(predictions_file3)
759# # check some grad cams are there
760# cam1 = os.path.join(output_folder, "train/cams/MCUCXR_0002_0_cam.png")
761# cam2 = os.path.join(output_folder, "train/cams/MCUCXR_0126_1_cam.png")
762# cam3 = os.path.join(output_folder, "train/cams/MCUCXR_0275_1_cam.png")
763# cam4 = os.path.join(output_folder, "validation/cams/MCUCXR_0399_1_cam.png")
764# cam5 = os.path.join(output_folder, "validation/cams/MCUCXR_0113_1_cam.png")
765# cam6 = os.path.join(output_folder, "validation/cams/MCUCXR_0013_0_cam.png")
766# cam7 = os.path.join(output_folder, "test/cams/MCUCXR_0027_0_cam.png")
767# cam8 = os.path.join(output_folder, "test/cams/MCUCXR_0094_0_cam.png")
768# cam9 = os.path.join(output_folder, "test/cams/MCUCXR_0375_1_cam.png")
769# assert os.path.exists(cam1)
770# assert os.path.exists(cam2)
771# assert os.path.exists(cam3)
772# assert os.path.exists(cam4)
773# assert os.path.exists(cam5)
774# assert os.path.exists(cam6)
775# assert os.path.exists(cam7)
776# assert os.path.exists(cam8)
777# assert os.path.exists(cam9)
778
779# keywords = {
780# r"^Loading checkpoint from.*$": 1,
781# r"^Total time:.*$": 3,
782# r"^Grad cams folder:.*$": 3,
783# }
784# buf.seek(0)
785# logging_output = buf.read()
786
787# for k, v in keywords.items():
788# assert _str_counter(k, logging_output) == v, (
789# f"Count for string '{k}' appeared "
790# f"({_str_counter(k, logging_output)}) "
791# f"instead of the expected {v}:\nOutput:\n{logging_output}"
792# )