Coverage for src/bob/bio/face/database/ijbc.py: 35%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import copy 

2import logging 

3import os 

4 

5from functools import partial 

6 

7import pandas as pd 

8 

9from clapper.rc import UserDefaults 

10 

11import bob.io.base 

12 

13from bob.bio.base.pipelines.abstract_classes import Database 

14from bob.pipelines import hash_string 

15from bob.pipelines.sample import DelayedSample, SampleSet 

16 

17logger = logging.getLogger(__name__) 

18rc = UserDefaults("bobrc.toml") 

19 

20 

21def _make_sample_from_template_row(row, image_directory): 

22 # Appending this key, so we can handle parallel writing correctly 

23 # paying the penalty of having duplicate checkpoint files 

24 key = os.path.splitext(row["FILENAME"])[0] + "-" + str(row["SUBJECT_ID"]) 

25 

26 return DelayedSample( 

27 load=partial( 

28 bob.io.base.load, os.path.join(image_directory, row["FILENAME"]) 

29 ), 

30 template_id=str(row["TEMPLATE_ID"]), 

31 subject_id=str(row["SUBJECT_ID"]), 

32 key=key, 

33 # gender=row["GENDER"], 

34 # indoor_outdoor=row["INDOOR_OUTDOOR"], 

35 # skintone=row["SKINTONE"], 

36 # yaw=row["YAW"], 

37 # rool=row["ROLL"], 

38 # occ1=row["OCC1"], 

39 # occ2=row["OCC2"], 

40 # occ3=row["OCC3"], 

41 # occ4=row["OCC4"], 

42 # occ5=row["OCC5"], 

43 # occ6=row["OCC6"], 

44 # occ7=row["OCC7"], 

45 # occ8=row["OCC8"], 

46 # occ9=row["OCC9"], 

47 # occ10=row["OCC10"], 

48 # occ11=row["OCC11"], 

49 # occ12=row["OCC12"], 

50 # occ13=row["OCC13"], 

51 # occ14=row["OCC14"], 

52 # occ15=row["OCC15"], 

53 # occ16=row["OCC16"], 

54 # occ17=row["OCC17"], 

55 # occ18=row["OCC18"], 

56 annotations={ 

57 "topleft": (float(row["FACE_Y"]), float(row["FACE_X"])), 

58 "bottomright": ( 

59 float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]), 

60 float(row["FACE_X"]) + float(row["FACE_WIDTH"]), 

61 ), 

62 "size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])), 

63 }, 

64 ) 

65 

66 

67def _make_sample_set_from_template_group(template_group, image_directory): 

68 samples = list( 

69 template_group.apply( 

70 _make_sample_from_template_row, 

71 axis=1, 

72 image_directory=image_directory, 

73 ) 

74 ) 

75 return SampleSet( 

76 samples, 

77 template_id=samples[0].template_id, 

78 subject_id=samples[0].subject_id, 

79 key=samples[0].template_id, 

80 ) 

81 

82 

83class IJBCDatabase(Database): # TODO Make this a CSVDatabase? 

84 """ 

85 

86 This package contains the access API and descriptions for the IARPA Janus Benchmark C -- IJB-C database. 

87 The actual raw data can be downloaded from the original web page: http://www.nist.gov/programs-projects/face-challenges (note that not everyone might be eligible for downloading the data). 

88 

89 Included in the database, there are list files defining verification as well as closed- and open-set identification protocols. 

90 For verification, two different protocols are provided. 

91 For the ``1:1`` protocol, gallery and probe templates are combined using several images and video frames for each subject. 

92 Compared gallery and probe templates share the same gender and skin tone -- these have been matched to make the comparisons more realistic and difficult. 

93 

94 For closed-set identification, the gallery of the ``1:1`` protocol is used, while probes stem from either only images, mixed images and video frames, or plain videos. 

95 For open-set identification, the same probes are evaluated, but the gallery is split into two parts, either of which is left out to provide unknown probe templates, i.e., probe templates with no matching subject in the gallery. 

96 In any case, scores are computed between all (active) gallery templates and all probes. 

97 

98 The IJB-C dataset provides additional evaluation protocols for face detection and clustering, but these are (not yet) part of this interface. 

99 

100 

101 .. warning:: 

102 

103 To use this dataset protocol, you need to have the original files of the IJBC datasets. 

104 Once you have it downloaded, please run the following command to set the path for Bob 

105 

106 .. code-block:: sh 

107 

108 bob config set bob.bio.face.ijbc.directory [IJBC PATH] 

109 

110 

111 The code below allows you to fetch the gallery and probes of the "1:1" protocol. 

112 

113 .. code-block:: python 

114 

115 >>> from bob.bio.face.database import IJBCDatabase 

116 >>> ijbc = IJBCDatabase(protocol="test1") 

117 

118 >>> # Fetching the gallery 

119 >>> references = ijbc.references() 

120 >>> # Fetching the probes 

121 >>> probes = ijbc.probes() 

122 

123 """ 

