Coverage for src/bob/bio/face/pytorch/datasets/demographics.py: 22%

330 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> 

4 

5""" 

6Datasets that handles demographic information 

7 

8""" 

9 

10 

11import itertools 

12import logging 

13import os 

14import random 

15 

16import cloudpickle 

17import numpy as np 

18import pandas as pd 

19import torch 

20 

21from clapper.rc import UserDefaults 

22from torch.utils.data import Dataset 

23 

24import bob.io.base 

25 

26from bob.bio.base.database.utils import download_file, md5_hash 

27from bob.bio.face.database import ( 

28 MEDSDatabase, 

29 MobioDatabase, 

30 MorphDatabase, 

31 RFWDatabase, 

32 VGG2Database, 

33) 

34 

35logger = logging.getLogger(__name__) 

36rc = UserDefaults("bobrc.toml") 

37 

38 

39class DemographicTorchDataset(Dataset): 

40 """ 

41 Pytorch base dataset that handles demographic information 

42 

43 Parameters 

44 ---------- 

45 

46 bob_dataset: 

47 Instance of a bob database object 

48 

49 transform=None 

50 

51 """ 

52 

53 def __init__(self, bob_dataset, transform=None): 

54 self.bob_dataset = bob_dataset 

55 self.transform = transform 

56 self.load_bucket() 

57 

58 def __len__(self): 

59 return len(self.bucket) 

60 

61 @property 

62 def n_classes(self): 

63 return len(self.labels) 

64 

65 @property 

66 def n_samples(self): 

67 return len(self.bucket) 

68 

69 @property 

70 def demographic_keys(self): 

71 return self._demographic_keys 

72 

73 def __getitem__(self, idx): 

74 """ 

75 It dumps a dictionary containing the following keys: data, label, demography 

76 

77 """ 

78 

79 sample = self.bucket[idx] 

80 

81 image = ( 

82 sample.data 

83 if self.transform is None 

84 else self.transform(sample.data) 

85 ) 

86 

87 # image = image.astype("float32") 

88 

89 label = self.labels[sample.subject_id] 

90 

91 demography = self.get_demographics(sample) 

92 

93 return {"data": image, "label": label, "demography": demography} 

94 

95 def count_subjects_per_demographics(self): 

96 """ 

97 Count the number of subjects per demographics 

98 """ 

99 all_demographics = list(self.subject_demographic.values()) 

100 

101 # Number of subjects per demographic 

102 subjects_per_demographics = dict( 

103 [ 

104 (d, sum(np.array(all_demographics) == d)) 

105 for d in set(all_demographics) 

106 ] 

107 ) 

108 

109 return subjects_per_demographics 

110 

111 def get_demographic_weights(self, as_dict=True): 

112 """ 

113 Compute the inverse weighting for each demographic group. 

114 

115 

116 .. warning:: 

117 This is not the same function as `get_demographic_class_weights`. 

118 

119 Parameters 

120 ---------- 

121 If `True` will return the weights as a dict. 

122 

123 """ 

124 n_identities = len(self.subject_demographic) 

125 

126 # Number of subjects per demographic 

127 subjects_per_demographics = self.count_subjects_per_demographics() 

128 

129 # INverse probability (1-p_i)/p_i 

130 demographic_weights = dict() 

131 for i in subjects_per_demographics: 

132 p_i = subjects_per_demographics[i] / n_identities 

133 demographic_weights[i] = (1 - p_i) / p_i 

134 

135 p_accumulator = sum(demographic_weights.values()) 

136 # Scaling the inverse probability 

137 for i in demographic_weights: 

138 demographic_weights[i] /= p_accumulator 

139 

140 # Return as a dictionary 

141 if as_dict: 

142 return demographic_weights 

143 

144 # Returning as a list (this is more aproppriated for NN training) 

145 return [demographic_weights[k] for k in self.demographic_keys] 

146 

147 def get_demographic_class_weights(self): 

148 """ 

149 Compute the class weights based on the demographics 

150 

151 Returns 

152 ------- 

153 weights: list 

154 A list containing the weights for each class 

155 """ 

156 

157 subjects_per_demographics = self.count_subjects_per_demographics() 

158 demographic_weights = self.get_demographic_weights() 

159 

160 weights = [ 

161 demographic_weights[v] / subjects_per_demographics[v] 

162 for k, v in self.subject_demographic.items() 

163 ] 

164 

165 return torch.Tensor(weights) 

166 

167 

168class MedsTorchDataset(DemographicTorchDataset): 

