Coverage for src/bob/bio/face/pytorch/datasets/demographics.py: 22%
330 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
5"""
6Datasets that handles demographic information
8"""
11import itertools
12import logging
13import os
14import random
16import cloudpickle
17import numpy as np
18import pandas as pd
19import torch
21from clapper.rc import UserDefaults
22from torch.utils.data import Dataset
24import bob.io.base
26from bob.bio.base.database.utils import download_file, md5_hash
27from bob.bio.face.database import (
28 MEDSDatabase,
29 MobioDatabase,
30 MorphDatabase,
31 RFWDatabase,
32 VGG2Database,
33)
35logger = logging.getLogger(__name__)
36rc = UserDefaults("bobrc.toml")
39class DemographicTorchDataset(Dataset):
40 """
41 Pytorch base dataset that handles demographic information
43 Parameters
44 ----------
46 bob_dataset:
47 Instance of a bob database object
49 transform=None
51 """
53 def __init__(self, bob_dataset, transform=None):
54 self.bob_dataset = bob_dataset
55 self.transform = transform
56 self.load_bucket()
58 def __len__(self):
59 return len(self.bucket)
61 @property
62 def n_classes(self):
63 return len(self.labels)
65 @property
66 def n_samples(self):
67 return len(self.bucket)
69 @property
70 def demographic_keys(self):
71 return self._demographic_keys
73 def __getitem__(self, idx):
74 """
75 It dumps a dictionary containing the following keys: data, label, demography
77 """
79 sample = self.bucket[idx]
81 image = (
82 sample.data
83 if self.transform is None
84 else self.transform(sample.data)
85 )
87 # image = image.astype("float32")
89 label = self.labels[sample.subject_id]
91 demography = self.get_demographics(sample)
93 return {"data": image, "label": label, "demography": demography}
95 def count_subjects_per_demographics(self):
96 """
97 Count the number of subjects per demographics
98 """
99 all_demographics = list(self.subject_demographic.values())
101 # Number of subjects per demographic
102 subjects_per_demographics = dict(
103 [
104 (d, sum(np.array(all_demographics) == d))
105 for d in set(all_demographics)
106 ]
107 )
109 return subjects_per_demographics
111 def get_demographic_weights(self, as_dict=True):
112 """
113 Compute the inverse weighting for each demographic group.
116 .. warning::
117 This is not the same function as `get_demographic_class_weights`.
119 Parameters
120 ----------
121 If `True` will return the weights as a dict.
123 """
124 n_identities = len(self.subject_demographic)
126 # Number of subjects per demographic
127 subjects_per_demographics = self.count_subjects_per_demographics()
129 # INverse probability (1-p_i)/p_i
130 demographic_weights = dict()
131 for i in subjects_per_demographics:
132 p_i = subjects_per_demographics[i] / n_identities
133 demographic_weights[i] = (1 - p_i) / p_i
135 p_accumulator = sum(demographic_weights.values())
136 # Scaling the inverse probability
137 for i in demographic_weights:
138 demographic_weights[i] /= p_accumulator
140 # Return as a dictionary
141 if as_dict:
142 return demographic_weights
144 # Returning as a list (this is more aproppriated for NN training)
145 return [demographic_weights[k] for k in self.demographic_keys]
147 def get_demographic_class_weights(self):
148 """
149 Compute the class weights based on the demographics
151 Returns
152 -------
153 weights: list
154 A list containing the weights for each class
155 """
157 subjects_per_demographics = self.count_subjects_per_demographics()
158 demographic_weights = self.get_demographic_weights()
160 weights = [
161 demographic_weights[v] / subjects_per_demographics[v]
162 for k, v in self.subject_demographic.items()
163 ]
165 return torch.Tensor(weights)
168class MedsTorchDataset(DemographicTorchDataset):
169 """
170 MEDS torch interface
172 .. warning::
173 Unfortunatelly, in this dataset there are several identities that has only ONE sample.
174 Hence, it is impossible to properly use this dataset to do contrastive learning, for instance.
175 If this is thecase, please set `take_from_znorm=True`, so the `dev` or the `eval` sets are used.
178 Parameters
179 ----------
181 protocol: str
182 One of the MEDS available protocols, check :py:class:`bob.bio.face.database.MEDSDatabase`
184 database_path: str
185 Database path
187 database_extension: str
188 Database extension
190 transform: callable
191 Transformation function to the input sample
193 take_from_znorm: bool
194 If `True`, it will take the samples from `treferences` and `zprobes` methods that comes from the training set
195 If `False`, it will take the samples from `references` and `probes` methods. Then, the variable `group` is considered.
197 group: str
199 """
201 def __init__(
202 self,
203 protocol,
204 database_path,
205 database_extension=".h5",
206 transform=None,
207 take_from_znorm=False,
208 group="dev",
209 ):
210 bob_dataset = MEDSDatabase(
211 protocol=protocol,
212 dataset_original_directory=database_path,
213 dataset_original_extension=database_extension,
214 )
215 self.take_from_znorm = take_from_znorm
216 self.group = group
217 super().__init__(bob_dataset, transform=transform)
219 def load_bucket(self):
220 self._target_metadata = "rac"
222 if self.take_from_znorm:
223 self.bucket = [
224 s for sset in self.bob_dataset.zprobes() for s in sset
225 ]
226 self.bucket += [
227 s for sset in self.bob_dataset.treferences() for s in sset
228 ]
229 else:
230 self.bucket = [
231 s
232 for sset in self.bob_dataset.probes(group=self.group)
233 for s in sset
234 ]
235 self.bucket += [
236 s
237 for sset in self.bob_dataset.references(group=self.group)
238 for s in sset
239 ]
241 offset = 0
242 self.labels = dict()
243 self.subject_demographic = dict()
245 for s in self.bucket:
246 if s.subject_id not in self.labels:
247 self.labels[s.subject_id] = offset
248 self.subject_demographic[s.subject_id] = getattr(
249 s, self._target_metadata
250 )
251 offset += 1
253 metadata_keys = set(self.subject_demographic.values())
254 self._demographic_keys = dict(
255 zip(metadata_keys, range(len(metadata_keys)))
256 )
258 def get_demographics(self, sample):
259 demographic_key = getattr(sample, "rac")
260 return self._demographic_keys[demographic_key]
263class VGG2TorchDataset(DemographicTorchDataset):
264 """
265 VGG2 for torch.
267 This interface make usage of :any:`bob.bio.face.database.VGG2Database`.
269 The "race" labels below were annotated by the students from the period 2018-2020. Race labels taken from: MasterEBTSv10.0.809302017_Final.pdf
271 - A: Asian in general (Chinese, Japanese, Filipino, Korean, Polynesian, Indonesian, Samoan, or any other Pacific Islander
272 - B: A person having origins in any of the black racial groups of Africa
273 - I: American Indian, Asian Indian, Eskimo, or Alaskan native
274 - U: Of indeterminable race
275 - W: Caucasian, Mexican, Puerto Rican, Cuban, Central or South American, or other Spanish culture or origin, Regardless of race
276 - N: None of the above
279 Gender information was taken from the original dataset
280 There are the following genders available:
281 - male
282 - female
285 .. note::
286 Some important information about this interface.
287 We have the following statistics:
288 - n_classes = 8631
289 - n_demographics: 12 ['m-A': 0, 'm-B': 1, 'm-I': 2, 'm-U': 3, 'm-W': 4, 'm-N': 5, 'f-A': 6, 'f-B': 7, 'f-I': 8, 'f-U': 9, 'f-W': 10, 'f-N': 11]
292 .. note::
294 Follow the distribution the combination of race and gender demographics
295 {'m-B': 552, 'm-U': 64, 'm-W': 3903, 'f-W': 2657, 'f-A': 286, 'f-U': 34, 'f-I': 298, 'f-N': 2, 'f-B': 200, 'm-N': 1, 'm-I': 366, 'm-A': 268}
297 Note that `m-N` has 1 subject and 'f-N' has 2 subjects.
298 For this reason, we are removing this race from this interface.
299 We can't learn anything from one sample.
302 Parameters
303 ----------
304 database_path: str
305 Path containing the raw data
307 database_extension:
309 load_bucket_from_cache: bool
310 If set, it will load the list of available samples from the cache
312 train: bool
313 If set it will prepare a bucket for training.
315 include_u_n: bool
316 If `True` it will include 'U' (Undefined) and 'N' (None) on the list of races.
319 """
321 def __init__(
322 self,
323 protocol,
324 database_path,
325 database_extension=".jpg",
326 transform=None,
327 load_bucket_from_cache=True,
328 include_u_n=False,
329 train=True,
330 ):
331 bob_dataset = VGG2Database(
332 protocol=protocol,
333 dataset_original_directory=database_path,
334 dataset_original_extension=database_extension,
335 )
336 self.load_bucket_from_cache = load_bucket_from_cache
338 # Percentage of the samples used for training
339 self._percentage_for_training = 0.8
340 self.train = train
342 # All possible metadata
343 self._possible_genders = ["m", "f"]
345 # self._possible_races = ["A", "B", "I", "U", "W", "N"]
346 self._possible_races = ["A", "B", "I", "W"]
347 if include_u_n:
348 self._possible_races += ["U", "N"]
350 super().__init__(bob_dataset, transform=transform)
352 def decode_race(self, race):
353 # return race if race in self._possible_races else "N"
354 return race if race in self._possible_races else "W"
356 def get_key(self, sample):
357 return f"{sample.gender}-{self.decode_race(sample.race)}"
359 def get_cache_path(self):
360 filename = (
361 "vgg2_short_cached_bucket.pickle"
362 if self.bob_dataset.protocol == "vgg2-short"
363 else "vgg2_full_cached_bucket.pickle"
364 )
366 return os.path.join(
367 rc.get(
368 "bob_data_folder",
369 os.path.join(os.path.expanduser("~"), "bob_data"),
370 ),
371 "datasets",
372 f"{filename}",
373 )
375 def cache_bucket(self, bucket):
376 """
377 Cache the list of samples into a temporary directory
378 """
379 bucket_filename = self.get_cache_path()
380 os.makedirs(os.path.dirname(bucket_filename), exist_ok=True)
381 with open(bucket_filename, "wb") as f:
382 cloudpickle.dump(bucket, f)
384 def load_cached_bucket(self):
385 bucket_filename = self.get_cache_path()
386 with open(bucket_filename, "rb") as f:
387 bucket = cloudpickle.load(f)
388 return bucket
390 def load_bucket(self):
391 # Defining the demographics keys
392 self._demographic_keys = [
393 f"{gender}-{race}"
394 for gender in self._possible_genders
395 for race in self._possible_races
396 ]
397 self._demographic_keys = dict(
398 [(d, i) for i, d in enumerate(self._demographic_keys)]
399 )
401 # Loading the buket from cache
402 if self.load_bucket_from_cache and os.path.exists(
403 self.get_cache_path()
404 ):
405 self.bucket = self.load_cached_bucket()
406 else:
407 self.bucket = [
408 s for s in self.bob_dataset.background_model_samples()
409 ]
410 # Caching the bucket
411 self.cache_bucket(self.bucket)
413 # Mapping subject_id with labels
414 self.labels = sorted(list(set([s.subject_id for s in self.bucket])))
415 self.labels = dict([(l, i) for i, l in enumerate(self.labels)])
417 # Spliting the bucket into training and developement set
418 all_indexes = np.array([self.labels[x.subject_id] for x in self.bucket])
419 indexes = []
420 if self.train:
421 for i in range(self.n_classes):
422 ind = np.where(all_indexes == i)[0]
423 indexes += list(
424 ind[
425 0 : int(
426 np.floor(len(ind) * self._percentage_for_training)
427 )
428 ]
429 )
430 else:
431 for i in range(self.n_classes):
432 ind = np.where(all_indexes == i)[0]
433 indexes += list(
434 ind[
435 int(
436 np.floor(len(ind) * self._percentage_for_training)
437 ) :
438 ]
439 )
441 # Redefining the bucket
442 self.bucket = list(np.array(self.bucket)[indexes])
444 # Mapping subject and demographics for fast access
445 self.subject_demographic = dict()
447 for s in self.bucket:
448 if s.subject_id not in self.subject_demographic:
449 self.subject_demographic[s.subject_id] = self.get_key(s)
451 def get_demographics(self, sample):
452 demographic_key = self.get_key(sample)
453 return self._demographic_keys[demographic_key]
456class MorphTorchDataset(DemographicTorchDataset):
457 """
458 MORPH torch interface
460 .. warning::
461 Unfortunatelly, in this dataset there are several identities that has only ONE sample.
462 Hence, it is impossible to properly use this dataset to do contrastive learning, for instance.
463 If this is thecase, please set `take_from_znorm=True`, so the `dev` or the `eval` sets are used.
466 Parameters
467 ----------
469 protocol: str
470 One of the Morph available protocols, check :py:class:`bob.bio.face.database.MEDSDatabase`
472 database_path: str
473 Database path
475 database_extension: str
476 Database extension
478 transform: callable
479 Transformation function to the input sample
481 take_from_znorm: bool
482 If `True`, it will take the samples from `treferences` and `zprobes` methods that comes from the training set
483 If `False`, it will take the samples from `references` and `probes` methods. Then, the variable `group` is considered.
485 group: str
487 """
489 def __init__(
490 self,
491 protocol,
492 database_path,
493 database_extension=".h5",
494 transform=None,
495 take_from_znorm=True,
496 group="dev",
497 ):
498 bob_dataset = MorphDatabase(
499 protocol=protocol,
500 dataset_original_directory=database_path,
501 dataset_original_extension=database_extension,
502 )
503 self.take_from_znorm = take_from_znorm
504 self.group = group
506 super().__init__(bob_dataset, transform=transform)
508 def load_bucket(self):
509 # Morph dataset has an intersection in between zprobes and treferences
510 # Those are the
511 self.excluding_list = [
512 "190276",
513 "332158",
514 "111942",
515 "308129",
516 "334074",
517 "350814",
518 "131677",
519 "168724",
520 "276055",
521 "275589",
522 "286810",
523 ]
525 if self.take_from_znorm:
526 self.bucket = [
527 s for sset in self.bob_dataset.zprobes() for s in sset
528 ]
529 self.bucket += [
530 s
531 for sset in self.bob_dataset.treferences()
532 for s in sset
533 if sset.subject_id not in self.excluding_list
534 ]
535 else:
536 self.bucket = [
537 s
538 for sset in self.bob_dataset.probes(group=self.group)
539 for s in sset
540 ]
541 self.bucket += [
542 s
543 for sset in self.bob_dataset.references(group=self.group)
544 for s in sset
545 ]
547 offset = 0
548 self.labels = dict()
549 self.subject_demographic = dict()
551 for s in self.bucket:
552 if s.subject_id not in self.labels:
553 self.labels[s.subject_id] = offset
554 self.subject_demographic[s.subject_id] = f"{s.rac}-{s.sex}"
555 offset += 1
557 metadata_keys = set(self.subject_demographic.values())
558 self._demographic_keys = dict(
559 zip(metadata_keys, range(len(metadata_keys)))
560 )
562 def get_demographics(self, sample):
563 demographic_key = f"{sample.rac}-{sample.sex}"
564 return self._demographic_keys[demographic_key]
567class RFWTorchDataset(DemographicTorchDataset):
568 def __init__(
569 self, protocol, database_path, database_extension=".h5", transform=None
570 ):
571 bob_dataset = RFWDatabase(
572 protocol=protocol,
573 dataset_original_directory=database_path,
574 dataset_original_extension=database_extension,
575 )
576 super().__init__(bob_dataset, transform=transform)
578 def load_demographics(self):
579 target_metadata = "race"
580 metadata_keys = set(
581 [
582 getattr(sset, target_metadata)
583 for sset in self.bob_dataset.zprobes()
584 ]
585 + [
586 getattr(sset, target_metadata)
587 for sset in self.bob_dataset.treferences()
588 ]
589 )
590 metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
591 return metadata_keys
593 def get_demographics(self, sample):
594 demographic_key = getattr(sample, "race")
595 return self._demographic_keys[demographic_key]
598class MobioTorchDataset(DemographicTorchDataset):
599 def __init__(
600 self, protocol, database_path, database_extension=".h5", transform=None
601 ):
602 bob_dataset = MobioDatabase(
603 protocol=protocol,
604 dataset_original_directory=database_path,
605 dataset_original_extension=database_extension,
606 )
608 super().__init__(bob_dataset, transform=transform)
610 def load_bucket(self):
611 self._target_metadata = "gender"
612 self.bucket = [s for s in self.bob_dataset.background_model_samples()]
613 offset = 0
614 self.labels = dict()
615 self.subject_demographic = dict()
617 for s in self.bucket:
618 if s.subject_id not in self.labels:
619 self.labels[s.subject_id] = offset
620 self.subject_demographic[s.subject_id] = getattr(
621 s, self._target_metadata
622 )
623 offset += 1
625 metadata_keys = set(self.subject_demographic.values())
626 self._demographic_keys = dict(
627 zip(metadata_keys, range(len(metadata_keys)))
628 )
630 def __len__(self):
631 return len(self.bucket)
633 def get_demographics(self, sample):
634 demographic_key = getattr(sample, self._target_metadata)
635 return self._demographic_keys[demographic_key]
638class MSCelebTorchDataset(DemographicTorchDataset):
639 """
640 This interface make usage of a CSV file containing gender and
641 RACE annotations available at.
643 The "race" labels below were annotated by the students from the period 2018-2020. Race labels taken from: MasterEBTSv10.0.809302017_Final.pdf
645 - A: Asian in general (Chinese, Japanese, Filipino, Korean, Polynesian, Indonesian, Samoan, or any other Pacific Islander
646 - B: A person having origins in any of the black racial groups of Africa
647 - I: American Indian, Asian Indian, Eskimo, or Alaskan native
648 - U: Of indeterminable race
649 - W: Caucasian, Mexican, Puerto Rican, Cuban, Central or South American, or other Spanish culture or origin, Regardless of race
650 - N: None of the above
653 Gender and country information taken from the wiki data: https://www.wikidata.org/wiki/Wikidata:Main_Page
654 There are the following genders available:
655 - male
656 - female
657 - other
660 .. note::
661 Some important information about this interface.
662 If `include_unknow_demographics==False` we will have the following statistics:
663 - n_classes = 81279
664 - n_demographics: 15 ['male-A', 'male-B', 'male-I', 'male-U', 'male-W', 'female-A', 'female-B', 'female-I', 'female-U', 'female-W', 'other-A', 'other-B', 'other-I', 'other-U', 'other-W']
667 If `include_unknow_demographics==True` we will have the following statistics:
668 - n_classes = 89735
669 - n_demographics: 18 ['male-A', 'male-B', 'male-I', 'male-N', 'male-U', 'male-W', 'female-A', 'female-B', 'female-I', 'female-N', 'female-U', 'female-W', 'other-A', 'other-B', 'other-I', 'other-N', 'other-U', 'other-W']
673 Parameters
674 ----------
675 database_path: str
676 Path containing the raw data
678 database_extension:
680 idiap_path: bool
681 If set, it will use the idiap standard relative path to load the data (e.g. [BASE_PATH]/chunk_[n]/[user_id])
683 include_unknow_demographics: bool
684 If set, it will include subjects whose race was set to `N` (None of the above)
686 load_bucket_from_cache: bool
687 If set, it will load the list of available samples from the cache
690 """
692 def __init__(
693 self,
694 database_path,
695 database_extension=".png",
696 idiap_path=True,
697 include_unknow_demographics=False,
698 load_bucket_from_cache=True,
699 transform=None,
700 ):
701 self.idiap_path = idiap_path
702 self.database_path = database_path
703 self.database_extension = database_extension
704 self.include_unknow_demographics = include_unknow_demographics
705 self.load_bucket_from_cache = load_bucket_from_cache
706 self.transform = transform
708 # Private keys
709 self._possible_genders = ["male", "female", "other"]
711 urls = MSCelebTorchDataset.urls()
712 filename = (
713 download_file(
714 urls=urls,
715 destination_filename="msceleb_race_wikidata.tar.gz",
716 checksum="76339d73f352faa00c155f7040e772bb",
717 checksum_fct=md5_hash,
718 extract=True,
719 )
720 / "msceleb_race_wikidata.csv"
721 )
723 self.load_bucket(filename)
725 @staticmethod
726 def urls():
727 return [
728 "https://www.idiap.ch/software/bob/databases/latest/msceleb_race_wikidata.tar.gz",
729 "http://www.idiap.ch/software/bob/databases/latest/msceleb_race_wikidata.tar.gz",
730 ]
732 def get_cache_path(self):
733 filename = (
734 "msceleb_cached_bucket_WITH_unknow_demographics.csv"
735 if self.include_unknow_demographics
736 else "msceleb_cached_bucket_WITHOUT_unknow_demographics.csv"
737 )
739 return os.path.join(
740 rc.get(
741 "bob_data_folder",
742 os.path.join(os.path.expanduser("~"), "bob_data"),
743 ),
744 "datasets",
745 f"{filename}",
746 )
748 def cache_bucket(self, bucket):
749 """
750 Cache the list of samples into a temporary directory
751 """
752 bucket_filename = self.get_cache_path()
753 os.makedirs(os.path.dirname(bucket_filename), exist_ok=True)
754 with open(bucket_filename, "w") as f:
755 for b in bucket:
756 f.write(f"{b}\n")
758 def load_cached_bucket(self):
759 """
760 Load the bucket from the cache
761 """
762 bucket_filename = self.get_cache_path()
763 return [f.rstrip("\n") for f in open(bucket_filename).readlines()]
765 def __len__(self):
766 return len(self.bucket)
768 def load_bucket(self, csv_filename):
769 dataframe = pd.read_csv(csv_filename)
771 # Possible races
772 # {'A', 'B', 'I', 'N', 'U', 'W', nan}
774 filtered_dataframe = (
775 dataframe.loc[
776 (dataframe.RACE == "A")
777 | (dataframe.RACE == "B")
778 | (dataframe.RACE == "I")
779 | (dataframe.RACE == "U")
780 | (dataframe.RACE == "W")
781 | (dataframe.RACE == "N")
782 ]
783 if self.include_unknow_demographics
784 else dataframe.loc[
785 (dataframe.RACE == "A")
786 | (dataframe.RACE == "B")
787 | (dataframe.RACE == "I")
788 | (dataframe.RACE == "U")
789 | (dataframe.RACE == "W")
790 ]
791 )
793 filtered_dataframe_list = filtered_dataframe[
794 ["idiap_chunk", "ID"]
795 ].to_csv()
797 # Defining the number of classes
798 subject_relative_paths = [
799 os.path.join(ll.split(",")[1], ll.split(",")[2])
800 for ll in filtered_dataframe_list.split("\n")[1:-1]
801 ]
803 if self.load_bucket_from_cache and os.path.exists(
804 self.get_cache_path()
805 ):
806 self.bucket = self.load_cached_bucket()
807 else:
808 # Defining all images
809 logger.warning(
810 f"Fetching all samples paths on the fly. This might take some minutes."
811 f"Then this will be cached in {self.get_cache_path()} and loaded from this cache"
812 )
814 self.bucket = [
815 os.path.join(subject, f)
816 for subject in subject_relative_paths
817 for f in os.listdir(os.path.join(self.database_path, subject))
818 if f[-4:] == self.database_extension
819 ]
820 self.cache_bucket(self.bucket)
822 self.labels = dict(
823 [
824 (k.split("/")[-1], i)
825 for i, k in enumerate(subject_relative_paths)
826 ]
827 )
829 # Setting the possible demographics and the demographic keys
830 filtered_dataframe = filtered_dataframe.set_index("ID")
831 self.metadata = filtered_dataframe[["GENDER", "RACE"]].to_dict(
832 orient="index"
833 )
835 self._demographic_keys = [
836 f"{gender}-{race}"
837 for gender in self._possible_genders
838 for race in sorted(set(filtered_dataframe["RACE"]))
839 ]
840 self._demographic_keys = dict(
841 [(d, i) for i, d in enumerate(self._demographic_keys)]
842 )
844 # Creating a map between the subject and the demographic
845 self.subject_demographic = dict(
846 [(m, self.get_demographics(m)) for m in self.metadata]
847 )
849 def get_demographics(self, subject_id):
850 race = self.metadata[subject_id]["RACE"]
851 gender = self.metadata[subject_id]["GENDER"]
853 gender = "other" if gender != "male" and gender != "female" else gender
855 return self._demographic_keys[f"{gender}-{race}"]
857 def __getitem__(self, idx):
858 sample = self.bucket[idx]
860 subject_id = sample.split("/")[-2]
862 # Transforming the image
863 image = bob.io.base.load(os.path.join(self.database_path, sample))
865 image = image if self.transform is None else self.transform(image)
867 label = self.labels[subject_id]
869 # Getting the demographics
871 demography = self.get_demographics(subject_id)
873 return {"data": image, "label": label, "demography": demography}
876class SiameseDemographicWrapper(Dataset):
877 """
878 This class wraps the current demographic interface and
879 dumps random positive and negative pairs of samples
881 """
883 def __init__(
884 self,
885 demographic_dataset,
886 max_positive_pairs_per_subject=20,
887 negative_pairs_per_subject=3,
888 dense_negatives=False,
889 ):
890 self.demographic_dataset = demographic_dataset
891 self.max_positive_pairs_per_subject = max_positive_pairs_per_subject
892 self.negative_pairs_per_subject = negative_pairs_per_subject
894 # Creating a bucket mapping the items of the bucket with their respective identities
895 self.siamese_bucket = dict()
896 for b in demographic_dataset.bucket:
897 if b.subject_id not in self.siamese_bucket:
898 self.siamese_bucket[b.subject_id] = []
900 self.siamese_bucket[b.subject_id].append(b)
902 positive_pairs = self.create_positive_pairs()
903 if dense_negatives:
904 negative_pairs = self.create_dense_negative_pairs()
905 else:
906 negative_pairs = self.create_light_negative_pairs()
908 # Redefining the bucket
909 self.siamese_bucket = negative_pairs + positive_pairs
911 self.labels = np.hstack(
912 (np.zeros(len(negative_pairs)), np.ones(len(positive_pairs)))
913 )
915 pass
917 def __len__(self):
918 return len(self.siamese_bucket)
920 def create_positive_pairs(self):
921 # Creating positive pairs for each identity
922 positives = []
923 random.seed(0)
924 for b in self.siamese_bucket:
925 samples = self.siamese_bucket[b]
926 random.shuffle(samples)
928 # All possible pair combinations
929 samples = itertools.combinations(samples, 2)
931 positives += [
932 s
933 for s in list(samples)[0 : self.max_positive_pairs_per_subject]
934 ]
935 pass
937 return positives
939 def create_dense_negative_pairs(self):
940 """
941 Creating negative pairs.
942 Here we create only negative pairs from the same demographic group,
943 since we know that pairs from different demographics leads to
944 poor scores
947 .. warning:
948 The list of negative pairs is dense.
949 For each combination of subjects for a particular demographic,
950 we will take `negative_pairs_per_subject` samples.
951 Hence, the number of negative pairs can explode as a function
952 of number of subjects.
953 For example, a combination pairs with 1000 identities gives us
954 499500 pairs. Taking `3` pairs of images for these combinations
955 of identities will give us ~1.5M negative pairs.
956 Hence, be careful with that.
958 """
960 # Inverting subject
961 random.seed(0)
962 negatives = []
964 # Creating the dictionary containing the demographics--> subjects
965 demographic_subject = dict()
966 for k, v in self.demographic_dataset.subject_demographic.items():
967 demographic_subject[v] = demographic_subject.get(v, []) + [k]
969 # For each demographic, pic the negative pairs
970 for d in demographic_subject:
971 subject_combinations = itertools.combinations(
972 demographic_subject[d], 2
973 )
975 for s_c in subject_combinations:
976 subject_i = self.siamese_bucket[s_c[0]]
977 subject_j = self.siamese_bucket[s_c[1]]
978 random.shuffle(subject_i)
979 random.shuffle(subject_j)
981 # All possible combinations
982 for i, p in enumerate(itertools.product(subject_i, subject_j)):
983 if i == self.negative_pairs_per_subject:
984 break
985 negatives += ((p[0], p[1]),)
987 return negatives
989 def create_light_negative_pairs(self):
990 """
991 Creating negative pairs.
992 Here we create only negative pairs from the same demographic group,
993 since we know that pairs from different demographics leads to
994 poor scores
996 .. warning:
997 This function generates a light set of negative pairs.
998 The number of pairs is composed by the number
999 of subjects in a particular demographic
1000 multiplied by the number of `negative_pairs_per_subject`.
1001 For example, a combination pairs with 1000 identities gives us
1002 1000 pairs. Taking `3` pairs of images for these combinations
1003 of identities will give us 3000 negative pairs.
1005 """
1007 # Inverting subject
1008 random.seed(0)
1009 negatives = []
1011 # Creating the dictionary containing the demographics--> subjects
1012 demographic_subject = dict()
1013 for k, v in self.demographic_dataset.subject_demographic.items():
1014 demographic_subject[v] = demographic_subject.get(v, []) + [k]
1016 # For each demographic, pic the negative pairs
1018 for d in demographic_subject:
1019 n_subjects = len(demographic_subject[d])
1021 subject_combinations = list(
1022 itertools.combinations(demographic_subject[d], 2)
1023 )
1024 # Shuffling these combinations
1025 random.shuffle(subject_combinations)
1027 for s_c in subject_combinations[
1028 0 : n_subjects * self.negative_pairs_per_subject
1029 ]:
1030 subject_i = self.siamese_bucket[s_c[0]]
1031 subject_j = self.siamese_bucket[s_c[1]]
1032 random.shuffle(subject_i)
1033 random.shuffle(subject_j)
1035 negatives += ((subject_i[0], subject_j[0]),)
1037 return negatives
1039 def __getitem__(self, idx):
1040 sample = self.siamese_bucket[idx]
1041 label = self.labels[idx]
1043 # subject_id = sample.split("/")[-2]
1045 # Transforming the image
1046 image_i = sample[0].data
1047 image_j = sample[1].data
1049 image_i = (
1050 image_i
1051 if self.demographic_dataset.transform is None
1052 else self.demographic_dataset.transform(image_i)
1053 )
1054 image_j = (
1055 image_j
1056 if self.demographic_dataset.transform is None
1057 else self.demographic_dataset.transform(image_j)
1058 )
1060 demography = self.demographic_dataset.get_demographics(sample[0])
1062 # Getting the demographics
1064 # demography = self.get_demographics(subject_id)
1066 return {
1067 "data": (image_i, image_j),
1068 "label": label,
1069 "demography": demography,
1070 }