Coverage for src/bob/bio/face/pytorch/head/arcface.py: 52%
65 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1import math
3import torch
4import torch.nn.functional as F
6from torch.nn import Module, Parameter
9class ArcFace(Module):
10 """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition" """
12 def __init__(
13 self, feat_dim, num_class, margin_arc=0.35, margin_am=0.0, scale=32
14 ):
15 super(ArcFace, self).__init__()
16 self.weight = Parameter(torch.Tensor(feat_dim, num_class))
17 self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
18 self.margin_arc = margin_arc
19 self.margin_am = margin_am
20 self.scale = scale
21 self.cos_margin = math.cos(margin_arc)
22 self.sin_margin = math.sin(margin_arc)
23 self.min_cos_theta = math.cos(math.pi - margin_arc)
25 def forward(self, feats, labels):
26 kernel_norm = F.normalize(self.weight, dim=0)
27 feats = F.normalize(feats)
28 cos_theta = torch.mm(feats, kernel_norm)
29 cos_theta = cos_theta.clamp(-1, 1)
30 sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
31 cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin
32 # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m
33 # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m)
34 cos_theta_m = torch.where(
35 cos_theta > self.min_cos_theta,
36 cos_theta_m,
37 cos_theta - self.margin_am,
38 )
39 index = torch.zeros_like(cos_theta)
40 index.scatter_(1, labels.data.view(-1, 1), 1)
41 index = index.byte().bool()
42 output = cos_theta * 1.0
43 output[index] = cos_theta_m[index]
44 output *= self.scale
45 return output
48class MagFace(Module):
49 """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
51 taken from https://github.com/JDAI-CV/FaceX-Zoo/blob/5b63794ba7649fe78a29d2ce0d0216c7773f6174/head/MagFace.py
52 """
54 def __init__(
55 self,
56 feat_dim,
57 num_class,
58 margin_am=0.0,
59 scale=32,
60 l_a=10,
61 u_a=110,
62 l_margin=0.45,
63 u_margin=0.8,
64 lamda=20,
65 ):
66 super(MagFace, self).__init__()
67 self.weight = Parameter(torch.Tensor(feat_dim, num_class))
68 self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
69 self.margin_am = margin_am
70 self.scale = scale
71 self.l_a = l_a
72 self.u_a = u_a
73 self.l_margin = l_margin
74 self.u_margin = u_margin
75 self.lamda = lamda
77 def calc_margin(self, x):
78 margin = (self.u_margin - self.l_margin) / (self.u_a - self.l_a) * (
79 x - self.l_a
80 ) + self.l_margin
81 return margin
83 def forward(self, feats, labels):
84 x_norm = torch.norm(feats, dim=1, keepdim=True).clamp(
85 self.l_a, self.u_a
86 ) # l2 norm
87 ada_margin = self.calc_margin(x_norm)
88 cos_m, sin_m = torch.cos(ada_margin), torch.sin(ada_margin)
89 loss_g = 1 / (self.u_a**2) * x_norm + 1 / (x_norm)
90 kernel_norm = F.normalize(self.weight, dim=0)
91 feats = F.normalize(feats)
92 cos_theta = torch.mm(feats, kernel_norm)
93 cos_theta = cos_theta.clamp(-1, 1)
94 sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
95 cos_theta_m = cos_theta * cos_m - sin_theta * sin_m
96 # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m
97 # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m)
98 min_cos_theta = torch.cos(math.pi - ada_margin)
99 cos_theta_m = torch.where(
100 cos_theta > min_cos_theta, cos_theta_m, cos_theta - self.margin_am
101 )
102 index = torch.zeros_like(cos_theta)
103 index.scatter_(1, labels.data.view(-1, 1), 1)
104 index = index.byte()
105 output = cos_theta * 1.0
106 output[index] = cos_theta_m[index]
107 output *= self.scale
108 return output, self.lamda * loss_g