169 """ 

170 MEDS torch interface 

171 

172 .. warning:: 

173 Unfortunatelly, in this dataset there are several identities that has only ONE sample. 

174 Hence, it is impossible to properly use this dataset to do contrastive learning, for instance. 

175 If this is thecase, please set `take_from_znorm=True`, so the `dev` or the `eval` sets are used. 

176 

177 

178 Parameters 

179 ---------- 

180 

181 protocol: str 

182 One of the MEDS available protocols, check :py:class:`bob.bio.face.database.MEDSDatabase` 

183 

184 database_path: str 

185 Database path 

186 

187 database_extension: str 

188 Database extension 

189 

190 transform: callable 

191 Transformation function to the input sample 

192 

193 take_from_znorm: bool 

194 If `True`, it will take the samples from `treferences` and `zprobes` methods that comes from the training set 

195 If `False`, it will take the samples from `references` and `probes` methods. Then, the variable `group` is considered. 

196 

197 group: str 

198 

199 """ 

200 

201 def __init__( 

202 self, 

203 protocol, 

204 database_path, 

205 database_extension=".h5", 

206 transform=None, 

207 take_from_znorm=False, 

208 group="dev", 

209 ): 

210 bob_dataset = MEDSDatabase( 

211 protocol=protocol, 

212 dataset_original_directory=database_path, 

213 dataset_original_extension=database_extension, 

214 ) 

215 self.take_from_znorm = take_from_znorm 

216 self.group = group 

217 super().__init__(bob_dataset, transform=transform) 

218 

219 def load_bucket(self): 

220 self._target_metadata = "rac" 

221 

222 if self.take_from_znorm: 

223 self.bucket = [ 

224 s for sset in self.bob_dataset.zprobes() for s in sset 

225 ] 

226 self.bucket += [ 

227 s for sset in self.bob_dataset.treferences() for s in sset 

228 ] 

229 else: 

230 self.bucket = [ 

231 s 

232 for sset in self.bob_dataset.probes(group=self.group) 

233 for s in sset 

234 ] 

235 self.bucket += [ 

236 s 

237 for sset in self.bob_dataset.references(group=self.group) 

238 for s in sset 

239 ] 

240 

241 offset = 0 

242 self.labels = dict() 

243 self.subject_demographic = dict() 

244 

245 for s in self.bucket: 

246 if s.subject_id not in self.labels: 

247 self.labels[s.subject_id] = offset 

248 self.subject_demographic[s.subject_id] = getattr( 

249 s, self._target_metadata 

250 ) 

251 offset += 1 

252 

253 metadata_keys = set(self.subject_demographic.values()) 

254 self._demographic_keys = dict( 

255 zip(metadata_keys, range(len(metadata_keys))) 

256 ) 

257 

258 def get_demographics(self, sample): 

259 demographic_key = getattr(sample, "rac") 

260 return self._demographic_keys[demographic_key] 

261 

262 

263class VGG2TorchDataset(DemographicTorchDataset): 

264 """ 

265 VGG2 for torch. 

266 

267 This interface make usage of :any:`bob.bio.face.database.VGG2Database`. 

268 

269 The "race" labels below were annotated by the students from the period 2018-2020. Race labels taken from: MasterEBTSv10.0.809302017_Final.pdf 

270 

271 - A: Asian in general (Chinese, Japanese, Filipino, Korean, Polynesian, Indonesian, Samoan, or any other Pacific Islander 

272 - B: A person having origins in any of the black racial groups of Africa 

273 - I: American Indian, Asian Indian, Eskimo, or Alaskan native 

274 - U: Of indeterminable race 

275 - W: Caucasian, Mexican, Puerto Rican, Cuban, Central or South American, or other Spanish culture or origin, Regardless of race 

276 - N: None of the above 

277 

278 

279 Gender information was taken from the original dataset 

280 There are the following genders available: 

281 - male 

282 - female 

283 

284 

285 .. note:: 

286 Some important information about this interface. 

287 We have the following statistics: 

288 - n_classes = 8631 

289 - n_demographics: 12 ['m-A': 0, 'm-B': 1, 'm-I': 2, 'm-U': 3, 'm-W': 4, 'm-N': 5, 'f-A': 6, 'f-B': 7, 'f-I': 8, 'f-U': 9, 'f-W': 10, 'f-N': 11] 

290 

291 

292 .. note:: 

293 

294 Follow the distribution the combination of race and gender demographics 

295 {'m-B': 552, 'm-U': 64, 'm-W': 3903, 'f-W': 2657, 'f-A': 286, 'f-U': 34, 'f-I': 298, 'f-N': 2, 'f-B': 200, 'm-N': 1, 'm-I': 366, 'm-A': 268} 

296 

297 Note that `m-N` has 1 subject and 'f-N' has 2 subjects. 

298 For this reason, we are removing this race from this interface. 

299 We can't learn anything from one sample. 

300 

301 

302 Parameters 

303 ---------- 

304 database_path: str 

305 Path containing the raw data 

306 

307 database_extension: 

308 

309 load_bucket_from_cache: bool 

310 If set, it will load the list of available samples from the cache 

311 

312 train: bool 

313 If set it will prepare a bucket for training. 

314 

315 include_u_n: bool 

316 If `True` it will include 'U' (Undefined) and 'N' (None) on the list of races. 

317 

318 

319 """ 

