Coverage for src/bob/bio/face/annotator/mtcnn.py: 97%
59 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1# Example taken from:
2# https://github.com/blaueck/tf-mtcnn/blob/master/mtcnn_tfv2.py
4import logging
6import pkg_resources
8from bob.bio.face.color import gray_to_rgb
9from bob.io.image import to_matplotlib
11from . import Base
13logger = logging.getLogger(__name__)
16class MTCNN(Base):
18 """MTCNN v1 wrapper for Tensorflow 2. See
19 https://kpzhang93.github.io/MTCNN_face_detection_alignment/index.html for
20 more details on MTCNN.
22 Attributes
23 ----------
24 factor : float
25 Factor is a trade-off between performance and speed.
26 min_size : int
27 Minimum face size to be detected.
28 thresholds : list
29 Thresholds are a trade-off between false positives and missed detections.
30 """
32 def __init__(
33 self, min_size=40, factor=0.709, thresholds=(0.6, 0.7, 0.7), **kwargs
34 ):
35 super().__init__(**kwargs)
36 self.min_size = min_size
37 self.factor = factor
38 self.thresholds = thresholds
39 self._graph_path = pkg_resources.resource_filename(
40 "bob.bio.face", "data/mtcnn.pb"
41 )
43 # Avoids loading graph at initilization
44 self._fun = None
46 @property
47 def mtcnn_fun(self):
48 import tensorflow as tf
50 if self._fun is None:
51 # wrap graph function as a callable function
52 self._fun = tf.compat.v1.wrap_function(
53 self._graph_fn,
54 [
55 tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
56 ],
57 )
58 return self._fun
60 def _graph_fn(self, img):
61 import tensorflow as tf
63 with open(self._graph_path, "rb") as f:
64 graph_def = tf.compat.v1.GraphDef.FromString(f.read())
66 prob, landmarks, box = tf.compat.v1.import_graph_def(
67 graph_def,
68 input_map={
69 "input:0": img,
70 "min_size:0": tf.convert_to_tensor(self.min_size, dtype=float),
71 "thresholds:0": tf.convert_to_tensor(
72 self.thresholds, dtype=float
73 ),
74 "factor:0": tf.convert_to_tensor(self.factor, dtype=float),
75 },
76 return_elements=["prob:0", "landmarks:0", "box:0"],
77 name="",
78 )
79 return box, prob, landmarks
81 def __getstate__(self):
82 # Handling unpicklable objects
83 state = {}
84 for key, value in super().__getstate__().items():
85 if key != "_fun":
86 state[key] = value
87 state["_fun"] = None
88 return state
90 def detect(self, image):
91 """Detects all faces in the image.
93 Parameters
94 ----------
95 image : numpy.ndarray
96 An RGB image in Bob format.
98 Returns
99 -------
100 tuple
101 A tuple of boxes, probabilities, and landmarks.
102 """
103 if len(image.shape) == 2:
104 image = gray_to_rgb(image)
106 # Assuming image is Bob format and RGB
107 assert image.shape[0] == 3, image.shape
108 # MTCNN expects BGR opencv format
109 image = to_matplotlib(image)
110 image = image[..., ::-1]
112 boxes, probs, landmarks = self.mtcnn_fun(image)
113 return boxes, probs, landmarks
115 def annotations(self, image):
116 """Detects all faces in the image and returns annotations in bob format.
118 Parameters
119 ----------
120 image : numpy.ndarray
121 An RGB image in Bob format.
123 Returns
124 -------
125 list
126 A list of annotations. Annotations are dictionaries that contain the
127 following keys: ``topleft``, ``bottomright``, ``reye``, ``leye``, ``nose``,
128 ``mouthright``, ``mouthleft``, and ``quality``.
129 """
130 boxes, probs, landmarks = self.detect(image)
132 # Iterate over all the detected faces
133 annots = []
134 for box, prob, lm in zip(boxes, probs, landmarks):
135 topleft = float(box[0]), float(box[1])
136 bottomright = float(box[2]), float(box[3])
137 right_eye = float(lm[0]), float(lm[5])
138 left_eye = float(lm[1]), float(lm[6])
139 nose = float(lm[2]), float(lm[7])
140 mouthright = float(lm[3]), float(lm[8])
141 mouthleft = float(lm[4]), float(lm[9])
142 annots.append(
143 {
144 "topleft": topleft,
145 "bottomright": bottomright,
146 "reye": right_eye,
147 "leye": left_eye,
148 "nose": nose,
149 "mouthright": mouthright,
150 "mouthleft": mouthleft,
151 "quality": float(prob),
152 }
153 )
154 return annots
156 def annotate(self, image, **kwargs):
157 """Annotates an image using mtcnn
159 Parameters
160 ----------
161 image : numpy.array
162 An RGB image in Bob format.
163 **kwargs
164 Ignored.
166 Returns
167 -------
168 dict
169 Annotations contain: (topleft, bottomright, leye, reye, nose,
170 mouthleft, mouthright, quality).
171 """
172 # return the annotations for the first/largest face
173 annotations = self.annotations(image)
175 if annotations:
176 return annotations[0]
177 else:
178 return None