Coverage for src/bob/bio/face/database/rfw.py: 20%

189 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import copy 

2import logging 

3import os 

4 

5from functools import partial 

6 

7import numpy as np 

8 

9from clapper.rc import UserDefaults 

10 

11import bob.io.base 

12 

13from bob.bio.base.database.utils import download_file, md5_hash 

14from bob.bio.base.pipelines.abstract_classes import Database 

15from bob.pipelines.sample import DelayedSample, SampleSet 

16 

17logger = logging.getLogger("bob.bio.face") 

18rc = UserDefaults("bobrc.toml") 

19 

20 

21class RFWDatabase(Database): # TODO Make this a CSVDatabase? 

22 """ 

23 Dataset interface for the Racial faces in the wild dataset: 

24 

25 The RFW is a subset of the MS-Celeb 1M dataset, and it's composed of 44332 images split into 11416 identities. 

26 There are four "race" labels in this dataset (`African`, `Asian`, `Caucasian`, and `Indian`). 

27 Furthermore, with the help of https://query.wikidata.org/ we've added information about gender and 

28 country of birth. 

29 

30 We offer two evaluation protocols. 

31 The first one, called "original" is the original protocol from its publication. It contains ~24k comparisons in total. 

32 Worth noting that this evaluation protocol has an issue. It considers only comparisons of pairs of images from the same 

33 "race". 

34 To close this gap, we've created a protocol called "idiap" that extends the original protocol to one where impostors comparisons 

35 (or non-mated) is possible. This is closed to a real-world scenario. 

36 

37 .. warning:: 

38 The following identities are associated with two races in the original dataset 

39 - m.023915 

40 - m.0z08d8y 

41 - m.0bk56n 

42 - m.04f4wpb 

43 - m.0gc2xf9 

44 - m.08dyjb 

45 - m.05y2fd 

46 - m.0gbz836 

47 - m.01pw5d 

48 - m.0cm83zb 

49 - m.02qmpkk 

50 - m.05xpnv 

51 

52 

53 For more information check: 

54 

55 .. code-block:: latex 

56 

57 @inproceedings{wang2019racial, 

58 title={Racial faces in the wild: Reducing racial bias by information maximization adaptation network}, 

59 author={Wang, Mei and Deng, Weihong and Hu, Jiani and Tao, Xunqiang and Huang, Yaohai}, 

60 booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 

61 pages={692--702}, 

62 year={2019} 

63 } 

64 

65 """ 

66 

67 name = "rfw" 

68 category = "face" 

69 dataset_protocols_name = "rfw.tar.gz" 

70 dataset_protocols_urls = [ 

71 "https://www.idiap.ch/software/bob/databases/latest/face/rfw-83549522.tar.gz", 

72 "http://www.idiap.ch/software/bob/databases/latest/face/rfw-83549522.tar.gz", 

73 ] 

74 dataset_protocols_hash = "83549522" 

75 

76 demographics_urls = [ 

77 "https://www.idiap.ch/software/bob/databases/latest/msceleb_wikidata_demographics.csv.tar.gz", 

78 "http://www.idiap.ch/software/bob/databases/latest/msceleb_wikidata_demographics.csv.tar.gz", 

79 ] 

80 

81 def __init__( 

82 self, 

83 protocol, 

84 original_directory=rc.get("bob.bio.face.rfw.directory"), 

85 **kwargs, 

86 ): 

87 if original_directory is None or not os.path.exists(original_directory): 

88 raise ValueError(f"Invalid or non existent {original_directory=}") 

89 

90 self._check_protocol(protocol) 

91 self._races = ["African", "Asian", "Caucasian", "Indian"] 

92 self.original_directory = original_directory 

93 self._default_extension = rc.get("bob.bio.face.rfw.extension", ".jpg") 

94 

95 super().__init__( 

96 protocol=protocol, 

97 annotation_type="eyes-center", 

98 fixed_positions=None, 

99 memory_demanding=False, 

100 **kwargs, 

101 ) 

102 

103 self._pairs = dict() 

104 self._first_reference_of_subject = ( 

105 dict() 

106 ) # Used with the Idiap protocol 