320 

321 def __init__( 

322 self, 

323 protocol, 

324 database_path, 

325 database_extension=".jpg", 

326 transform=None, 

327 load_bucket_from_cache=True, 

328 include_u_n=False, 

329 train=True, 

330 ): 

331 bob_dataset = VGG2Database( 

332 protocol=protocol, 

333 dataset_original_directory=database_path, 

334 dataset_original_extension=database_extension, 

335 ) 

336 self.load_bucket_from_cache = load_bucket_from_cache 

337 

338 # Percentage of the samples used for training 

339 self._percentage_for_training = 0.8 

340 self.train = train 

341 

342 # All possible metadata 

343 self._possible_genders = ["m", "f"] 

344 

345 # self._possible_races = ["A", "B", "I", "U", "W", "N"] 

346 self._possible_races = ["A", "B", "I", "W"] 

347 if include_u_n: 

348 self._possible_races += ["U", "N"] 

349 

350 super().__init__(bob_dataset, transform=transform) 

351 

352 def decode_race(self, race): 

353 # return race if race in self._possible_races else "N" 

354 return race if race in self._possible_races else "W" 

355 

356 def get_key(self, sample): 

357 return f"{sample.gender}-{self.decode_race(sample.race)}" 

358 

359 def get_cache_path(self): 

360 filename = ( 

361 "vgg2_short_cached_bucket.pickle" 

362 if self.bob_dataset.protocol == "vgg2-short" 

363 else "vgg2_full_cached_bucket.pickle" 

364 ) 

365 

366 return os.path.join( 

367 rc.get( 

368 "bob_data_folder", 

369 os.path.join(os.path.expanduser("~"), "bob_data"), 

370 ), 

371 "datasets", 

372 f"{filename}", 

373 ) 

374 

375 def cache_bucket(self, bucket): 

376 """ 

377 Cache the list of samples into a temporary directory 

378 """ 

379 bucket_filename = self.get_cache_path() 

380 os.makedirs(os.path.dirname(bucket_filename), exist_ok=True) 

381 with open(bucket_filename, "wb") as f: 

382 cloudpickle.dump(bucket, f) 

383 

384 def load_cached_bucket(self): 

385 bucket_filename = self.get_cache_path() 

386 with open(bucket_filename, "rb") as f: 

387 bucket = cloudpickle.load(f) 

388 return bucket 

389 

390 def load_bucket(self): 

391 # Defining the demographics keys 

392 self._demographic_keys = [ 

393 f"{gender}-{race}" 

394 for gender in self._possible_genders 

395 for race in self._possible_races 

396 ] 

397 self._demographic_keys = dict( 

398 [(d, i) for i, d in enumerate(self._demographic_keys)] 

399 ) 

400 

401 # Loading the buket from cache 

402 if self.load_bucket_from_cache and os.path.exists( 

403 self.get_cache_path() 

404 ): 

405 self.bucket = self.load_cached_bucket() 

406 else: 

407 self.bucket = [ 

408 s for s in self.bob_dataset.background_model_samples() 

409 ] 

410 # Caching the bucket 

411 self.cache_bucket(self.bucket) 

412 

413 # Mapping subject_id with labels 

414 self.labels = sorted(list(set([s.subject_id for s in self.bucket]))) 

415 self.labels = dict([(l, i) for i, l in enumerate(self.labels)]) 

416 

417 # Spliting the bucket into training and developement set 

418 all_indexes = np.array([self.labels[x.subject_id] for x in self.bucket]) 

419 indexes = [] 

420 if self.train: 

421 for i in range(self.n_classes): 

422 ind = np.where(all_indexes == i)[0] 

