1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Shenzhen CXR dataset"""
6
7import numpy
8import pytest
9
10from ...binseg.data.shenzhen import dataset
11from .utils import count_bw
12
13
14def test_protocol_consistency():
15
16 subset = dataset.subsets("default")
17 assert len(subset) == 3
18
19 assert "train" in subset
20 assert len(subset["train"]) == 396
21 for s in subset["train"]:
22 assert s.key.startswith("CXR_png")
23
24 assert "validation" in subset
25 assert len(subset["validation"]) == 56
26 for s in subset["validation"]:
27 assert s.key.startswith("CXR_png")
28
29 assert "test" in subset
30 assert len(subset["test"]) == 114
31 for s in subset["test"]:
32 assert s.key.startswith("CXR_png")
33
34
35@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.shenzhen.datadir")
36def test_loading():
37
38 min_image_size = (1130, 948)
39 max_image_size = (3001, 3001)
40
41 def _check_sample(s, bw_threshold_label):
42
43 data = s.data
44 assert isinstance(data, dict)
45 assert len(data) == 2
46
47 assert "data" in data
48 assert data["data"].mode == "RGB"
49
50 assert "label" in data
51 assert data["label"].mode == "1"
52
53 b, w = count_bw(data["label"])
54 assert (b + w) >= numpy.prod(min_image_size), (
55 f"Counts of black + white ({b}+{w}) lower than smallest image total"
56 f"image size ({numpy.prod(min_image_size)}) at '{s.key}':label"
57 )
58 assert (b + w) <= numpy.prod(max_image_size), (
59 f"Counts of black + white ({b}+{w}) higher than largest image total"
60 f"image size ({numpy.prod(max_image_size)}) at '{s.key}':label"
61 )
62 assert (w / b) < bw_threshold_label, (
63 f"The proportion between black and white pixels "
64 f"({w}/{b}={w/b:.3f}) is larger than the allowed threshold "
65 f"of {bw_threshold_label} at '{s.key}':label - this could "
66 f"indicate a loading problem!"
67 )
68
69 # to visualize images, uncomment the folowing code it should display an
70 # image with a faded background representing the original data, blended
71 # with green labels.
72 # from ..data.utils import overlayed_image
73 # display = overlayed_image(data["data"], data["label"])
74 # display.show()
75 # import ipdb; ipdb.set_trace()
76
77 return w / b
78
79 limit = None # use this to limit testing to first images only
80 subset = dataset.subsets("default")
81 proportions = [_check_sample(s, 0.77) for s in subset["train"][:limit]]
82 proportions = [_check_sample(s, 0.77) for s in subset["validation"][:limit]]
83 proportions = [_check_sample(s, 0.77) for s in subset["test"][:limit]]
84 del proportions # only to satisfy flake8
85
86
87@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.shenzhen.datadir")
88def test_check():
89 assert dataset.check() == 0