Coverage for src/bob/bio/face/pytorch/datasets/webface42m.py: 25%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> 

4 

5import csv 

6import os 

7 

8import numpy as np 

9 

10from clapper.rc import UserDefaults 

11from torch.utils.data import Dataset 

12 

13import bob.io.base 

14 

15from bob.bio.base.database.utils import download_file, md5_hash, search_and_open 

16 

17# from bob.bio.face.database import MEDSDatabase, MorphDatabase 

18 

19rc = UserDefaults("bobrc.toml") 

20 

21 

22class WebFace42M(Dataset): 

23 """ 

24 Pytorch Daset for the WebFace42M dataset mentioned in 

25 

26 

27 .. latex:: 

28 

29 @inproceedings {zhu2021webface260m, 

30 title= {WebFace260M: A Benchmark Unveiling the Power of Million-scale Deep Face Recognition}, 

31 author= {Zheng Zhu, Guan Huang, Jiankang Deng, Yun Ye, Junjie Huang, Xinze Chen, 

32 Jiagang Zhu, Tian Yang, Jiwen Lu, Dalong Du, Jie Zhou}, 

33 booktitle= {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 

34 year= {2021} 

35 } 

36 

37 This dataset contains 2'059'906 identities and 42'474'558 images. 

38 

39 

40 .. warning:: 

41 

42 To use this dataset protocol, you need to have the original files of the WebFace42M dataset. 

43 Once you have it downloaded, please run the following command to set the path for Bob 

44 

45 .. code-block:: sh 

46 

47 bob config set bob.bio.face.webface42M.directory [WEBFACE42M PATH] 

48 

49 """ 

50 

51 def __init__( 

52 self, 

53 database_path=rc.get("bob.bio.face.webface42M.directory", ""), 

54 transform=None, 

55 ): 

56 self.database_path = database_path 

57 

58 if database_path == "": 

59 raise ValueError( 

60 "`database_path` is empty; please do `bob config set bob.bio.face.webface42M.directory` to set the absolute path of the data" 

61 ) 

62 

63 urls = WebFace42M.urls() 

64 filename = download_file( 

65 urls=urls, 

66 destination_filename="webface42M.tar.gz", 

67 checksum="50c32cbe61de261466e1ea3af2721cea", 

68 checksum_fct=md5_hash, 

69 ) 

70 self.file = search_and_open(filename, "webface42M.csv") 

71 

72 self._line_offset = 51 

73 self.transform = transform 

74 

75 def __len__(self): 

76 # Avoiding this very slow task 

77 # return sum(1 for line in open(self.csv_file)) 

78 return 42474558 

79 

80 def __getitem__(self, idx): 

81 self.file.seek(0) 

82 

83 # Allowing negative indexing 

84 if idx < 0: 

85 idx = self.__len__() + idx 

86 

87 self.file.seek(idx * self._line_offset) 

88 line_sample = self.file.read(self._line_offset).split(",") 

89 

90 label = int(line_sample[0]) 

91 

92 file_name = os.path.join( 

93 self.database_path, line_sample[1].rstrip("\n").strip() 

94 ) 

95 image = bob.io.base.load(file_name) 

96 

97 image = image if self.transform is None else self.transform(image) 

98 

99 return {"data": image, "label": label} 

100 

101 @staticmethod 

102 def urls(): 

103 return [ 

104 "https://www.idiap.ch/software/bob/databases/latest/webface42M.tar.gz", 

105 "http://www.idiap.ch/software/bob/databases/latest/webface42M.tar.gz", 

106 ] 

107 

108 def generate_csv(self, output_csv_directory): 

109 """ 

110 Generates a bunch of CSV files containing all the files from the WebFace42M dataset 

111 The csv's have two columns only `LABEL, RELATIVE_FILE_PATH` 

112 

113 

114 

115 Idiap file structure 

116 

117 [0-9]_[0-6]_xxx 

118 | 

119 -- [0-9]_[0-9]_xxxxxxx 

120 | 

121 --- [0-9]_[0-9].jpg 

122 

123 

124 """ 

125 label_checker = np.zeros(2059906) 

126 counter = 0 

127 

128 # Navigating into the Idiap file structure 

129 for directory in os.listdir(self.database_path): 

130 output_csv_file = os.path.join( 

131 output_csv_directory, directory + ".csv" 

132 ) 

133 

134 with open(output_csv_file, "w") as csv_file: 

135 csv_writer = csv.writer(csv_file, delimiter=",") 

136 

137 print(f"Processing {directory}") 

138 rows = [] 

139 path = os.path.join(self.database_path, directory) 

140 if not os.path.isdir(path): 

141 continue 

142 for sub_directory in os.listdir(path): 

143 sub_path = os.path.join(path, sub_directory) 

144 label_checker[counter] = 1 

145 

146 if not os.path.isdir(sub_path): 

147 continue 

148 

149 for file in os.listdir(sub_path): 

150 relative_path = os.path.join( 

151 directory, sub_directory, file 

152 ) 

153 rows.append( 

154 [ 

155 str(counter).zfill(7), 

156 relative_path.rstrip("\n").rjust(42) + "\n", 

157 ] 

158 ) 

159 # csv_writer.writerow([label, relative_path]) 

160 counter += 1 

161 csv_writer.writerows(rows) 

162 

163 # print(counter) 

164 # Checking if all labels were taken 

165 zero_labels = np.where(label_checker == 0)[0] 

166 if zero_labels.shape[0] > 0: 

167 print(zero_labels)