423 indexes += list( 

424 ind[ 

425 0 : int( 

426 np.floor(len(ind) * self._percentage_for_training) 

427 ) 

428 ] 

429 ) 

430 else: 

431 for i in range(self.n_classes): 

432 ind = np.where(all_indexes == i)[0] 

433 indexes += list( 

434 ind[ 

435 int( 

436 np.floor(len(ind) * self._percentage_for_training) 

437 ) : 

438 ] 

439 ) 

440 

441 # Redefining the bucket 

442 self.bucket = list(np.array(self.bucket)[indexes]) 

443 

444 # Mapping subject and demographics for fast access 

445 self.subject_demographic = dict() 

446 

447 for s in self.bucket: 

448 if s.subject_id not in self.subject_demographic: 

449 self.subject_demographic[s.subject_id] = self.get_key(s) 

450 

451 def get_demographics(self, sample): 

452 demographic_key = self.get_key(sample) 

453 return self._demographic_keys[demographic_key] 

454 

455 

456class MorphTorchDataset(DemographicTorchDataset): 

457 """ 

458 MORPH torch interface 

459 

460 .. warning:: 

461 Unfortunatelly, in this dataset there are several identities that has only ONE sample. 

462 Hence, it is impossible to properly use this dataset to do contrastive learning, for instance. 

463 If this is thecase, please set `take_from_znorm=True`, so the `dev` or the `eval` sets are used. 

464 

465 

466 Parameters 

467 ---------- 

468 

469 protocol: str 

470 One of the Morph available protocols, check :py:class:`bob.bio.face.database.MEDSDatabase` 

471 

472 database_path: str 

473 Database path 

474 

475 database_extension: str 

476 Database extension 

477 

478 transform: callable 

479 Transformation function to the input sample 

480 

481 take_from_znorm: bool 

482 If `True`, it will take the samples from `treferences` and `zprobes` methods that comes from the training set 

483 If `False`, it will take the samples from `references` and `probes` methods. Then, the variable `group` is considered. 

484 

485 group: str 

486 

487 """ 

488 

489 def __init__( 

490 self, 

491 protocol, 

492 database_path, 

493 database_extension=".h5", 

494 transform=None, 

495 take_from_znorm=True, 

496 group="dev", 

497 ): 

498 bob_dataset = MorphDatabase( 

499 protocol=protocol, 

500 dataset_original_directory=database_path, 

501 dataset_original_extension=database_extension, 

502 ) 

503 self.take_from_znorm = take_from_znorm 

504 self.group = group 

505 

506 super().__init__(bob_dataset, transform=transform) 

507 

508 def load_bucket(self): 

509 # Morph dataset has an intersection in between zprobes and treferences 

510 # Those are the 

511 self.excluding_list = [ 

512 "190276", 

513 "332158", 

514 "111942", 

515 "308129", 

516 "334074", 

517 "350814", 

518 "131677", 

519 "168724", 

520 "276055", 

521 "275589", 

522 "286810", 

523 ] 

524 

525 if self.take_from_znorm: 

526 self.bucket = [ 

527 s for sset in self.bob_dataset.zprobes() for s in sset 

528 ] 

529 self.bucket += [ 

530 s 

531 for sset in self.bob_dataset.treferences() 

532 for s in sset 

533 if sset.subject_id not in self.excluding_list 

534 ] 

535 else: 

536 self.bucket = [ 

537 s 

538 for sset in self.bob_dataset.probes(group=self.group) 

539 for s in sset 

540 ] 

541 self.bucket += [ 

542 s 

543 for sset in self.bob_dataset.references(group=self.group) 

544 for s in sset 

545 ] 

546 

547 offset = 0 

548 self.labels = dict() 

549 self.subject_demographic = dict() 

550 

551 for s in self.bucket: 

552 if s.subject_id not in self.labels: 

553 self.labels[s.subject_id] = offset 

554 self.subject_demographic[s.subject_id] = f"{s.rac}-{s.sex}" 

555 offset += 1 

556 

557 metadata_keys = set(self.subject_demographic.values()) 

558 self._demographic_keys = dict( 

559 zip(metadata_keys, range(len(metadata_keys))) 

560 ) 

561 

562 def get_demographics(self, sample): 

563 demographic_key = f"{sample.rac}-{sample.sex}" 

564 return self._demographic_keys[demographic_key] 

565 

566 

567class RFWTorchDataset(DemographicTorchDataset): 

