Coverage for src/bob/bio/face/pytorch/lightning/backbone_head.py: 98%
47 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1import os
3import numpy as np
4import pytorch_lightning as pl
5import scipy.spatial
6import torch
8from torch.optim.lr_scheduler import ReduceLROnPlateau
11class BackboneHeadModel(pl.LightningModule):
12 def __init__(
13 self,
14 backbone,
15 head,
16 loss_fn,
17 optimizer_fn,
18 backbone_checkpoint_file=None,
19 **kwargs,
20 ):
21 """
22 Pytorch-lightining (https://pytorch-lightning.readthedocs.io/) model composed of two `torch.nn.Module`:
23 `backbone` and `head`.
26 Use this model if you want to compose a lightning model that mixing a standard backbone
27 (Resnet, InceptionResnet, EfficientNet....) and a head (ArcFace, regular cross entropy).
30 .. note::
31 The `validation_step` of this module runs a validation in the level of embeddings, doing
32 closed-set identification.
33 Hence, it's mandatory to have a validation dataloader containg pairs of samples of the same identity in a sequence
37 Parameters
38 ----------
40 backbone: `torch.nn.Module`
41 Backbone module
43 head: `torch.nn.Module`
44 Head module
46 loss_fn:
47 A loss function
49 optimizer_fn:
50 A `torch.optim` function
52 backbone_checkpoint_path:
53 Path for the backbone
56 Example
57 -------
59 Follow below
62 """
64 super().__init__(**kwargs)
65 self.backbone = backbone
66 self.head = head
67 self.loss_fn = loss_fn
68 self.optimizer_fn = optimizer_fn
69 self.backbone_checkpoint_file = backbone_checkpoint_file
71 def forward(self, inputs):
72 # in lightning, forward defines the prediction/inference actions
73 return self.backbone(inputs)
75 def on_train_epoch_end(self):
76 if self.backbone_checkpoint_file is not None:
77 torch.save(
78 self.backbone.state_dict(),
79 os.path.join(self.backbone_checkpoint_file),
80 )
82 # def training_step_end(self, losses):
83 # pass
85 def validation_step(self, val_batch, batch_idx):
86 data = val_batch["data"]
87 labels = val_batch["label"].cpu().detach().numpy()
89 val_embedding = torch.nn.functional.normalize(self.forward(data), p=2)
90 val_embedding = val_embedding.cpu().detach().numpy()
91 n = val_embedding.shape[0]
93 # Distance with all vectors in a batch
94 pdist = scipy.spatial.distance.pdist(val_embedding, metric="cosine")
96 # Squared matrix with infinity
97 predictions = np.ones((n, n)) * np.inf
99 # Filling the upper triangular (without the diagonal) with the pdist
100 predictions[np.triu_indices(n, k=1)] = pdist
102 # predicting
103 predictions = labels[np.argmin(predictions, axis=1)]
105 accuracy = sum(predictions == labels) / n
106 self.log("validation/accuracy", accuracy)
108 def training_step(self, batch, batch_idx):
109 data = batch["data"]
110 label = batch["label"]
112 embedding = self.backbone(data)
114 logits = self.head(embedding, label)
116 loss = self.loss_fn(logits, label)
118 self.log("train/loss", loss)
120 return loss
122 def configure_optimizers(self):
123 # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
125 config = dict()
126 optimizer = self.optimizer_fn(params=self.parameters())
127 config["optimizer"] = optimizer
129 lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
130 config["lr_scheduler"] = lr_scheduler
132 config["monitor"] = "train/loss"
134 return config