Coverage for src/bob/bio/face/embeddings/pytorch.py: 85%
244 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
4# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
6import imp
7import os
9import numpy as np
10import torch
12from sklearn.base import BaseEstimator, TransformerMixin
13from sklearn.utils import check_array
15from bob.bio.base.algorithm import Distance
16from bob.bio.base.database.utils import download_file, md5_hash
17from bob.bio.base.pipelines import PipelineSimple
18from bob.bio.face.annotator import MTCNN
19from bob.bio.face.pytorch.facexzoo import FaceXZooModelFactory
20from bob.bio.face.utils import (
21 cropped_positions_arcface,
22 dnn_default_cropping,
23 embedding_transformer,
24)
27class PyTorchModel(TransformerMixin, BaseEstimator):
28 """
29 Base Transformer using pytorch models
32 Parameters
33 ----------
35 checkpoint_path: str
36 Path containing the checkpoint
38 config:
39 Path containing some configuration file (e.g. .json, .prototxt)
41 preprocessor:
42 A function that will transform the data right before forward. The default transformation is `X/255`
44 """
46 def __init__(
47 self,
48 checkpoint_path=None,
49 config=None,
50 preprocessor=lambda x: x / 255,
51 memory_demanding=False,
52 device=None,
53 **kwargs,
54 ):
55 super().__init__(**kwargs)
56 self.checkpoint_path = checkpoint_path
57 self.config = config
58 self.model = None
59 self.preprocessor = preprocessor
60 self.memory_demanding = memory_demanding
61 self.device = torch.device(
62 device or "cuda" if torch.cuda.is_available() else "cpu"
63 )
65 def transform(self, X):
66 """__call__(image) -> feature
68 Extracts the features from the given image.
70 Parameters
71 ----------
73 image : 2D :py:class:`numpy.ndarray` (floats)
74 The image to extract the features from.
76 Returns
77 -------
79 feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
80 The list of features extracted from the image.
81 """
82 import torch
84 if self.model is None:
85 self._load_model()
86 X = check_array(X, allow_nd=True)
87 X = torch.Tensor(X)
88 with torch.no_grad():
89 X = self.preprocessor(X)
91 def _transform(X):
92 with torch.no_grad():
93 return self.model(X.to(self.device)).cpu().detach().numpy()
95 if self.memory_demanding:
96 features = np.array([_transform(x[None, ...]) for x in X])
98 # If we ndim is > than 3. We should stack them all
99 # The enroll_features can come from a source where there are `N` samples containing
100 # nxd samples
101 if features.ndim >= 3:
102 features = np.vstack(features)
104 return features
106 else:
107 return _transform(X)
109 def __getstate__(self):
110 # Handling unpicklable objects
112 d = self.__dict__.copy()
113 d["model"] = None
114 return d
116 def _more_tags(self):
117 return {"requires_fit": False}
119 def place_model_on_device(self):
120 if self.model is not None:
121 self.model.to(self.device)
124class RunnableModel(PyTorchModel):
125 """
126 Runnable pytorch model
128 With this class it is possible to pass a loaded pytorch model as an argument.
130 The parameter `model` can be set in two ways.
131 The first one via an pytorch module instance. Like in the example below:
133 >>> model = my_py_torch_model.load("PATH")
134 >>> transformer = RunnableModel(model) #doctest: +SKIP
137 The second one is via a `callable` function.
138 The mode is useful while running under Dask.
139 This avoids the serialization and data transfer of the model.
140 At the Idiap's network, this can be a problem.
141 See the example below on how to set the `model` as a callable:
143 >>> from functools import partial
144 >>> model = partial(my_py_torch_model.load,"PATH")
145 >>> transformer = RunnableModel(model) #doctest: +SKIP
149 Parameters:
150 ----------
151 model:
152 Loaded pytorch model OR a function that loads the model.
153 Providing a function that loads the model might be useful for the
154 idiap grid
156 preprocessor:
157 A function that will transform the data right before forward. The default transformation is `X/255`
159 memory_demanding:
161 device:
163 """
165 def __init__(
166 self,
167 model,
168 preprocessor=lambda x: x / 255,
169 memory_demanding=False,
170 device=None,
171 **kwargs,
172 ):
173 super(RunnableModel, self).__init__(
174 preprocessor=preprocessor,
175 memory_demanding=memory_demanding,
176 device=device,
177 **kwargs,
178 )
180 if callable(model):
181 self.model = None
182 self._model_fn = model
183 self.is_loaded_by_function = True
184 else:
185 self.model = model
186 self.model.eval()
187 self.is_loaded_by_function = False
189 def _load_model(self):
190 self.model = self._model_fn()
191 self.model.eval()
193 def __getstate__(self):
194 if self.is_loaded_by_function:
195 return super(RunnableModel, self).__getstate__()
198class AFFFE_2021(PyTorchModel):
199 """
200 AFFFE Pytorch network that extracts 1000-dimensional features, trained by Manuel Gunther, as described in [LGB18]_
202 """
204 def __init__(self, memory_demanding=False, device=None, **kwargs):
205 urls = [
206 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
207 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
208 ]
210 path = download_file(
211 urls=urls,
212 destination_sub_directory="data/pytorch/AFFFE-42a53f19",
213 destination_filename="AFFFE-42a53f19.tar.gz",
214 checksum="1358bbcda62cb59b85b2418ef1f81e9b",
215 checksum_fct=md5_hash,
216 extract=True,
217 )
218 config = os.path.join(path, "AFFFE.py")
219 checkpoint_path = os.path.join(path, "AFFFE.pth")
221 super(AFFFE_2021, self).__init__(
222 checkpoint_path,
223 config,
224 memory_demanding=memory_demanding,
225 device=device,
226 **kwargs,
227 )
229 def _load_model(self):
230 import torch
232 _ = imp.load_source("MainModel", self.config)
233 self.model = torch.load(self.checkpoint_path, map_location=self.device)
235 self.model.eval()
236 self.place_model_on_device()
239def _get_iresnet_file():
240 urls = [
241 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
242 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
243 ]
245 return download_file(
246 urls=urls,
247 destination_sub_directory="data/pytorch/iresnet-91a5de61",
248 destination_filename="iresnet-91a5de61.tar.gz",
249 checksum="3976c0a539811d888ef5b6217e5de425",
250 checksum_fct=md5_hash,
251 extract=True,
252 )
255class IResnet34(PyTorchModel):
256 """
257 ArcFace model (RESNET 34) from Insightface ported to pytorch
258 """
260 def __init__(
261 self,
262 preprocessor=lambda x: (x - 127.5) / 128.0,
263 memory_demanding=False,
264 device=None,
265 **kwargs,
266 ):
267 path = _get_iresnet_file()
269 config = os.path.join(path, "iresnet.py")
270 checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
272 super(IResnet34, self).__init__(
273 checkpoint_path,
274 config,
275 memory_demanding=memory_demanding,
276 preprocessor=preprocessor,
277 device=device,
278 **kwargs,
279 )
281 def _load_model(self):
282 model = imp.load_source("module", self.config).iresnet34(
283 self.checkpoint_path
284 )
285 self.model = model
287 self.model.eval()
288 self.place_model_on_device()
291class IResnet50(PyTorchModel):
292 """
293 ArcFace model (RESNET 50) from Insightface ported to pytorch
294 """
296 def __init__(
297 self,
298 preprocessor=lambda x: (x - 127.5) / 128.0,
299 memory_demanding=False,
300 device=None,
301 **kwargs,
302 ):
303 path = _get_iresnet_file()
305 config = os.path.join(path, "iresnet.py")
306 checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
308 super(IResnet50, self).__init__(
309 checkpoint_path,
310 config,
311 memory_demanding=memory_demanding,
312 preprocessor=preprocessor,
313 device=device,
314 **kwargs,
315 )
317 def _load_model(self):
318 model = imp.load_source("module", self.config).iresnet50(
319 self.checkpoint_path
320 )
321 self.model = model
323 self.model.eval()
324 self.place_model_on_device()
327class IResnet100(PyTorchModel):
328 """
329 ArcFace model (RESNET 100) from Insightface ported to pytorch
330 """
332 def __init__(
333 self,
334 preprocessor=lambda x: (x - 127.5) / 128.0,
335 memory_demanding=False,
336 device=None,
337 **kwargs,
338 ):
339 path = _get_iresnet_file()
341 config = os.path.join(path, "iresnet.py")
342 checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
344 super(IResnet100, self).__init__(
345 checkpoint_path,
346 config,
347 memory_demanding=memory_demanding,
348 preprocessor=preprocessor,
349 device=device,
350 **kwargs,
351 )
353 def _load_model(self):
354 model = imp.load_source("module", self.config).iresnet100(
355 self.checkpoint_path
356 )
357 self.model = model
359 self.model.eval()
360 self.place_model_on_device()
363class OxfordVGG2Resnets(PyTorchModel):
364 """
365 Get the transformer for the resnet based models from Oxford.
366 All these models were training the the VGG2 dataset.
368 Models taken from: https://www.robots.ox.ac.uk/~albanie
371 Parameters
372 ----------
374 model_name: str
375 One of the 4 models available (`resnet50_scratch_dag`, `resnet50_ft_dag`, `senet50_ft_dag`, `senet50_scratch_dag`).
377 """
379 def __init__(
380 self,
381 model_name,
382 memory_demanding=False,
383 device=None,
384 **kwargs,
385 ):
386 urls = [
387 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/oxford_resnet50_vgg2.tar.gz",
388 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/oxford_resnet50_vgg2.tar.gz",
389 ]
391 path = download_file(
392 urls=urls,
393 destination_sub_directory="data/pytorch/oxford_resnet50_vgg2",
394 destination_filename="oxford_resnet50_vgg2.tar.gz",
395 checksum="c8e1ed3715d83647b4a02e455213aaf0",
396 checksum_fct=md5_hash,
397 extract=True,
398 )
400 models_available = [
401 "resnet50_scratch_dag",
402 "resnet50_ft_dag",
403 "senet50_ft_dag",
404 "senet50_scratch_dag",
405 ]
406 if model_name not in models_available:
407 raise ValueError(
408 f"Invalid model {model_name}. The models available are {models_available}"
409 )
411 self.model_name = model_name
412 config = os.path.join(path, model_name, f"{model_name}.py")
413 checkpoint_path = os.path.join(path, model_name, f"{model_name}.pth")
415 super(OxfordVGG2Resnets, self).__init__(
416 checkpoint_path,
417 config,
418 memory_demanding=memory_demanding,
419 preprocessor=self.dag_preprocessor,
420 device=device,
421 **kwargs,
422 )
424 def dag_preprocessor(self, X):
425 """
426 Normalize using `self.meta`
428 Caffe has the shape `H x W x C` and the chanel is BGR and
430 """
432 # Convert to H x W x C
433 # X = torch.moveaxis(X, 1, 3)
435 # Subtracting
436 X[:, 0, :, :] = (X[:, 0, :, :] - self.meta["mean"][0]) / self.meta[
437 "std"
438 ][0]
439 X[:, 1, :, :] = (X[:, 1, :, :] - self.meta["mean"][1]) / self.meta[
440 "std"
441 ][1]
442 X[:, 2, :, :] = (X[:, 2, :, :] - self.meta["mean"][2]) / self.meta[
443 "std"
444 ][2]
446 return X
448 def _load_model(self):
449 if self.model_name == "resnet50_scratch_dag":
450 model = imp.load_source("module", self.config).resnet50_scratch_dag(
451 weights_path=self.checkpoint_path
452 )
453 elif self.model_name == "resnet50_ft_dag":
454 model = imp.load_source("module", self.config).resnet50_ft_dag(
455 weights_path=self.checkpoint_path
456 )
457 elif self.model_name == "senet50_scratch_dag":
458 model = imp.load_source("module", self.config).senet50_scratch_dag(
459 weights_path=self.checkpoint_path
460 )
461 else:
462 model = imp.load_source("module", self.config).senet50_ft_dag(
463 weights_path=self.checkpoint_path
464 )
466 self.model = model
467 self.meta = self.model.meta
469 self.model.eval()
470 self.place_model_on_device()
472 def transform(self, X):
473 """__call__(image) -> feature
475 Extracts the features from the given image.
477 **Parameters:**
479 image : 2D :py:class:`numpy.ndarray` (floats)
480 The image to extract the features from.
482 **Returns:**
484 feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
485 The list of features extracted from the image.
486 """
487 import torch
489 if self.model is None:
490 self._load_model()
491 X = check_array(X, allow_nd=True)
492 X = torch.Tensor(X)
493 with torch.no_grad():
494 X = self.preprocessor(X)
496 def _transform(X):
497 with torch.no_grad():
498 # Fetching the pool5_7x7_s1 layer which
499 return (
500 self.model(X.to(self.device))[1]
501 .cpu()
502 .detach()
503 .numpy()[:, :, 0, 0]
504 )
506 if self.memory_demanding:
507 features = np.array([_transform(x[None, ...]) for x in X])
509 # If we ndim is > than 3. We should stack them all
510 # The enroll_features can come from a source where there are `N` samples containing
511 # nxd samples
512 if features.ndim >= 3:
513 features = np.vstack(features)
515 return features
517 else:
518 return _transform(X)
521class IResnet100Elastic(PyTorchModel):
522 """
523 iResnet100 model from the paper.
525 Boutros, Fadi, et al. "ElasticFace: Elastic Margin Loss for Deep Face Recognition." arXiv preprint arXiv:2109.09416 (2021).
526 """
528 def __init__(
529 self,
530 preprocessor=lambda x: (x - 127.5) / 128.0,
531 memory_demanding=False,
532 device=None,
533 **kwargs,
534 ):
535 urls = [
536 "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz",
537 "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz",
538 ]
540 path = download_file(
541 urls=urls,
542 destination_sub_directory="data/pytorch/iresnet100-elastic",
543 destination_filename="iresnet100-elastic.tar.gz",
544 checksum="0ac36db3f0f94930993afdb27faa4f02",
545 checksum_fct=md5_hash,
546 extract=True,
547 )
549 config = os.path.join(path, "iresnet.py")
550 checkpoint_path = os.path.join(path, "iresnet100-elastic.pt")
552 super(IResnet100Elastic, self).__init__(
553 checkpoint_path,
554 config,
555 memory_demanding=memory_demanding,
556 preprocessor=preprocessor,
557 device=device,
558 **kwargs,
559 )
561 def _load_model(self):
562 model = imp.load_source("module", self.config).iresnet100(
563 self.checkpoint_path
564 )
565 self.model = model
567 self.model.eval()
568 self.place_model_on_device()
571class FaceXZooModel(PyTorchModel):
572 """
573 FaceXZoo models
574 """
576 def __init__(
577 self,
578 preprocessor=lambda x: (x - 127.5) / 128.0,
579 memory_demanding=False,
580 device=None,
581 arch="AttentionNet",
582 **kwargs,
583 ):
584 self.arch = arch
585 _model = FaceXZooModelFactory(self.arch)
586 path = _model.get_facexzoo_file()
587 config = None
588 checkpoint_path = os.path.join(path, self.arch + ".pt")
590 super(FaceXZooModel, self).__init__(
591 checkpoint_path,
592 config,
593 memory_demanding=memory_demanding,
594 preprocessor=preprocessor,
595 device=device,
596 **kwargs,
597 )
599 def _load_model(self):
600 _model = FaceXZooModelFactory(self.arch)
601 self.model = _model.get_model()
603 model_dict = self.model.state_dict()
605 pretrained_dict = torch.load(
606 self.checkpoint_path, map_location=torch.device("cpu")
607 )["state_dict"]
609 new_pretrained_dict = {}
610 for k in model_dict:
611 new_pretrained_dict[k] = pretrained_dict["backbone." + k]
612 model_dict.update(new_pretrained_dict)
613 self.model.load_state_dict(model_dict)
615 self.model.eval()
616 self.place_model_on_device()
619def iresnet_template(embedding, annotation_type, fixed_positions=None):
620 # DEFINE CROPPING
621 cropped_image_size = (112, 112)
623 if (
624 annotation_type == "eyes-center"
625 or annotation_type == "bounding-box"
626 or isinstance(annotation_type, list)
627 ):
628 # Hard coding eye positions for backward consistency
629 # cropped_positions = {
630 cropped_positions = cropped_positions_arcface(annotation_type)
631 if annotation_type == "bounding-box":
632 # This will allow us to use `BoundingBoxAnnotatorCrop`
633 cropped_positions.update(
634 {"topleft": (0, 0), "bottomright": cropped_image_size}
635 )
637 else:
638 cropped_positions = dnn_default_cropping(
639 cropped_image_size, annotation_type
640 )
642 annotator = MTCNN(min_size=40, factor=0.709, thresholds=(0.1, 0.2, 0.2))
643 transformer = embedding_transformer(
644 cropped_image_size=cropped_image_size,
645 embedding=embedding,
646 cropped_positions=cropped_positions,
647 fixed_positions=fixed_positions,
648 color_channel="rgb",
649 annotator=annotator,
650 )
652 algorithm = Distance()
654 return PipelineSimple(transformer, algorithm)
657def AttentionNet(
658 annotation_type,
659 fixed_positions=None,
660 memory_demanding=False,
661 device=torch.device("cpu"),
662):
663 """
664 Get the AttentionNet pipeline which will crop the face :math:`112 \\times 112` and
665 use the :py:class:`AttentionNet` to extract the features
668 .. warning::
670 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
674 Parameters
675 ----------
677 annotation_type: str
678 Type of the annotations (e.g. `eyes-center')
680 fixed_positions: dict
681 Set it if in your face images are registered to a fixed position in the image
683 memory_demanding: bool
685 """
686 return iresnet_template(
687 embedding=FaceXZooModel(
688 arch="AttentionNet",
689 memory_demanding=memory_demanding,
690 device=device,
691 ),
692 annotation_type=annotation_type,
693 fixed_positions=fixed_positions,
694 )
697def ResNeSt(
698 annotation_type,
699 fixed_positions=None,
700 memory_demanding=False,
701 device=torch.device("cpu"),
702):
703 """
704 Get the ResNeSt pipeline which will crop the face :math:`112 \\times 112` and
705 use the :py:class:`ResNeSt` to extract the features
707 .. warning::
709 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
713 Parameters
714 ----------
716 annotation_type: str
717 Type of the annotations (e.g. `eyes-center')
719 fixed_positions: dict
720 Set it if in your face images are registered to a fixed position in the image
722 memory_demanding: bool
724 """
726 return iresnet_template(
727 embedding=FaceXZooModel(
728 arch="ResNeSt", memory_demanding=memory_demanding, device=device
729 ),
730 annotation_type=annotation_type,
731 fixed_positions=fixed_positions,
732 )
735def MobileFaceNet(
736 annotation_type,
737 fixed_positions=None,
738 memory_demanding=False,
739 device=torch.device("cpu"),
740):
741 """
742 Get the MobileFaceNet pipeline which will crop the face :math:`112 \\times 112` and
743 use the :py:class:`MobileFaceNet` to extract the features
745 .. warning::
747 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
751 Parameters
752 ----------
754 annotation_type: str
755 Type of the annotations (e.g. `eyes-center')
757 fixed_positions: dict
758 Set it if in your face images are registered to a fixed position in the image
760 memory_demanding: bool
762 """
764 return iresnet_template(
765 embedding=FaceXZooModel(
766 arch="MobileFaceNet",
767 memory_demanding=memory_demanding,
768 device=device,
769 ),
770 annotation_type=annotation_type,
771 fixed_positions=fixed_positions,
772 )
775def ResNet(
776 annotation_type,
777 fixed_positions=None,
778 memory_demanding=False,
779 device=torch.device("cpu"),
780):
781 """
782 Get the ResNet pipeline which will crop the face :math:`112 \\times 112` and
783 use the :py:class:`ResNet` to extract the features
786 .. warning::
788 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
791 Parameters
792 ----------
794 annotation_type: str
795 Type of the annotations (e.g. `eyes-center')
797 fixed_positions: dict
798 Set it if in your face images are registered to a fixed position in the image
800 memory_demanding: bool
802 """
804 return iresnet_template(
805 embedding=FaceXZooModel(
806 arch="ResNet", memory_demanding=memory_demanding, device=device
807 ),
808 annotation_type=annotation_type,
809 fixed_positions=fixed_positions,
810 )
813def EfficientNet(
814 annotation_type,
815 fixed_positions=None,
816 memory_demanding=False,
817 device=torch.device("cpu"),
818):
819 """
820 Get the EfficientNet pipeline which will crop the face :math:`112 \\times 112` and
821 use the :py:class:`EfficientNet` to extract the features
824 .. warning::
826 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
830 Parameters
831 ----------
833 annotation_type: str
834 Type of the annotations (e.g. `eyes-center')
836 fixed_positions: dict
837 Set it if in your face images are registered to a fixed position in the image
839 memory_demanding: bool
841 """
843 return iresnet_template(
844 embedding=FaceXZooModel(
845 arch="EfficientNet",
846 memory_demanding=memory_demanding,
847 device=device,
848 ),
849 annotation_type=annotation_type,
850 fixed_positions=fixed_positions,
851 )
854def TF_NAS(
855 annotation_type,
856 fixed_positions=None,
857 memory_demanding=False,
858 device=torch.device("cpu"),
859):
860 """
861 Get the TF_NAS pipeline which will crop the face :math:`112 \\times 112` and
862 use the :py:class:`TF-NAS` to extract the features
865 .. warning::
867 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
871 Parameters
872 ----------
874 annotation_type: str
875 Type of the annotations (e.g. `eyes-center')
877 fixed_positions: dict
878 Set it if in your face images are registered to a fixed position in the image
880 memory_demanding: bool
882 """
884 return iresnet_template(
885 embedding=FaceXZooModel(
886 arch="TF-NAS", memory_demanding=memory_demanding, device=device
887 ),
888 annotation_type=annotation_type,
889 fixed_positions=fixed_positions,
890 )
893def HRNet(
894 annotation_type,
895 fixed_positions=None,
896 memory_demanding=False,
897 device=torch.device("cpu"),
898):
899 """
900 Get the HRNet pipeline which will crop the face :math:`112 \\times 112` and
901 use the :py:class:`HRNet` to extract the features
904 .. warning::
906 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
909 Parameters
910 ----------
912 annotation_type: str
913 Type of the annotations (e.g. `eyes-center')
915 fixed_positions: dict
916 Set it if in your face images are registered to a fixed position in the image
918 memory_demanding: bool
920 """
922 return iresnet_template(
923 embedding=FaceXZooModel(
924 arch="HRNet", memory_demanding=memory_demanding, device=device
925 ),
926 annotation_type=annotation_type,
927 fixed_positions=fixed_positions,
928 )
931def ReXNet(
932 annotation_type,
933 fixed_positions=None,
934 memory_demanding=False,
935 device=torch.device("cpu"),
936):
937 """
938 Get the ReXNet pipeline which will crop the face :math:`112 \\times 112` and
939 use the :py:class:`ReXNet` to extract the features
941 .. warning::
943 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
947 Parameters
948 ----------
950 annotation_type: str
951 Type of the annotations (e.g. `eyes-center')
953 fixed_positions: dict
954 Set it if in your face images are registered to a fixed position in the image
956 memory_demanding: bool
958 """
960 return iresnet_template(
961 embedding=FaceXZooModel(
962 arch="ReXNet", memory_demanding=memory_demanding
963 ),
964 annotation_type=annotation_type,
965 fixed_positions=fixed_positions,
966 )
969def GhostNet(
970 annotation_type,
971 fixed_positions=None,
972 memory_demanding=False,
973 device=torch.device("cpu"),
974):
975 """
976 Get the GhostNet pipeline which will crop the face :math:`112 \\times 112` and
977 use the :py:class:`GhostNet` to extract the features
980 .. warning::
982 If you are at Idiap, please use the option `-l sge-gpu` while running the `pipeline simple` pipeline.
985 Parameters
986 ----------
988 annotation_type: str
989 Type of the annotations (e.g. `eyes-center')
991 fixed_positions: dict
992 Set it if in your face images are registered to a fixed position in the image
994 memory_demanding: bool
996 """
998 return iresnet_template(
999 embedding=FaceXZooModel(
1000 arch="GhostNet", memory_demanding=memory_demanding, device=device
1001 ),
1002 annotation_type=annotation_type,
1003 fixed_positions=fixed_positions,
1004 )
1007def iresnet34(
1008 annotation_type,
1009 fixed_positions=None,
1010 memory_demanding=False,
1011 device=torch.device("cpu"),
1012):
1013 """
1014 Get the Resnet34 pipeline which will crop the face :math:`112 \\times 112` and
1015 use the :py:class:`IResnet34` to extract the features
1018 code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
1019 https://github.com/nizhib/pytorch-insightface
1022 Parameters
1023 ----------
1025 annotation_type: str
1026 Type of the annotations (e.g. `eyes-center')
1028 fixed_positions: dict
1029 Set it if in your face images are registered to a fixed position in the image
1031 memory_demanding: bool
1033 """
1035 return iresnet_template(
1036 embedding=IResnet34(memory_demanding=memory_demanding, device=device),
1037 annotation_type=annotation_type,
1038 fixed_positions=fixed_positions,
1039 )
1042def iresnet50(
1043 annotation_type,
1044 fixed_positions=None,
1045 memory_demanding=False,
1046 device=torch.device("cpu"),
1047):
1048 """
1049 Get the Resnet50 pipeline which will crop the face :math:`112 \\times 112` and
1050 use the :py:class:`IResnet50` to extract the features
1053 code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
1054 https://github.com/nizhib/pytorch-insightface
1057 Parameters
1058 ----------
1060 annotation_type: str
1061 Type of the annotations (e.g. `eyes-center')
1063 fixed_positions: dict
1064 Set it if in your face images are registered to a fixed position in the image
1066 memory_demanding: bool
1068 """
1070 return iresnet_template(
1071 embedding=IResnet50(memory_demanding=memory_demanding, device=device),
1072 annotation_type=annotation_type,
1073 fixed_positions=fixed_positions,
1074 )
1077def iresnet100(
1078 annotation_type,
1079 fixed_positions=None,
1080 memory_demanding=False,
1081 device=torch.device("cpu"),
1082):
1083 """
1084 Get the Resnet100 pipeline which will crop the face :math:`112 \\times 112` and
1085 use the :py:class:`IResnet100` to extract the features
1088 code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
1089 https://github.com/nizhib/pytorch-insightface
1092 Parameters
1093 ----------
1095 annotation_type: str
1096 Type of the annotations (e.g. `eyes-center')
1098 fixed_positions: dict
1099 Set it if in your face images are registered to a fixed position in the image
1101 memory_demanding: bool
1103 """
1105 return iresnet_template(
1106 embedding=IResnet100(memory_demanding=memory_demanding, device=device),
1107 annotation_type=annotation_type,
1108 fixed_positions=fixed_positions,
1109 )
1112def iresnet100_elastic(
1113 annotation_type,
1114 fixed_positions=None,
1115 memory_demanding=False,
1116 device=torch.device("cpu"),
1117):
1118 """
1119 Get the Resnet100 pipeline which will crop the face :math:`112 \\times 112` and
1120 use the :py:class:`IResnet100` to extract the features
1123 code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
1124 https://github.com/nizhib/pytorch-insightface
1127 Parameters
1128 ----------
1130 annotation_type: str
1131 Type of the annotations (e.g. `eyes-center')
1133 fixed_positions: dict
1134 Set it if in your face images are registered to a fixed position in the image
1136 memory_demanding: bool
1138 """
1140 return iresnet_template(
1141 embedding=IResnet100Elastic(
1142 memory_demanding=memory_demanding, device=device
1143 ),
1144 annotation_type=annotation_type,
1145 fixed_positions=fixed_positions,
1146 )
1149def afffe_baseline(
1150 annotation_type,
1151 fixed_positions=None,
1152 memory_demanding=False,
1153 device=torch.device("cpu"),
1154):
1155 """
1156 Get the AFFFE pipeline which will crop the face :math:`224 \\times 224`
1157 use the :py:class:`AFFFE_2021`
1159 Parameters
1160 ----------
1162 annotation_type: str
1163 Type of the annotations (e.g. `eyes-center')
1165 fixed_positions: dict
1166 Set it if in your face images are registered to a fixed position in the image
1167 """
1169 # DEFINE CROPPING
1170 cropped_image_size = (224, 224)
1172 if annotation_type == "eyes-center":
1173 # Hard coding eye positions for backward consistency
1174 cropped_positions = {"leye": (110, 144), "reye": (110, 96)}
1175 else:
1176 cropped_positions = dnn_default_cropping(
1177 cropped_image_size, annotation_type
1178 )
1180 transformer = embedding_transformer(
1181 cropped_image_size=cropped_image_size,
1182 embedding=AFFFE_2021(memory_demanding=memory_demanding, device=device),
1183 cropped_positions=cropped_positions,
1184 fixed_positions=fixed_positions,
1185 color_channel="rgb",
1186 annotator="mtcnn",
1187 )
1189 algorithm = Distance()
1190 from bob.bio.base.pipelines import PipelineSimple
1192 return PipelineSimple(transformer, algorithm)
1195def oxford_vgg2_resnets(
1196 model_name, annotation_type, fixed_positions=None, memory_demanding=False
1197):
1198 """
1199 Get the pipeline for the resnet based models from Oxford.
1200 All these models were training the the VGG2 dataset.
1202 Models taken from: https://www.robots.ox.ac.uk/~albanie
1204 Parameters
1205 ----------
1206 model_name: str
1207 One of the 4 models available (`resnet50_scratch_dag`, `resnet50_ft_dag`, `senet50_ft_dag`, `senet50_scratch_dag`).
1209 annotation_type: str
1210 Type of the annotations (e.g. `eyes-center')
1212 fixed_positions: dict
1213 Set it if in your face images are registered to a fixed position in the image
1214 """
1216 # DEFINE CROPPING
1217 cropped_image_size = (224, 224)
1219 if annotation_type == "eyes-center":
1220 # Coordinates taken from : https://www.merlin.uzh.ch/contributionDocument/download/14240
1221 cropped_positions = {"leye": (100, 159), "reye": (100, 65)}
1222 else:
1223 cropped_positions = dnn_default_cropping(
1224 cropped_image_size, annotation_type
1225 )
1227 transformer = embedding_transformer(
1228 cropped_image_size=cropped_image_size,
1229 embedding=OxfordVGG2Resnets(
1230 model_name=model_name, memory_demanding=memory_demanding
1231 ),
1232 cropped_positions=cropped_positions,
1233 fixed_positions=fixed_positions,
1234 color_channel="rgb",
1235 annotator="mtcnn",
1236 )
1238 algorithm = Distance()
1239 from bob.bio.base.pipelines import PipelineSimple
1241 return PipelineSimple(transformer, algorithm)