Coverage for src/bob/bio/face/pytorch/head/arcface.py: 52%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import math 

2 

3import torch 

4import torch.nn.functional as F 

5 

6from torch.nn import Module, Parameter 

7 

8 

9class ArcFace(Module): 

10 """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition" """ 

11 

12 def __init__( 

13 self, feat_dim, num_class, margin_arc=0.35, margin_am=0.0, scale=32 

14 ): 

15 super(ArcFace, self).__init__() 

16 self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 

17 self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 

18 self.margin_arc = margin_arc 

19 self.margin_am = margin_am 

20 self.scale = scale 

21 self.cos_margin = math.cos(margin_arc) 

22 self.sin_margin = math.sin(margin_arc) 

23 self.min_cos_theta = math.cos(math.pi - margin_arc) 

24 

25 def forward(self, feats, labels): 

26 kernel_norm = F.normalize(self.weight, dim=0) 

27 feats = F.normalize(feats) 

28 cos_theta = torch.mm(feats, kernel_norm) 

29 cos_theta = cos_theta.clamp(-1, 1) 

30 sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) 

31 cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin 

32 # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m 

33 # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m) 

34 cos_theta_m = torch.where( 

35 cos_theta > self.min_cos_theta, 

36 cos_theta_m, 

37 cos_theta - self.margin_am, 

38 ) 

39 index = torch.zeros_like(cos_theta) 

40 index.scatter_(1, labels.data.view(-1, 1), 1) 

41 index = index.byte().bool() 

42 output = cos_theta * 1.0 

43 output[index] = cos_theta_m[index] 

44 output *= self.scale 

45 return output 

46 

47 

48class MagFace(Module): 

49 """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition" 

50 

51 taken from https://github.com/JDAI-CV/FaceX-Zoo/blob/5b63794ba7649fe78a29d2ce0d0216c7773f6174/head/MagFace.py 

52 """ 

53 

54 def __init__( 

55 self, 

56 feat_dim, 

57 num_class, 

58 margin_am=0.0, 

59 scale=32, 

60 l_a=10, 

61 u_a=110, 

62 l_margin=0.45, 

63 u_margin=0.8, 

64 lamda=20, 

65 ): 

66 super(MagFace, self).__init__() 

67 self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 

68 self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 

69 self.margin_am = margin_am 

70 self.scale = scale 

71 self.l_a = l_a 

72 self.u_a = u_a 

73 self.l_margin = l_margin 

74 self.u_margin = u_margin 

75 self.lamda = lamda 

76 

77 def calc_margin(self, x): 

78 margin = (self.u_margin - self.l_margin) / (self.u_a - self.l_a) * ( 

79 x - self.l_a 

80 ) + self.l_margin 

81 return margin 

82 

83 def forward(self, feats, labels): 

84 x_norm = torch.norm(feats, dim=1, keepdim=True).clamp( 

85 self.l_a, self.u_a 

86 ) # l2 norm 

87 ada_margin = self.calc_margin(x_norm) 

88 cos_m, sin_m = torch.cos(ada_margin), torch.sin(ada_margin) 

89 loss_g = 1 / (self.u_a**2) * x_norm + 1 / (x_norm) 

90 kernel_norm = F.normalize(self.weight, dim=0) 

91 feats = F.normalize(feats) 

92 cos_theta = torch.mm(feats, kernel_norm) 

93 cos_theta = cos_theta.clamp(-1, 1) 

94 sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) 

95 cos_theta_m = cos_theta * cos_m - sin_theta * sin_m 

96 # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m 

97 # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m) 

98 min_cos_theta = torch.cos(math.pi - ada_margin) 

99 cos_theta_m = torch.where( 

100 cos_theta > min_cos_theta, cos_theta_m, cos_theta - self.margin_am 

101 ) 

102 index = torch.zeros_like(cos_theta) 

103 index.scatter_(1, labels.data.view(-1, 1), 1) 

104 index = index.byte() 

105 output = cos_theta * 1.0 

106 output[index] = cos_theta_m[index] 

107 output *= self.scale 

108 return output, self.lamda * loss_g