568 def __init__( 

569 self, protocol, database_path, database_extension=".h5", transform=None 

570 ): 

571 bob_dataset = RFWDatabase( 

572 protocol=protocol, 

573 dataset_original_directory=database_path, 

574 dataset_original_extension=database_extension, 

575 ) 

576 super().__init__(bob_dataset, transform=transform) 

577 

578 def load_demographics(self): 

579 target_metadata = "race" 

580 metadata_keys = set( 

581 [ 

582 getattr(sset, target_metadata) 

583 for sset in self.bob_dataset.zprobes() 

584 ] 

585 + [ 

586 getattr(sset, target_metadata) 

587 for sset in self.bob_dataset.treferences() 

588 ] 

589 ) 

590 metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys)))) 

591 return metadata_keys 

592 

593 def get_demographics(self, sample): 

594 demographic_key = getattr(sample, "race") 

595 return self._demographic_keys[demographic_key] 

596 

597 

598class MobioTorchDataset(DemographicTorchDataset): 

599 def __init__( 

600 self, protocol, database_path, database_extension=".h5", transform=None 

601 ): 

602 bob_dataset = MobioDatabase( 

603 protocol=protocol, 

604 dataset_original_directory=database_path, 

605 dataset_original_extension=database_extension, 

606 ) 

607 

608 super().__init__(bob_dataset, transform=transform) 

609 

610 def load_bucket(self): 

611 self._target_metadata = "gender" 

612 self.bucket = [s for s in self.bob_dataset.background_model_samples()] 

613 offset = 0 

614 self.labels = dict() 

615 self.subject_demographic = dict() 

616 

617 for s in self.bucket: 

618 if s.subject_id not in self.labels: 

619 self.labels[s.subject_id] = offset 

620 self.subject_demographic[s.subject_id] = getattr( 

621 s, self._target_metadata 

622 ) 

623 offset += 1 

624 

625 metadata_keys = set(self.subject_demographic.values()) 

626 self._demographic_keys = dict( 

627 zip(metadata_keys, range(len(metadata_keys))) 

628 ) 

629 

630 def __len__(self): 

631 return len(self.bucket) 

632 

633 def get_demographics(self, sample): 

634 demographic_key = getattr(sample, self._target_metadata) 

635 return self._demographic_keys[demographic_key] 

636 

637 

638class MSCelebTorchDataset(DemographicTorchDataset): 

639 """ 

640 This interface make usage of a CSV file containing gender and 

641 RACE annotations available at. 

642 

643 The "race" labels below were annotated by the students from the period 2018-2020. Race labels taken from: MasterEBTSv10.0.809302017_Final.pdf 

644 

645 - A: Asian in general (Chinese, Japanese, Filipino, Korean, Polynesian, Indonesian, Samoan, or any other Pacific Islander 

646 - B: A person having origins in any of the black racial groups of Africa 

647 - I: American Indian, Asian Indian, Eskimo, or Alaskan native 

648 - U: Of indeterminable race 

649 - W: Caucasian, Mexican, Puerto Rican, Cuban, Central or South American, or other Spanish culture or origin, Regardless of race 

650 - N: None of the above 

651 

652 

653 Gender and country information taken from the wiki data: https://www.wikidata.org/wiki/Wikidata:Main_Page 

654 There are the following genders available: 

655 - male 

656 - female 

657 - other 

658 

659 

660 .. note:: 

661 Some important information about this interface. 

662 If `include_unknow_demographics==False` we will have the following statistics: 

663 - n_classes = 81279 

664 - n_demographics: 15 ['male-A', 'male-B', 'male-I', 'male-U', 'male-W', 'female-A', 'female-B', 'female-I', 'female-U', 'female-W', 'other-A', 'other-B', 'other-I', 'other-U', 'other-W'] 

665 

666 

667 If `include_unknow_demographics==True` we will have the following statistics: 

668 - n_classes = 89735 

669 - n_demographics: 18 ['male-A', 'male-B', 'male-I', 'male-N', 'male-U', 'male-W', 'female-A', 'female-B', 'female-I', 'female-N', 'female-U', 'female-W', 'other-A', 'other-B', 'other-I', 'other-N', 'other-U', 'other-W'] 

670 

671 

672 

673 Parameters 

674 ---------- 

675 database_path: str 

676 Path containing the raw data 

677 

678 database_extension: 

679 

680 idiap_path: bool 

681 If set, it will use the idiap standard relative path to load the data (e.g. [BASE_PATH]/chunk_[n]/[user_id]) 

682 

683 include_unknow_demographics: bool 

684 If set, it will include subjects whose race was set to `N` (None of the above) 

685 

686 load_bucket_from_cache: bool 

687 If set, it will load the list of available samples from the cache 

688 

689 

690 """ 