107 self._inverted_pairs = dict() 

108 self._id_race = dict() # ID -- > RACE 

109 self._race_ids = dict() # RACE --> ID 

110 self._landmarks = dict() 

111 self._cached_biometric_references = None 

112 self._cached_probes = None 

113 self._cached_zprobes = None 

114 self._cached_treferences = None 

115 self._cached_treferences = None 

116 self._discarded_subjects = ( 

117 [] 

118 ) # Some subjects were labeled with both races 

119 self._load_metadata(target_set="test") 

120 self._demographics = None 

121 self._demographics = self._get_demographics_dict() 

122 

123 # Setting the seed for the IDIAP PROTOCOL, 

124 # so we have a consistent set of probes 

125 self._idiap_protocol_seed = 652 

126 

127 # Number of samples used to Z-Norm and T-Norm (per race) 

128 self._nzprobes = 25 

129 self._ntreferences = 25 

130 

131 def _get_demographics_dict(self): 

132 """ 

133 Get the dictionary with GENDER and COUNTRY of birth. 

134 Data obtained using the wiki data `https://query.wikidata.org/` using the following sparql query 

135 

136 ''' 

137 SELECT ?item ?itemLabel ?genderLabel ?countryLabel WHERE { 

138 ?item wdt:P31 wd:Q5. 

139 ?item ?label "{MY_NAME_HERE}"@en . 

140 optional{ ?item wdt:P21 ?gender.} 

141 optional{ ?item wdt:P27 ?country.} 

142 SERVICE wikibase:label { bd:serviceParam wikibase:language "en". } 

143 } 

144 ''' 

145 

146 

147 """ 

148 

149 filename = ( 

150 download_file( 

151 urls=RFWDatabase.demographics_urls, 

152 destination_sub_directory="protocols/" + RFWDatabase.category, 

153 checksum="8eb0e3c93647dfa0c13fade5db96d73a", 

154 checksum_fct=md5_hash, 

155 extract=True, 

156 ) 

157 / "msceleb_wikidata_demographics.csv" 

158 ) 

159 if self._demographics is None: 

160 self._demographics = dict() 

161 with open(filename) as f: 

162 for line in f.readlines(): 

163 line = line.split(",") 

164 self._demographics[line[0]] = [ 

165 line[2], 

166 line[3].rstrip("\n"), 

167 ] 

168 

169 return self._demographics 

170 

171 def _get_subject_from_key(self, key): 

172 return key[:-5] 

173 

174 def _load_metadata(self, target_set="test"): 

175 for race in self._races: 

176 pair_file = os.path.join( 

177 self.original_directory, 

178 target_set, 

179 "txts", 

180 race, 

181 f"{race}_pairs.txt", 

182 ) 

183 

184 for line in open(pair_file).readlines(): 

185 line = line.split("\t") 

186 line[-1] = line[-1].rstrip("\n") 

187 

188 key = f"{line[0]}_000{line[1]}" 

189 subject_id = self._get_subject_from_key(key) 

190 dict_key = f"{race}/{subject_id}/{key}" 

191 

192 if subject_id not in self._id_race: 

193 self._id_race[subject_id] = race 

194 else: 

195 if ( 

196 self._id_race[subject_id] != race 

197 and subject_id not in self._discarded_subjects 

198 ): 

199 logger.warning( 

200 f"{subject_id} was already labeled as {self._id_race[subject_id]}, and it's illogical to be relabeled as {race}. " 

201 f"This seems a problem with RFW dataset, so we are removing all samples linking {subject_id} as {race}" 

202 ) 

203 self._discarded_subjects.append(subject_id) 

204 continue 

205 

206 # Positive or negative pairs 

207 if len(line) == 3: 

208 k_value = f"{line[0]}_000{line[2]}" 

209 dict_value = f"{race}/{self._get_subject_from_key(k_value)}/{k_value}" 

210 else: 

211 k_value = f"{line[2]}_000{line[3]}" 

212 dict_value = f"{race}/{self._get_subject_from_key(k_value)}/{k_value}" 

