Coverage for src/bob/bio/face/embeddings/mxnet.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch> 

4# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> 

5 

6import os 

7 

8import numpy as np 

9 

10from sklearn.base import BaseEstimator, TransformerMixin 

11from sklearn.utils import check_array 

12 

13from bob.bio.base.database.utils import download_file, md5_hash 

14from bob.bio.face.annotator import MTCNN 

15 

16 

17class MxNetTransformer(TransformerMixin, BaseEstimator): 

18 

19 """ 

20 Base Transformer for MxNet architectures. 

21 

22 Parameters: 

23 ----------- 

24 

25 checkpoint_path : str 

26 Path containing the checkpoint 

27 

28 config : str 

29 json file containing the DNN spec 

30 

31 preprocessor: 

32 A function that will transform the data right before forward. The default transformation is `X=X` 

33 

34 use_gpu: bool 

35 """ 

36 

37 def __init__( 

38 self, 

39 checkpoint_path=None, 

40 config=None, 

41 use_gpu=False, 

42 memory_demanding=False, 

43 preprocessor=lambda x: x, 

44 **kwargs, 

45 ): 

46 super().__init__(**kwargs) 

47 self.checkpoint_path = checkpoint_path 

48 self.config = config 

49 self.use_gpu = use_gpu 

50 self.model = None 

51 self.memory_demanding = memory_demanding 

52 self.preprocessor = preprocessor 

53 

54 def _load_model(self): 

55 import warnings 

56 

57 import mxnet as mx 

58 

59 from mxnet import gluon 

60 

61 ctx = mx.gpu() if self.use_gpu else mx.cpu() 

62 

63 with warnings.catch_warnings(): 

64 warnings.simplefilter("ignore") 

65 deserialized_net = gluon.nn.SymbolBlock.imports( 

66 self.config, ["data"], self.checkpoint_path, ctx=ctx 

67 ) 

68 

69 self.model = deserialized_net 

70 

71 def transform(self, X): 

72 import mxnet as mx 

73 

74 if self.model is None: 

75 self._load_model() 

76 

77 X = check_array(X, allow_nd=True) 

78 X = self.preprocessor(X) 

79 

80 def _transform(X): 

81 X = mx.nd.array(X) 

82 db = mx.io.DataBatch(data=(X,)) 

83 self.model.forward(db, is_train=False) 

84 return self.model.get_outputs()[0].asnumpy() 

85 

86 if self.memory_demanding: 

87 features = np.array([_transform(x[None, ...]) for x in X]) 

88 

89 # If we ndim is > than 3. We should stack them all 

90 # The enroll_features can come from a source where there are `N` samples containing 

91 # nxd samples 

92 if features.ndim >= 3: 

93 features = np.vstack(features) 

94 

95 return features 

96 else: 

97 return _transform(X) 

98 

99 def __getstate__(self): 

100 # Handling unpicklable objects 

101 

102 d = self.__dict__.copy() 

103 d["model"] = None 

104 return d 

105 

106 def _more_tags(self): 

107 return {"requires_fit": False} 

108 

109 

110class ArcFaceInsightFace_LResNet100(MxNetTransformer): 

111 """ 

112 Extracts features using deep face recognition models under MxNet Interfaces. 

113 

114 Users can download the pretrained face recognition models with MxNet Interface. The path to downloaded models (and weights) should be specified while calling this class, usually in the configuration file of an experiment. 

115 

116 Examples: (Pretrained ResNet models): `LResNet100E-IR,ArcFace@ms1m-refine-v2 <https://github.com/deepinsight/insightface>`_ 

117 

118 The extracted features can be combined with different the algorithms. 

119 

120 """ 

121 

122 def __init__(self, memory_demanding=False, use_gpu=False): 

123 urls = [ 

124 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/mxnet/arcface_r100_v1_mxnet.tar.gz", 

125 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/mxnet/arcface_r100_v1_mxnet.tar.gz", 

126 ] 

127 filename = download_file( 

128 urls=urls, 

129 destination_sub_directory="data/mxnet/arcface_r100_v1_mxnet", 

130 destination_filename="arcface_r100_v1_mxnet.tar.gz", 

131 checksum="050ce7d6e731e560127c705f61391f48", 

132 checksum_fct=md5_hash, 

133 extract=True, 

134 ) 

135 path = filename 

136 checkpoint_path = os.path.join(path, "model-symbol.json") 

137 config = os.path.join(path, "model-0000.params") 

138 

139 super(ArcFaceInsightFace_LResNet100, self).__init__( 

140 checkpoint_path=checkpoint_path, 

141 config=config, 

142 use_gpu=use_gpu, 

143 memory_demanding=memory_demanding, 

144 ) 

145 

146 def _load_model(self): 

147 import mxnet as mx 

148 

149 sym, arg_params, aux_params = mx.model.load_checkpoint( 

150 os.path.join(os.path.dirname(self.checkpoint_path), "model"), 0 

151 ) 

152 

153 all_layers = sym.get_internals() 

154 sym = all_layers["fc1_output"] 

155 

156 # LOADING CHECKPOINT 

157 ctx = mx.gpu() if self.use_gpu else mx.cpu() 

158 model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) 

159 data_shape = (1, 3, 112, 112) 

160 model.bind(data_shapes=[("data", data_shape)]) 

161 model.set_params(arg_params, aux_params) 

162 

163 self.model = model 

164 

165 

166from bob.bio.base.algorithm import Distance 

167from bob.bio.base.pipelines import PipelineSimple 

168from bob.bio.face.utils import ( 

169 cropped_positions_arcface, 

170 dnn_default_cropping, 

171 embedding_transformer, 

172) 

173 

174 

175def arcface_template(embedding, annotation_type, fixed_positions=None): 

176 # DEFINE CROPPING 

177 cropped_image_size = (112, 112) 

178 

179 if annotation_type == "eyes-center" or annotation_type == "bounding-box": 

180 # Hard coding eye positions for backward consistency 

181 # cropped_positions = { 

182 cropped_positions = cropped_positions_arcface() 

183 if annotation_type == "bounding-box": 

184 # This will allow us to use `BoundingBoxAnnotatorCrop` 

185 cropped_positions.update( 

186 {"topleft": (0, 0), "bottomright": cropped_image_size} 

187 ) 

188 

189 elif isinstance(annotation_type, list): 

190 cropped_positions = cropped_positions_arcface(annotation_type) 

191 else: 

192 cropped_positions = dnn_default_cropping( 

193 cropped_image_size, annotation_type 

194 ) 

195 

196 annotator = MTCNN(min_size=40, factor=0.709, thresholds=(0.1, 0.2, 0.2)) 

197 transformer = embedding_transformer( 

198 cropped_image_size=cropped_image_size, 

199 embedding=embedding, 

200 cropped_positions=cropped_positions, 

201 fixed_positions=fixed_positions, 

202 color_channel="rgb", 

203 annotator=annotator, 

204 ) 

205 

206 algorithm = Distance() 

207 

208 return PipelineSimple(transformer, algorithm) 

209 

210 

211def arcface_insightFace_lresnet100( 

212 annotation_type, fixed_positions=None, memory_demanding=False 

213): 

214 return arcface_template( 

215 ArcFaceInsightFace_LResNet100(memory_demanding=memory_demanding), 

216 annotation_type, 

217 fixed_positions, 

218 )