691 

692 def __init__( 

693 self, 

694 database_path, 

695 database_extension=".png", 

696 idiap_path=True, 

697 include_unknow_demographics=False, 

698 load_bucket_from_cache=True, 

699 transform=None, 

700 ): 

701 self.idiap_path = idiap_path 

702 self.database_path = database_path 

703 self.database_extension = database_extension 

704 self.include_unknow_demographics = include_unknow_demographics 

705 self.load_bucket_from_cache = load_bucket_from_cache 

706 self.transform = transform 

707 

708 # Private keys 

709 self._possible_genders = ["male", "female", "other"] 

710 

711 urls = MSCelebTorchDataset.urls() 

712 filename = ( 

713 download_file( 

714 urls=urls, 

715 destination_filename="msceleb_race_wikidata.tar.gz", 

716 checksum="76339d73f352faa00c155f7040e772bb", 

717 checksum_fct=md5_hash, 

718 extract=True, 

719 ) 

720 / "msceleb_race_wikidata.csv" 

721 ) 

722 

723 self.load_bucket(filename) 

724 

725 @staticmethod 

726 def urls(): 

727 return [ 

728 "https://www.idiap.ch/software/bob/databases/latest/msceleb_race_wikidata.tar.gz", 

729 "http://www.idiap.ch/software/bob/databases/latest/msceleb_race_wikidata.tar.gz", 

730 ] 

731 

732 def get_cache_path(self): 

733 filename = ( 

734 "msceleb_cached_bucket_WITH_unknow_demographics.csv" 

735 if self.include_unknow_demographics 

736 else "msceleb_cached_bucket_WITHOUT_unknow_demographics.csv" 

737 ) 

738 

739 return os.path.join( 

740 rc.get( 

741 "bob_data_folder", 

742 os.path.join(os.path.expanduser("~"), "bob_data"), 

743 ), 

744 "datasets", 

745 f"{filename}", 

746 ) 

747 

748 def cache_bucket(self, bucket): 

749 """ 

750 Cache the list of samples into a temporary directory 

751 """ 

752 bucket_filename = self.get_cache_path() 

753 os.makedirs(os.path.dirname(bucket_filename), exist_ok=True) 

754 with open(bucket_filename, "w") as f: 

755 for b in bucket: 

756 f.write(f"{b}\n") 

757 

758 def load_cached_bucket(self): 

759 """ 

760 Load the bucket from the cache 

761 """ 

762 bucket_filename = self.get_cache_path() 

763 return [f.rstrip("\n") for f in open(bucket_filename).readlines()] 

764 

765 def __len__(self): 

766 return len(self.bucket) 

767 

768 def load_bucket(self, csv_filename): 

769 dataframe = pd.read_csv(csv_filename) 

770 

771 # Possible races 

772 # {'A', 'B', 'I', 'N', 'U', 'W', nan} 

773 

774 filtered_dataframe = ( 

775 dataframe.loc[ 

776 (dataframe.RACE == "A") 

777 | (dataframe.RACE == "B") 

778 | (dataframe.RACE == "I") 

779 | (dataframe.RACE == "U") 

780 | (dataframe.RACE == "W") 

781 | (dataframe.RACE == "N") 

782 ] 

783 if self.include_unknow_demographics 

784 else dataframe.loc[ 

785 (dataframe.RACE == "A") 

786 | (dataframe.RACE == "B") 

787 | (dataframe.RACE == "I") 

788 | (dataframe.RACE == "U") 

789 | (dataframe.RACE == "W") 

790 ] 

791 ) 

792 

793 filtered_dataframe_list = filtered_dataframe[ 

794 ["idiap_chunk", "ID"] 

795 ].to_csv() 

796 

797 # Defining the number of classes 

798 subject_relative_paths = [ 

799 os.path.join(ll.split(",")[1], ll.split(",")[2]) 

800 for ll in filtered_dataframe_list.split("\n")[1:-1] 

801 ] 

802 

803 if self.load_bucket_from_cache and os.path.exists( 

804 self.get_cache_path() 

805 ): 

806 self.bucket = self.load_cached_bucket() 

807 else: 

808 # Defining all images 

809 logger.warning( 

810 f"Fetching all samples paths on the fly. This might take some minutes." 

811 f"Then this will be cached in {self.get_cache_path()} and loaded from this cache" 

812 ) 

