Coverage for src/bob/bio/face/pytorch/datasets/webface42m.py: 25%
59 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
5import csv
6import os
8import numpy as np
10from clapper.rc import UserDefaults
11from torch.utils.data import Dataset
13import bob.io.base
15from bob.bio.base.database.utils import download_file, md5_hash, search_and_open
17# from bob.bio.face.database import MEDSDatabase, MorphDatabase
19rc = UserDefaults("bobrc.toml")
22class WebFace42M(Dataset):
23 """
24 Pytorch Daset for the WebFace42M dataset mentioned in
27 .. latex::
29 @inproceedings {zhu2021webface260m,
30 title= {WebFace260M: A Benchmark Unveiling the Power of Million-scale Deep Face Recognition},
31 author= {Zheng Zhu, Guan Huang, Jiankang Deng, Yun Ye, Junjie Huang, Xinze Chen,
32 Jiagang Zhu, Tian Yang, Jiwen Lu, Dalong Du, Jie Zhou},
33 booktitle= {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
34 year= {2021}
35 }
37 This dataset contains 2'059'906 identities and 42'474'558 images.
40 .. warning::
42 To use this dataset protocol, you need to have the original files of the WebFace42M dataset.
43 Once you have it downloaded, please run the following command to set the path for Bob
45 .. code-block:: sh
47 bob config set bob.bio.face.webface42M.directory [WEBFACE42M PATH]
49 """
51 def __init__(
52 self,
53 database_path=rc.get("bob.bio.face.webface42M.directory", ""),
54 transform=None,
55 ):
56 self.database_path = database_path
58 if database_path == "":
59 raise ValueError(
60 "`database_path` is empty; please do `bob config set bob.bio.face.webface42M.directory` to set the absolute path of the data"
61 )
63 urls = WebFace42M.urls()
64 filename = download_file(
65 urls=urls,
66 destination_filename="webface42M.tar.gz",
67 checksum="50c32cbe61de261466e1ea3af2721cea",
68 checksum_fct=md5_hash,
69 )
70 self.file = search_and_open(filename, "webface42M.csv")
72 self._line_offset = 51
73 self.transform = transform
75 def __len__(self):
76 # Avoiding this very slow task
77 # return sum(1 for line in open(self.csv_file))
78 return 42474558
80 def __getitem__(self, idx):
81 self.file.seek(0)
83 # Allowing negative indexing
84 if idx < 0:
85 idx = self.__len__() + idx
87 self.file.seek(idx * self._line_offset)
88 line_sample = self.file.read(self._line_offset).split(",")
90 label = int(line_sample[0])
92 file_name = os.path.join(
93 self.database_path, line_sample[1].rstrip("\n").strip()
94 )
95 image = bob.io.base.load(file_name)
97 image = image if self.transform is None else self.transform(image)
99 return {"data": image, "label": label}
101 @staticmethod
102 def urls():
103 return [
104 "https://www.idiap.ch/software/bob/databases/latest/webface42M.tar.gz",
105 "http://www.idiap.ch/software/bob/databases/latest/webface42M.tar.gz",
106 ]
108 def generate_csv(self, output_csv_directory):
109 """
110 Generates a bunch of CSV files containing all the files from the WebFace42M dataset
111 The csv's have two columns only `LABEL, RELATIVE_FILE_PATH`
115 Idiap file structure
117 [0-9]_[0-6]_xxx
118 |
119 -- [0-9]_[0-9]_xxxxxxx
120 |
121 --- [0-9]_[0-9].jpg
124 """
125 label_checker = np.zeros(2059906)
126 counter = 0
128 # Navigating into the Idiap file structure
129 for directory in os.listdir(self.database_path):
130 output_csv_file = os.path.join(
131 output_csv_directory, directory + ".csv"
132 )
134 with open(output_csv_file, "w") as csv_file:
135 csv_writer = csv.writer(csv_file, delimiter=",")
137 print(f"Processing {directory}")
138 rows = []
139 path = os.path.join(self.database_path, directory)
140 if not os.path.isdir(path):
141 continue
142 for sub_directory in os.listdir(path):
143 sub_path = os.path.join(path, sub_directory)
144 label_checker[counter] = 1
146 if not os.path.isdir(sub_path):
147 continue
149 for file in os.listdir(sub_path):
150 relative_path = os.path.join(
151 directory, sub_directory, file
152 )
153 rows.append(
154 [
155 str(counter).zfill(7),
156 relative_path.rstrip("\n").rjust(42) + "\n",
157 ]
158 )
159 # csv_writer.writerow([label, relative_path])
160 counter += 1
161 csv_writer.writerows(rows)
163 # print(counter)
164 # Checking if all labels were taken
165 zero_labels = np.where(label_checker == 0)[0]
166 if zero_labels.shape[0] > 0:
167 print(zero_labels)