124 

125 name = "ijbc" 

126 category = "face" 

127 dataset_protocols_name = "ijbc.tar.gz" 

128 dataset_protocols_urls = [ 

129 "https://www.idiap.ch/software/bob/databases/latest/face/ijbc-????.tar.gz", 

130 "http://www.idiap.ch/software/bob/databases/latest/face/ijbc-????.tar.gz", 

131 ] 

132 dataset_protocols_hash = "????" 

133 

134 def __init__( 

135 self, 

136 protocol, 

137 original_directory=rc.get("bob.bio.face.ijbc.directory"), 

138 **kwargs, 

139 ): 

140 import warnings 

141 

142 warnings.warn( 

143 f"The {self.name} database is not yet adapted to this version of bob. Please port it or ask for it to be ported (This one actually needs to be converted to a CSVDatabase).", 

144 DeprecationWarning, 

145 ) 

146 

147 if original_directory is None or not os.path.exists(original_directory): 

148 raise ValueError( 

149 f"Invalid or non existent `original_directory`: {original_directory}" 

150 ) 

151 

152 self._check_protocol(protocol) 

153 super().__init__( 

154 protocol=protocol, 

155 annotation_type="bounding-box", 

156 fixed_positions=None, 

157 memory_demanding=True, 

158 ) 

159 

160 self.image_directory = os.path.join(original_directory, "images") 

161 self.protocol_directory = os.path.join(original_directory, "protocols") 

162 self._cached_probes = None 

163 self._cached_references = None 

164 self.hash_fn = hash_string 

165 

166 self._load_metadata(protocol) 

167 

168 # For the test4 protocols 

169 if "test4" in protocol: 

170 self.score_all_vs_all = True 

171 

172 def _load_metadata(self, protocol): 

173 # Load CSV files 

174 if protocol == "test1" or protocol == "test2": 

175 self.reference_templates = pd.read_csv( 

176 os.path.join( 

177 self.protocol_directory, protocol, "enroll_templates.csv" 

178 ) 

179 ) 

180 

181 self.probe_templates = pd.read_csv( 

182 os.path.join( 

183 self.protocol_directory, protocol, "verif_templates.csv" 

184 ) 

185 ) 

186 

187 self.matches = pd.read_csv( 

188 os.path.join(self.protocol_directory, protocol, "match.csv"), 

189 names=["ENROLL_TEMPLATE_ID", "VERIF_TEMPLATE_ID"], 

190 ).astype("str") 

191 

192 # TODO: temporarily disabling the metadata 