813 

814 self.bucket = [ 

815 os.path.join(subject, f) 

816 for subject in subject_relative_paths 

817 for f in os.listdir(os.path.join(self.database_path, subject)) 

818 if f[-4:] == self.database_extension 

819 ] 

820 self.cache_bucket(self.bucket) 

821 

822 self.labels = dict( 

823 [ 

824 (k.split("/")[-1], i) 

825 for i, k in enumerate(subject_relative_paths) 

826 ] 

827 ) 

828 

829 # Setting the possible demographics and the demographic keys 

830 filtered_dataframe = filtered_dataframe.set_index("ID") 

831 self.metadata = filtered_dataframe[["GENDER", "RACE"]].to_dict( 

832 orient="index" 

833 ) 

834 

835 self._demographic_keys = [ 

836 f"{gender}-{race}" 

837 for gender in self._possible_genders 

838 for race in sorted(set(filtered_dataframe["RACE"])) 

839 ] 

840 self._demographic_keys = dict( 

841 [(d, i) for i, d in enumerate(self._demographic_keys)] 

842 ) 

843 

844 # Creating a map between the subject and the demographic 

845 self.subject_demographic = dict( 

846 [(m, self.get_demographics(m)) for m in self.metadata] 

847 ) 

848 

849 def get_demographics(self, subject_id): 

850 race = self.metadata[subject_id]["RACE"] 

851 gender = self.metadata[subject_id]["GENDER"] 

852 

853 gender = "other" if gender != "male" and gender != "female" else gender 

854 

855 return self._demographic_keys[f"{gender}-{race}"] 

856 

857 def __getitem__(self, idx): 

858 sample = self.bucket[idx] 

859 

860 subject_id = sample.split("/")[-2] 

861 

862 # Transforming the image 

863 image = bob.io.base.load(os.path.join(self.database_path, sample)) 

864 

865 image = image if self.transform is None else self.transform(image) 

866 

867 label = self.labels[subject_id] 

868 

869 # Getting the demographics 

870 

871 demography = self.get_demographics(subject_id) 

872 

873 return {"data": image, "label": label, "demography": demography} 

874 

875 

876class SiameseDemographicWrapper(Dataset): 

877 """ 

878 This class wraps the current demographic interface and 

879 dumps random positive and negative pairs of samples 

880 

881 """ 

882 

883 def __init__( 

884 self, 

885 demographic_dataset, 

886 max_positive_pairs_per_subject=20, 

887 negative_pairs_per_subject=3, 

888 dense_negatives=False, 

889 ): 

890 self.demographic_dataset = demographic_dataset 

891 self.max_positive_pairs_per_subject = max_positive_pairs_per_subject 

892 self.negative_pairs_per_subject = negative_pairs_per_subject 

893 

894 # Creating a bucket mapping the items of the bucket with their respective identities 

895 self.siamese_bucket = dict() 

896 for b in demographic_dataset.bucket: 

897 if b.subject_id not in self.siamese_bucket: 

898 self.siamese_bucket[b.subject_id] = [] 

899 

900 self.siamese_bucket[b.subject_id].append(b) 

901 

902 positive_pairs = self.create_positive_pairs() 

903 if dense_negatives: 

904 negative_pairs = self.create_dense_negative_pairs() 

905 else: 

906 negative_pairs = self.create_light_negative_pairs() 

907 

908 # Redefining the bucket 

909 self.siamese_bucket = negative_pairs + positive_pairs 

910 

911 self.labels = np.hstack( 

912 (np.zeros(len(negative_pairs)), np.ones(len(positive_pairs))) 

913 ) 

914 

915 pass 

916 

917 def __len__(self): 

918 return len(self.siamese_bucket) 

919 

920 def create_positive_pairs(self): 

921 # Creating positive pairs for each identity 

922 positives = [] 

923 random.seed(0) 

924 for b in self.siamese_bucket: 

925 samples = self.siamese_bucket[b] 

926 random.shuffle(samples) 

927 

928 # All possible pair combinations 

929 samples = itertools.combinations(samples, 2) 

930 

931 positives += [ 

932 s 

933 for s in list(samples)[0 : self.max_positive_pairs_per_subject] 

934 ] 

935 pass 

936 

937 return positives 

938 

939 def create_dense_negative_pairs(self): 