213 

214 if dict_key not in self._pairs: 

215 self._pairs[dict_key] = [] 

216 self._pairs[dict_key].append(dict_value) 

217 

218 # Picking the first reference 

219 if self.protocol == "idiap": 

220 for p in self._pairs: 

221 _, subject_id, template_id = p.split("/") 

222 if subject_id in self._first_reference_of_subject: 

223 continue 

224 self._first_reference_of_subject[subject_id] = template_id 

225 

226 # Preparing the probes 

227 self._inverted_pairs = self._invert_dict(self._pairs) 

228 self._race_ids = self._invert_dict(self._id_race) 

229 

230 def _invert_dict(self, dict_pairs): 

231 inverted_pairs = dict() 

232 

233 for k in dict_pairs: 

234 if isinstance(dict_pairs[k], list): 

235 for v in dict_pairs[k]: 

236 if v not in inverted_pairs: 

237 inverted_pairs[v] = [] 

238 inverted_pairs[v].append(k) 

239 else: 

240 v = dict_pairs[k] 

241 if v not in inverted_pairs: 

242 inverted_pairs[v] = [] 

243 inverted_pairs[v].append(k) 

244 return inverted_pairs 

245 

246 def background_model_samples(self): 

247 return [] 

248 

249 def _get_zt_samples(self, seed): 

250 cache = [] 

251 

252 # Setting the seed for the IDIAP PROTOCOL, 

253 # so we have a consistent set of probes 

254 np.random.seed(seed) 

255 

256 for race in self._races: 

257 data_dir = os.path.join( 

258 self.original_directory, "train", "data", race 

259 ) 

260 files = os.listdir(data_dir) 

261 # SHUFFLING 

262 np.random.shuffle(files) 

263 files = files[0 : self._nzprobes] 

264 

265 # RFW original data is not super organized 

266 # train data from Caucasians are stored differently 

267 if race == "Caucasian": 

268 for f in files: 

269 template_id = os.listdir(os.path.join(data_dir, f))[0] 

270 key = f"{race}/{f}/{template_id[:-4]}" 

271 cache.append( 

272 self._make_sampleset( 

273 key, target_set="train", get_demographic=False 

274 ) 

275 ) 

276 

277 else: 

278 for f in files: 

279 key = f"{race}/{race}/{f[:-4]}" 

280 cache.append( 

281 self._make_sampleset( 

282 key, target_set="train", get_demographic=False 

283 ) 

284 ) 

285 return cache 

286 

287 def zprobes(self, group="dev", proportion=1.0): 

288 if self._cached_zprobes is None: 

289 self._cached_zprobes = self._get_zt_samples( 

290 self._idiap_protocol_seed + 1 

291 ) 

292 references = list( 

293 set([s.template_id for s in self.references(group=group)]) 

294 ) 

295 for p in self._cached_zprobes: 

296 p.references = copy.deepcopy(references) 

297 

298 return self._cached_zprobes 

299 

300 def treferences(self, group="dev", proportion=1.0): 

301 if self._cached_treferences is None: 

302 self._cached_treferences = self._get_zt_samples( 

303 self._idiap_protocol_seed + 2 

304 ) 

305 

306 return self._cached_zprobes 

307 

308 def probes(self, group="dev"): 

309 self._check_group(group) 

310 if self._cached_probes is None: 

311 # Setting the seed for the IDIAP PROTOCOL, 

312 # so we have a consistent set of probes 

313 np.random.seed(self._idiap_protocol_seed) 

314 

315 self._cached_probes = [] 

316 for key in self._inverted_pairs: 

317 sset = self._make_sampleset(key) 

318 sset.references = [ 

319 key.split("/")[-1] for key in self._inverted_pairs[key] 

320 ] 

321 

322 # If it's the idiap protocol, we should 

323 # extend the list of comparisons 

324 if self.protocol == "idiap": 

325 # Picking one reference per race 

326 extra_references = [] 

327 for k in self._race_ids: 

328 # Discard samples from the same race 

329 if k == sset.race: 

330 continue 

331 

332 index = np.random.randint(len(self._race_ids[k])) 

