Coverage for src/bob/bio/face/embeddings/mxnet.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
4# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
6import os
8import numpy as np
10from sklearn.base import BaseEstimator, TransformerMixin
11from sklearn.utils import check_array
13from bob.bio.base.database.utils import download_file, md5_hash
14from bob.bio.face.annotator import MTCNN
17class MxNetTransformer(TransformerMixin, BaseEstimator):
19 """
20 Base Transformer for MxNet architectures.
22 Parameters:
23 -----------
25 checkpoint_path : str
26 Path containing the checkpoint
28 config : str
29 json file containing the DNN spec
31 preprocessor:
32 A function that will transform the data right before forward. The default transformation is `X=X`
34 use_gpu: bool
35 """
37 def __init__(
38 self,
39 checkpoint_path=None,
40 config=None,
41 use_gpu=False,
42 memory_demanding=False,
43 preprocessor=lambda x: x,
44 **kwargs,
45 ):
46 super().__init__(**kwargs)
47 self.checkpoint_path = checkpoint_path
48 self.config = config
49 self.use_gpu = use_gpu
50 self.model = None
51 self.memory_demanding = memory_demanding
52 self.preprocessor = preprocessor
54 def _load_model(self):
55 import warnings
57 import mxnet as mx
59 from mxnet import gluon
61 ctx = mx.gpu() if self.use_gpu else mx.cpu()
63 with warnings.catch_warnings():
64 warnings.simplefilter("ignore")
65 deserialized_net = gluon.nn.SymbolBlock.imports(
66 self.config, ["data"], self.checkpoint_path, ctx=ctx
67 )
69 self.model = deserialized_net
71 def transform(self, X):
72 import mxnet as mx
74 if self.model is None:
75 self._load_model()
77 X = check_array(X, allow_nd=True)
78 X = self.preprocessor(X)
80 def _transform(X):
81 X = mx.nd.array(X)
82 db = mx.io.DataBatch(data=(X,))
83 self.model.forward(db, is_train=False)
84 return self.model.get_outputs()[0].asnumpy()
86 if self.memory_demanding:
87 features = np.array([_transform(x[None, ...]) for x in X])
89 # If we ndim is > than 3. We should stack them all
90 # The enroll_features can come from a source where there are `N` samples containing
91 # nxd samples
92 if features.ndim >= 3:
93 features = np.vstack(features)
95 return features
96 else:
97 return _transform(X)
99 def __getstate__(self):
100 # Handling unpicklable objects
102 d = self.__dict__.copy()
103 d["model"] = None
104 return d
106 def _more_tags(self):
107 return {"requires_fit": False}
110class ArcFaceInsightFace_LResNet100(MxNetTransformer):
111 """
112 Extracts features using deep face recognition models under MxNet Interfaces.
114 Users can download the pretrained face recognition models with MxNet Interface. The path to downloaded models (and weights) should be specified while calling this class, usually in the configuration file of an experiment.
116 Examples: (Pretrained ResNet models): `LResNet100E-IR,ArcFace@ms1m-refine-v2 <https://github.com/deepinsight/insightface>`_
118 The extracted features can be combined with different the algorithms.
120 """
122 def __init__(self, memory_demanding=False, use_gpu=False):
123 urls = [
124 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/mxnet/arcface_r100_v1_mxnet.tar.gz",
125 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/mxnet/arcface_r100_v1_mxnet.tar.gz",
126 ]
127 filename = download_file(
128 urls=urls,
129 destination_sub_directory="data/mxnet/arcface_r100_v1_mxnet",
130 destination_filename="arcface_r100_v1_mxnet.tar.gz",
131 checksum="050ce7d6e731e560127c705f61391f48",
132 checksum_fct=md5_hash,
133 extract=True,
134 )
135 path = filename
136 checkpoint_path = os.path.join(path, "model-symbol.json")
137 config = os.path.join(path, "model-0000.params")
139 super(ArcFaceInsightFace_LResNet100, self).__init__(
140 checkpoint_path=checkpoint_path,
141 config=config,
142 use_gpu=use_gpu,
143 memory_demanding=memory_demanding,
144 )
146 def _load_model(self):
147 import mxnet as mx
149 sym, arg_params, aux_params = mx.model.load_checkpoint(
150 os.path.join(os.path.dirname(self.checkpoint_path), "model"), 0
151 )
153 all_layers = sym.get_internals()
154 sym = all_layers["fc1_output"]
156 # LOADING CHECKPOINT
157 ctx = mx.gpu() if self.use_gpu else mx.cpu()
158 model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
159 data_shape = (1, 3, 112, 112)
160 model.bind(data_shapes=[("data", data_shape)])
161 model.set_params(arg_params, aux_params)
163 self.model = model
166from bob.bio.base.algorithm import Distance
167from bob.bio.base.pipelines import PipelineSimple
168from bob.bio.face.utils import (
169 cropped_positions_arcface,
170 dnn_default_cropping,
171 embedding_transformer,
172)
175def arcface_template(embedding, annotation_type, fixed_positions=None):
176 # DEFINE CROPPING
177 cropped_image_size = (112, 112)
179 if annotation_type == "eyes-center" or annotation_type == "bounding-box":
180 # Hard coding eye positions for backward consistency
181 # cropped_positions = {
182 cropped_positions = cropped_positions_arcface()
183 if annotation_type == "bounding-box":
184 # This will allow us to use `BoundingBoxAnnotatorCrop`
185 cropped_positions.update(
186 {"topleft": (0, 0), "bottomright": cropped_image_size}
187 )
189 elif isinstance(annotation_type, list):
190 cropped_positions = cropped_positions_arcface(annotation_type)
191 else:
192 cropped_positions = dnn_default_cropping(
193 cropped_image_size, annotation_type
194 )
196 annotator = MTCNN(min_size=40, factor=0.709, thresholds=(0.1, 0.2, 0.2))
197 transformer = embedding_transformer(
198 cropped_image_size=cropped_image_size,
199 embedding=embedding,
200 cropped_positions=cropped_positions,
201 fixed_positions=fixed_positions,
202 color_channel="rgb",
203 annotator=annotator,
204 )
206 algorithm = Distance()
208 return PipelineSimple(transformer, algorithm)
211def arcface_insightFace_lresnet100(
212 annotation_type, fixed_positions=None, memory_demanding=False
213):
214 return arcface_template(
215 ArcFaceInsightFace_LResNet100(memory_demanding=memory_demanding),
216 annotation_type,
217 fixed_positions,
218 )