940 """ 

941 Creating negative pairs. 

942 Here we create only negative pairs from the same demographic group, 

943 since we know that pairs from different demographics leads to 

944 poor scores 

945 

946 

947 .. warning: 

948 The list of negative pairs is dense. 

949 For each combination of subjects for a particular demographic, 

950 we will take `negative_pairs_per_subject` samples. 

951 Hence, the number of negative pairs can explode as a function 

952 of number of subjects. 

953 For example, a combination pairs with 1000 identities gives us 

954 499500 pairs. Taking `3` pairs of images for these combinations 

955 of identities will give us ~1.5M negative pairs. 

956 Hence, be careful with that. 

957 

958 """ 

959 

960 # Inverting subject 

961 random.seed(0) 

962 negatives = [] 

963 

964 # Creating the dictionary containing the demographics--> subjects 

965 demographic_subject = dict() 

966 for k, v in self.demographic_dataset.subject_demographic.items(): 

967 demographic_subject[v] = demographic_subject.get(v, []) + [k] 

968 

969 # For each demographic, pic the negative pairs 

970 for d in demographic_subject: 

971 subject_combinations = itertools.combinations( 

972 demographic_subject[d], 2 

973 ) 

974 

975 for s_c in subject_combinations: 

976 subject_i = self.siamese_bucket[s_c[0]] 

977 subject_j = self.siamese_bucket[s_c[1]] 

978 random.shuffle(subject_i) 

979 random.shuffle(subject_j) 

980 

981 # All possible combinations 

982 for i, p in enumerate(itertools.product(subject_i, subject_j)): 

983 if i == self.negative_pairs_per_subject: 

984 break 

985 negatives += ((p[0], p[1]),) 

986 

987 return negatives 

988 

989 def create_light_negative_pairs(self): 

990 """ 

991 Creating negative pairs. 

992 Here we create only negative pairs from the same demographic group, 

993 since we know that pairs from different demographics leads to 

994 poor scores 

995 

996 .. warning: 

997 This function generates a light set of negative pairs. 

998 The number of pairs is composed by the number 

999 of subjects in a particular demographic 

1000 multiplied by the number of `negative_pairs_per_subject`. 

1001 For example, a combination pairs with 1000 identities gives us 

1002 1000 pairs. Taking `3` pairs of images for these combinations 

1003 of identities will give us 3000 negative pairs. 

1004 

1005 """ 

1006 

1007 # Inverting subject 

1008 random.seed(0) 

1009 negatives = [] 

1010 

1011 # Creating the dictionary containing the demographics--> subjects 

1012 demographic_subject = dict() 

1013 for k, v in self.demographic_dataset.subject_demographic.items(): 

1014 demographic_subject[v] = demographic_subject.get(v, []) + [k] 

1015 

1016 # For each demographic, pic the negative pairs 

1017 

1018 for d in demographic_subject: 

1019 n_subjects = len(demographic_subject[d]) 

1020 

1021 subject_combinations = list( 

1022 itertools.combinations(demographic_subject[d], 2) 

1023 ) 

1024 # Shuffling these combinations 

1025 random.shuffle(subject_combinations) 

1026 

1027 for s_c in subject_combinations[ 

1028 0 : n_subjects * self.negative_pairs_per_subject 

1029 ]: 

1030 subject_i = self.siamese_bucket[s_c[0]] 

1031 subject_j = self.siamese_bucket[s_c[1]] 

1032 random.shuffle(subject_i) 

1033 random.shuffle(subject_j) 

1034 

1035 negatives += ((subject_i[0], subject_j[0]),) 

1036 

1037 return negatives 

1038 

1039 def __getitem__(self, idx): 

1040 sample = self.siamese_bucket[idx] 

1041 label = self.labels[idx] 

1042 

1043 # subject_id = sample.split("/")[-2] 

1044 

1045 # Transforming the image 

1046 image_i = sample[0].data 

1047 image_j = sample[1].data 

1048 

1049 image_i = ( 

1050 image_i 

1051 if self.demographic_dataset.transform is None 

1052 else self.demographic_dataset.transform(image_i) 

1053 ) 

1054 image_j = ( 

1055 image_j 

1056 if self.demographic_dataset.transform is None 

1057 else self.demographic_dataset.transform(image_j) 

1058 ) 

1059 

1060 demography = self.demographic_dataset.get_demographics(sample[0]) 

1061 

1062 # Getting the demographics 

1063 

1064 # demography = self.get_demographics(subject_id) 

1065 

1066 return { 

1067 "data": (image_i, image_j), 

1068 "label": label, 

1069 "demography": demography, 

1070 }