Coverage for src/bob/bio/face/pytorch/facexzoo/models.py: 94%
18 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1import pkg_resources
3from bob.bio.base.database.utils import download_file, md5_hash
4from bob.bio.face.pytorch.facexzoo.backbone_def import BackboneFactory
6def_backbone_conf = pkg_resources.resource_filename(
7 "bob.bio.face", "pytorch/facexzoo/backbone_conf.yaml"
8)
10info = {
11 "AttentionNet": [
12 "AttentionNet-f4c6f908.pt.tar.gz",
13 "49e435d8d9c075a4f613336090eac242",
14 "AttentionNet-f4c6f908.pt",
15 ],
16 "ResNeSt": [
17 "ResNeSt-e8b132d4.pt.tar.gz",
18 "51eef17ef7c17d1b22bbc13022393f31",
19 "ResNeSt-e8b132d4.pt",
20 ],
21 "MobileFaceNet": [
22 "MobileFaceNet-ca475a8d.pt.tar.gz",
23 "e5fc0ae59d1a290b58a297b37f015e11",
24 "MobileFaceNet-ca475a8d.pt",
25 ],
26 "ResNet": [
27 "ResNet-e07e7fa1.pt.tar.gz",
28 "13596dfeeb7f40c4b746ad2f0b271c36",
29 "ResNet-e07e7fa1.pt",
30 ],
31 "EfficientNet": [
32 "EfficientNet-5aed534e.pt.tar.gz",
33 "31c827017fe2029c1ab57371c8e5abf4",
34 "EfficientNet-5aed534e.pt",
35 ],
36 "TF-NAS": [
37 "TF-NAS-709d8562.pt.tar.gz",
38 "f96fe2683970140568a17c09fff24fab",
39 "TF-NAS-709d8562.pt",
40 ],
41 "HRNet": [
42 "HRNet-edc4da11.pt.tar.gz",
43 "5ed9920e004af440b623339a7008a758",
44 "HRNet-edc4da11.pt",
45 ],
46 "ReXNet": [
47 "ReXNet-7c45620c.pt.tar.gz",
48 "b24cf257a25486c52fde5626007b324b",
49 "ReXNet-7c45620c.pt",
50 ],
51 "GhostNet": [
52 "GhostNet-5f026295.pt.tar.gz",
53 "9edb8327c62b62197ad023f21bd865bc",
54 "GhostNet-5f026295.pt",
55 ],
56}
59class FaceXZooModelFactory:
60 def __init__(self, arch, backbone_conf=def_backbone_conf, info=info):
61 self.arch = arch
62 self.backbone_conf = backbone_conf
63 self.info = info
65 assert self.arch in self.info.keys()
67 def get_model(self):
68 return BackboneFactory(self.arch, self.backbone_conf).get_backbone()
70 def get_checkpoint_name(self):
71 return self.info[self.arch][0]
73 def get_facexzoo_file(self):
74 urls = [
75 "https://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/facexzoomodels/{}".format(
76 self.info[self.arch][0]
77 ),
78 "http://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/facexzoomodels/{}".format(
79 self.info[self.arch][0]
80 ),
81 ]
83 return download_file(
84 urls=urls,
85 destination_filename=self.info[self.arch][0],
86 destination_sub_directory=f"data/pytorch/{self.info[self.arch][0]}/",
87 checksum=self.info[self.arch][1],
88 checksum_fct=md5_hash,
89 extract=True,
90 )