Coverage for src/bob/bio/face/pytorch/facexzoo/models.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import pkg_resources 

2 

3from bob.bio.base.database.utils import download_file, md5_hash 

4from bob.bio.face.pytorch.facexzoo.backbone_def import BackboneFactory 

5 

6def_backbone_conf = pkg_resources.resource_filename( 

7 "bob.bio.face", "pytorch/facexzoo/backbone_conf.yaml" 

8) 

9 

10info = { 

11 "AttentionNet": [ 

12 "AttentionNet-f4c6f908.pt.tar.gz", 

13 "49e435d8d9c075a4f613336090eac242", 

14 "AttentionNet-f4c6f908.pt", 

15 ], 

16 "ResNeSt": [ 

17 "ResNeSt-e8b132d4.pt.tar.gz", 

18 "51eef17ef7c17d1b22bbc13022393f31", 

19 "ResNeSt-e8b132d4.pt", 

20 ], 

21 "MobileFaceNet": [ 

22 "MobileFaceNet-ca475a8d.pt.tar.gz", 

23 "e5fc0ae59d1a290b58a297b37f015e11", 

24 "MobileFaceNet-ca475a8d.pt", 

25 ], 

26 "ResNet": [ 

27 "ResNet-e07e7fa1.pt.tar.gz", 

28 "13596dfeeb7f40c4b746ad2f0b271c36", 

29 "ResNet-e07e7fa1.pt", 

30 ], 

31 "EfficientNet": [ 

32 "EfficientNet-5aed534e.pt.tar.gz", 

33 "31c827017fe2029c1ab57371c8e5abf4", 

34 "EfficientNet-5aed534e.pt", 

35 ], 

36 "TF-NAS": [ 

37 "TF-NAS-709d8562.pt.tar.gz", 

38 "f96fe2683970140568a17c09fff24fab", 

39 "TF-NAS-709d8562.pt", 

40 ], 

41 "HRNet": [ 

42 "HRNet-edc4da11.pt.tar.gz", 

43 "5ed9920e004af440b623339a7008a758", 

44 "HRNet-edc4da11.pt", 

45 ], 

46 "ReXNet": [ 

47 "ReXNet-7c45620c.pt.tar.gz", 

48 "b24cf257a25486c52fde5626007b324b", 

49 "ReXNet-7c45620c.pt", 

50 ], 

51 "GhostNet": [ 

52 "GhostNet-5f026295.pt.tar.gz", 

53 "9edb8327c62b62197ad023f21bd865bc", 

54 "GhostNet-5f026295.pt", 

55 ], 

56} 

57 

58 

59class FaceXZooModelFactory: 

60 def __init__(self, arch, backbone_conf=def_backbone_conf, info=info): 

61 self.arch = arch 

62 self.backbone_conf = backbone_conf 

63 self.info = info 

64 

65 assert self.arch in self.info.keys() 

66 

67 def get_model(self): 

68 return BackboneFactory(self.arch, self.backbone_conf).get_backbone() 

69 

70 def get_checkpoint_name(self): 

71 return self.info[self.arch][0] 

72 

73 def get_facexzoo_file(self): 

74 urls = [ 

75 "https://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/facexzoomodels/{}".format( 

76 self.info[self.arch][0] 

77 ), 

78 "http://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/facexzoomodels/{}".format( 

79 self.info[self.arch][0] 

80 ), 

81 ] 

82 

83 return download_file( 

84 urls=urls, 

85 destination_filename=self.info[self.arch][0], 

86 destination_sub_directory=f"data/pytorch/{self.info[self.arch][0]}/", 

87 checksum=self.info[self.arch][1], 

88 checksum_fct=md5_hash, 

89 extract=True, 

90 )