333 random_subject_id = self._race_ids[k][index] 

334 

335 # Search for the first reference id in with this identity 

336 extra_references.append( 

337 self._first_reference_of_subject[random_subject_id] 

338 ) 

339 

340 assert len(extra_references) == 3 

341 

342 sset.references += extra_references 

343 

344 self._cached_probes.append(sset) 

345 return self._cached_probes 

346 

347 def _fetch_landmarks(self, filename, key): 

348 if key not in self._landmarks: 

349 with open(filename) as f: 

350 for line in f.readlines(): 

351 line = line.split("\t") 

352 # pattern 'm.0c7mh2_0003.jpg'[:-4] 

353 k = line[0].split("/")[-1][:-4] 

354 self._landmarks[k] = dict() 

355 self._landmarks[k]["reye"] = ( 

356 float(line[3]), 

357 float(line[2]), 

358 ) 

359 self._landmarks[k]["leye"] = ( 

360 float(line[5]), 

361 float(line[4]), 

362 ) 

363 

364 return self._landmarks[key] 

365 

366 def _make_sampleset(self, item, target_set="test", get_demographic=True): 

367 race, subject_id, template_id = item.split("/") 

368 

369 # RFW original data is not super organized 

370 # Test and train data os stored differently 

371 

372 key = f"{race}/{subject_id}/{template_id}" 

373 

374 path = ( 

375 os.path.join( 

376 self.original_directory, 

377 f"{target_set}/data/{race}", 

378 subject_id, 

379 template_id + self._default_extension, 

380 ) 

381 if (target_set == "test" or race == "Caucasian") 

382 else os.path.join( 

383 self.original_directory, 

384 f"{target_set}/data/{race}", 

385 template_id + self._default_extension, 

386 ) 

387 ) 

388 

389 annotations = ( 

390 self._fetch_landmarks( 

391 os.path.join( 

392 self.original_directory, "erratum1", "Caucasian_lmk.txt" 

393 ), 

394 template_id, 

395 ) 

396 if (target_set == "train" and race == "Caucasian") 

397 else self._fetch_landmarks( 

398 os.path.join( 

399 self.original_directory, 

400 f"{target_set}/txts/{race}/{race}_lmk.txt", 

401 ), 

402 template_id, 

403 ) 

404 ) 

405 

406 samples = [ 

407 DelayedSample( 

408 partial( 

409 bob.io.base.load, 

410 path, 

411 ), 

412 key=key, 

413 annotations=annotations, 

414 template_id=template_id, 

415 subject_id=subject_id, 

416 ) 

417 ] 

418 

419 if get_demographic: 

420 gender = self._demographics[subject_id][0] 

421 country = self._demographics[subject_id][1] 

422 

423 return SampleSet( 

424 samples, 

425 key=key, 

426 template_id=template_id, 

427 subject_id=subject_id, 

428 race=race, 

429 gender=gender, 

430 country=country, 

431 ) 

432 else: 

433 return SampleSet( 

434 samples, 

435 key=key, 

436 template_id=template_id, 

437 subject_id=subject_id, 

438 race=race, 

439 ) 

440 

441 def references(self, group="dev"): 

442 self._check_group(group) 

443 

444 if self._cached_biometric_references is None: 

445 self._cached_biometric_references = [] 

446 for key in self._pairs: 

447 self._cached_biometric_references.append( 

448 self._make_sampleset(key) 

449 ) 

450 

451 return self._cached_biometric_references 

452 

453 def all_samples(self, group="dev"): 

454 self._check_group(group) 

455 

456 return self.references() + self.probes() 

457 

458 def groups(self): 

459 return ["dev"] 

460 

461 def protocols(self): 

462 return ["original", "idiap"] 

463 

464 def _check_protocol(self, protocol): 

465 assert ( 

466 protocol in self.protocols() 

467 ), "Invalid protocol `{}` not in {}".format(protocol, self.protocols()) 

468 

469 def _check_group(self, group): 

470 assert group in self.groups(), "Invalid group `{}` not in {}".format( 

471 group, self.groups() 

472 )