1#!/usr/bin/env python
2# coding: utf-8
3#
4# Author: Kazuto Nakashima
5# URL: http://kazuto1011.github.io
6# Created: 2017-05-26
7
8from collections.abc import Sequence
9
10import numpy as np
11import torch
12import torch.nn as nn
13from torch.nn import functional as F
14from tqdm import tqdm
15
16
17class BaseWrapper(object):
18 def __init__(self, model):
19 super(BaseWrapper, self).__init__()
20 self.device = next(model.parameters()).device
21 self.model_with_norm = model
22 self.model = model.model
23 self.handlers = [] # a set of hook function handlers
24
25 def _encode_one_hot(self, ids):
26 one_hot = torch.zeros_like(self.logits).to(self.device)
27 one_hot.scatter_(1, ids, 1.0)
28 return one_hot
29
30 def forward(self, image):
31 self.image_shape = image.shape[2:]
32 self.logits = self.model_with_norm(image)
33 self.probs = torch.sigmoid(self.logits)
34 return self.probs.sort(dim=1, descending=True) # ordered results
35
36 def backward(self, ids):
37 """
38 Class-specific backpropagation
39 """
40 one_hot = self._encode_one_hot(ids)
41 self.model_with_norm.zero_grad()
42 self.logits.backward(gradient=one_hot, retain_graph=True)
43
44 def generate(self):
45 raise NotImplementedError
46
47 def remove_hook(self):
48 """
49 Remove all the forward/backward hook functions
50 """
51 for handle in self.handlers:
52 handle.remove()
53
54class GradCAM(BaseWrapper):
55 """
56 "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
57 https://arxiv.org/pdf/1610.02391.pdf
58 Look at Figure 2 on page 4
59 """
60
61 def __init__(self, model, candidate_layers=None):
62 super(GradCAM, self).__init__(model)
63 self.fmap_pool = {}
64 self.grad_pool = {}
65 self.candidate_layers = candidate_layers # list
66
67 def save_fmaps(key):
68 def forward_hook(module, input, output):
69 self.fmap_pool[key] = output.detach()
70
71 return forward_hook
72
73 def save_grads(key):
74 def backward_hook(module, grad_in, grad_out):
75 self.grad_pool[key] = grad_out[0].detach()
76
77 return backward_hook
78
79 # If any candidates are not specified, the hook is registered to all the layers.
80 for name, module in self.model.named_modules():
81 if self.candidate_layers is None or name in self.candidate_layers:
82 self.handlers.append(module.register_forward_hook(save_fmaps(name)))
83 self.handlers.append(module.register_backward_hook(save_grads(name)))
84
85 def _find(self, pool, target_layer):
86 if target_layer in pool.keys():
87 return pool[target_layer]
88 else:
89 raise ValueError("Invalid layer name: {}".format(target_layer))
90
91 def generate(self, target_layer):
92 fmaps = self._find(self.fmap_pool, target_layer)
93 grads = self._find(self.grad_pool, target_layer)
94 weights = F.adaptive_avg_pool2d(grads, 1)
95
96 gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
97 gcam = F.relu(gcam)
98 gcam = F.interpolate(
99 gcam, self.image_shape, mode="bilinear", align_corners=False
100 )
101
102 B, C, H, W = gcam.shape
103 gcam = gcam.view(B, -1)
104 gcam -= gcam.min(dim=1, keepdim=True)[0]
105 gcam /= gcam.max(dim=1, keepdim=True)[0]
106 gcam = gcam.view(B, C, H, W)
107
108 return gcam