Coverage for src/bob/bio/vein/script/compare_rois.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 23:27 +0200

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4"""Compares two set of masks and prints some error metrics 

5 

6This program requires that the masks for both databases (one representing the 

7ground-truth and a second the database with an automated method) are 

8represented as HDF5 files containing a ``mask`` object, which should be boolean 

9in nature. 

10 

11 

12Usage: %(prog)s [-v...] [-n X] <ground-truth> <database> 

13 %(prog)s --help 

14 %(prog)s --version 

15 

16 

17Arguments: 

18 <ground-truth> Path to a set of files that contain ground truth annotations 

19 for the ROIs you wish to compare. 

20 <database> Path to a similar set of files as in `<ground-truth>`, but 

21 with ROIs calculated automatically. Every HDF5 in this 

22 directory will be matched to an equivalent file in the 

23 `<ground-truth>` database and their masks will be compared 

24 

25 

26Options: 

27 -h, --help Shows this help message and exits 

28 -V, --version Prints the version and exits 

29 -v, --verbose Increases the output verbosity level 

30 -n N, --annotate=N Print out the N worst cases available in the database, 

31 taking into consideration the various metrics analyzed 

32 

33 

34Example: 

35 

36 1. Just run for basic statistics: 

37 

38 $ %(prog)s -vvv gt/ automatic/ 

39 

40 2. Identify worst 5 samples in the database according to a certain criterion: 

41 

42 $ %(prog)s -vv -n 5 gt/ automatic/ 

43 

44""" 

45 

46import fnmatch 

47import operator 

48import os 

49import sys 

50 

51import clapper.logging 

52import h5py 

53import numpy 

54 

55logger = clapper.logging.setup("bob.bio.vein") 

56 

57 

58def make_catalog(d): 

59 """Returns a catalog dictionary containing the file stems available in ``d`` 

60 

61 Parameters: 

62 

63 d (str): A path representing a directory that will be scanned for .hdf5 

64 files 

65 

66 

67 Returns 

68 

69 list: A list of stems, from the directory ``d``, that represent files of 

70 type HDF5 in that directory. Each file should contain two objects: 

71 ``image`` and ``mask``. 

72 

73 """ 

74 

75 logger.info("Scanning directory `%s'..." % d) 

76 retval = [] 

77 for path, dirs, files in os.walk(d): 

78 basedir = os.path.relpath(path, d) 

79 logger.debug("Scanning sub-directory `%s'..." % basedir) 

80 candidates = fnmatch.filter(files, "*.hdf5") 

81 if not candidates: 

82 continue 

83 logger.debug("Found %d files" % len(candidates)) 

84 retval += [os.path.join(basedir, k) for k in candidates] 

85 logger.info("Found a total of %d files at `%s'" % (len(retval), d)) 

86 return sorted(retval) 

87 

88 

89def sort_table(table, cols): 

90 """Sorts a table by multiple columns 

91 

92 

93 Parameters: 

94 

95 table (:py:class:`list` of :py:class:`list`): Or tuple of tuples, where 

96 each inner list represents a row 

97 

98 cols (list, tuple): Specifies the column numbers to sort by e.g. (1,0) 

99 would sort by column 1, then by column 0 

100 

101 

102 Returns: 

103 

104 list: of lists, with the table re-ordered as you see fit. 

105 

106 """ 

107 

108 for col in reversed(cols): 

109 table = sorted(table, key=operator.itemgetter(col)) 

110 return table 

111 

112 

113def mean_std_for_column(table, column): 

114 """Calculates the mean and standard deviation for the column in question 

115 

116 

117 Parameters: 

118 

119 table (:py:class:`list` of :py:class:`list`): Or tuple of tuples, where 

120 each inner list represents a row 

121 

122 col (int): The number of the column from where to extract the data for 

123 calculating the mean and the standard-deviation. 

124 

125 

126 Returns: 

127 

128 float: mean 

129 

130 float: (unbiased) standard deviation 

131 

132 """ 

133 

134 z = numpy.array([k[column] for k in table]) 

135 return z.mean(), z.std(ddof=1) 

136 

137 

138def main(user_input=None): 

139 if user_input is not None: 

140 argv = user_input 

141 else: 

142 argv = sys.argv[1:] 

143 

144 import docopt 

145 import pkg_resources 

146 

147 completions = dict( 

148 prog=os.path.basename(sys.argv[0]), 

149 version=pkg_resources.require("bob.bio.vein")[0].version, 

150 ) 

151 

152 args = docopt.docopt( 

153 __doc__ % completions, 

154 argv=argv, 

155 version=completions["version"], 

156 ) 

157 

158 # Sets-up logging 

159 verbosity = int(args["--verbose"]) 

160 clapper.logging.set_verbosity_level(logger, verbosity) 

161 

162 # Catalogs 

163 gt = make_catalog(args["<ground-truth>"]) 

164 db = make_catalog(args["<database>"]) 

165 

166 if gt != db: 

167 raise RuntimeError("Ground-truth and database have different files!") 

168 

169 # Calculate all metrics required 

170 from ..preprocessor import utils 

171 

172 metrics = [] 

173 for k in gt: 

174 gt_file = os.path.join(args["<ground-truth>"], k) 

175 db_file = os.path.join(args["<database>"], k) 

176 gt_roi = h5py.File(gt_file).read("mask") 

177 db_roi = h5py.File(db_file).read("mask") 

178 metrics.append( 

179 ( 

180 k, 

181 utils.jaccard_index(gt_roi, db_roi), 

182 utils.intersect_ratio(gt_roi, db_roi), 

183 utils.intersect_ratio_of_complement(gt_roi, db_roi), 

184 ) 

185 ) 

186 logger.info("%s: JI = %.5g, M1 = %.5g, M2 = %5.g" % metrics[-1]) 

187 

188 # Print statistics 

189 names = ( 

190 (1, "Jaccard index"), 

191 (2, "Intersection ratio (m1)"), 

192 (3, "Intersection ratio of complement (m2)"), 

193 ) 

194 print("Statistics:") 

195 for k, name in names: 

196 mean, std = mean_std_for_column(metrics, k) 

197 print(name + ": " + "%.2e +- %.2e" % (mean, std)) 

198 

199 # Print worst cases, if the user asked so 

200 if args["--annotate"] is not None: 

201 N = int(args["--annotate"]) 

202 if N <= 0: 

203 raise docopt.DocoptExit("Argument to --annotate should be >0") 

204 

205 print("Worst cases by metric:") 

206 for k, name in names: 

207 print(name + ":") 

208 

209 if k in (1, 2): 

210 worst = sort_table(metrics, (k,))[:N] 

211 else: 

212 worst = reversed(sort_table(metrics, (k,))[-N:]) 

213 

214 for n, l in enumerate(worst): 

215 fname = os.path.join(args["<database>"], l[0]) 

216 print(" %d. [%.2e] %s" % (n, l[k], fname))