Coverage for src/bob/bio/face/pytorch/lightning/backbone_head.py: 98%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1import os 

2 

3import numpy as np 

4import pytorch_lightning as pl 

5import scipy.spatial 

6import torch 

7 

8from torch.optim.lr_scheduler import ReduceLROnPlateau 

9 

10 

11class BackboneHeadModel(pl.LightningModule): 

12 def __init__( 

13 self, 

14 backbone, 

15 head, 

16 loss_fn, 

17 optimizer_fn, 

18 backbone_checkpoint_file=None, 

19 **kwargs, 

20 ): 

21 """ 

22 Pytorch-lightining (https://pytorch-lightning.readthedocs.io/) model composed of two `torch.nn.Module`: 

23 `backbone` and `head`. 

24 

25 

26 Use this model if you want to compose a lightning model that mixing a standard backbone 

27 (Resnet, InceptionResnet, EfficientNet....) and a head (ArcFace, regular cross entropy). 

28 

29 

30 .. note:: 

31 The `validation_step` of this module runs a validation in the level of embeddings, doing 

32 closed-set identification. 

33 Hence, it's mandatory to have a validation dataloader containg pairs of samples of the same identity in a sequence 

34 

35 

36 

37 Parameters 

38 ---------- 

39 

40 backbone: `torch.nn.Module` 

41 Backbone module 

42 

43 head: `torch.nn.Module` 

44 Head module 

45 

46 loss_fn: 

47 A loss function 

48 

49 optimizer_fn: 

50 A `torch.optim` function 

51 

52 backbone_checkpoint_path: 

53 Path for the backbone 

54 

55 

56 Example 

57 ------- 

58 

59 Follow below 

60 

61 

62 """ 

63 

64 super().__init__(**kwargs) 

65 self.backbone = backbone 

66 self.head = head 

67 self.loss_fn = loss_fn 

68 self.optimizer_fn = optimizer_fn 

69 self.backbone_checkpoint_file = backbone_checkpoint_file 

70 

71 def forward(self, inputs): 

72 # in lightning, forward defines the prediction/inference actions 

73 return self.backbone(inputs) 

74 

75 def on_train_epoch_end(self): 

76 if self.backbone_checkpoint_file is not None: 

77 torch.save( 

78 self.backbone.state_dict(), 

79 os.path.join(self.backbone_checkpoint_file), 

80 ) 

81 

82 # def training_step_end(self, losses): 

83 # pass 

84 

85 def validation_step(self, val_batch, batch_idx): 

86 data = val_batch["data"] 

87 labels = val_batch["label"].cpu().detach().numpy() 

88 

89 val_embedding = torch.nn.functional.normalize(self.forward(data), p=2) 

90 val_embedding = val_embedding.cpu().detach().numpy() 

91 n = val_embedding.shape[0] 

92 

93 # Distance with all vectors in a batch 

94 pdist = scipy.spatial.distance.pdist(val_embedding, metric="cosine") 

95 

96 # Squared matrix with infinity 

97 predictions = np.ones((n, n)) * np.inf 

98 

99 # Filling the upper triangular (without the diagonal) with the pdist 

100 predictions[np.triu_indices(n, k=1)] = pdist 

101 

102 # predicting 

103 predictions = labels[np.argmin(predictions, axis=1)] 

104 

105 accuracy = sum(predictions == labels) / n 

106 self.log("validation/accuracy", accuracy) 

107 

108 def training_step(self, batch, batch_idx): 

109 data = batch["data"] 

110 label = batch["label"] 

111 

112 embedding = self.backbone(data) 

113 

114 logits = self.head(embedding, label) 

115 

116 loss = self.loss_fn(logits, label) 

117 

118 self.log("train/loss", loss) 

119 

120 return loss 

121 

122 def configure_optimizers(self): 

123 # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 

124 

125 config = dict() 

126 optimizer = self.optimizer_fn(params=self.parameters()) 

127 config["optimizer"] = optimizer 

128 

129 lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5) 

130 config["lr_scheduler"] = lr_scheduler 

131 

132 config["monitor"] = "train/loss" 

133 

134 return config