193 """ 

194 self.metadata = pd.read_csv( 

195 os.path.join(self.protocol_directory, "ijbc_metadata_with_age.csv"), 

196 usecols=[ 

197 "SUBJECT_ID", 

198 "FILENAME", 

199 "FACE_X", 

200 "FACE_Y", 

201 "FACE_WIDTH", 

202 "FACE_HEIGHT", 

203 "SIGHTING_ID", 

204 "FACIAL_HAIR", 

205 "AGE", 

206 "INDOOR_OUTDOOR", 

207 "SKINTONE", 

208 "GENDER", 

209 "YAW", 

210 "ROLL", 

211 ] 

212 + [f"OCC{i}" for i in range(1, 19)], 

213 ) 

214 

215 # LEFT JOIN WITH METADATA 

216 self.probe_templates = pd.merge( 

217 self.probe_templates, 

218 self.metadata, 

219 on=[ 

220 "SUBJECT_ID", 

221 "FILENAME", 

222 "FACE_X", 

223 "FACE_Y", 

224 "FACE_WIDTH", 

225 "FACE_HEIGHT", 

226 ], 

227 how="left", 

228 ) 

229 

230 # LEFT JOIN WITH METADATA 

231 self.reference_templates = pd.merge( 

232 self.reference_templates, 

233 self.metadata, 

234 on=[ 

235 "SUBJECT_ID", 

236 "FILENAME", 

237 "FACE_X", 

238 "FACE_Y", 

239 "FACE_WIDTH", 

240 "FACE_HEIGHT", 

241 ], 

242 how="left", 

243 ) 

244 """ 

245 

246 elif "test4" in protocol: 

247 gallery_file = ( 

248 "gallery_G1.csv" if "G1" in protocol else "gallery_G2.csv" 

249 ) 

250 

251 self.reference_templates = pd.read_csv( 

252 os.path.join(self.protocol_directory, "test4", gallery_file) 

253 ) 

254 

255 self.probe_templates = pd.read_csv( 

256 os.path.join(self.protocol_directory, "test4", "probes.csv") 

257 ) 

258 

259 self.matches = None 

260 

261 else: 

262 raise ValueError( 

263 f"Protocol `{protocol}` not supported. We do accept merge requests :-)" 

264 ) 

265 

266 def background_model_samples(self): 

267 return None 

268 

269 def probes(self, group="dev"): 

270 self._check_group(group) 

271 if self._cached_probes is None: 

272 logger.info( 

273 "Loading probes. This operation might take some minutes" 

274 ) 

275 

276 self._cached_probes = list( 

277 self.probe_templates.groupby("TEMPLATE_ID").apply( 

278 _make_sample_set_from_template_group, 

279 image_directory=self.image_directory, 

280 ) 

281 ) 

282 

283 # Wiring probes with references 

284 if self.protocol == "test1" or self.protocol == "test2": 

285 # Link probes to the references they have to be compared with 

286 # We might make that faster if we manage to write it as a Panda instruction 

287 grouped_matches = self.matches.groupby("VERIF_TEMPLATE_ID") 

288 for probe_sampleset in self._cached_probes: 

289 probe_sampleset.references = list( 

290 grouped_matches.get_group(probe_sampleset.template_id)[ 

291 "ENROLL_TEMPLATE_ID" 

292 ] 

293 ) 

294 elif "test4" in self.protocol: 

295 references = [s.template_id for s in self.references()] 

296 # You compare with all biometric references 

297 for probe_sampleset in self._cached_probes: 

298 probe_sampleset.references = copy.deepcopy(references) 

299 pass 

300 

301 else: 

302 raise ValueError(f"Invalid protocol: {self.protocol}") 

303 

304 return self._cached_probes 

305 

306 def references(self, group="dev"): 

307 self._check_group(group) 

308 if self._cached_references is None: 

309 logger.info( 

310 "Loading templates. This operation might take some minutes" 

311 ) 

312 

313 self._cached_references = list( 

314 self.reference_templates.groupby("TEMPLATE_ID").apply( 

315 _make_sample_set_from_template_group, 

316 image_directory=self.image_directory, 

317 ) 

318 ) 

319 

320 return self._cached_references 

321 

322 def all_samples(self, group="dev"): 

323 self._check_group(group) 

324 

325 return self.references() + self.probes() 

326 

327 def groups(self): 

328 return ["dev"] 

329 

330 def protocols(self): 

331 return ["test1", "test2", "test4-G1", "test4-G2"] 

332 

333 def _check_protocol(self, protocol): 

334 assert ( 

335 protocol in self.protocols() 

336 ), "Invalid protocol `{}` not in {}".format(protocol, self.protocols()) 

337 

338 def _check_group(self, group): 

339 assert group in self.groups(), "Invalid group `{}` not in {}".format( 

340 group, self.groups() 

341 )