Source code for bob.med.tb.utils.grad_cams

#!/usr/bin/env python
# coding: utf-8
#
# Author:   Kazuto Nakashima
# URL:      http://kazuto1011.github.io
# Created:  2017-05-26

from collections import Sequence

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm


[docs]class BaseWrapper(object): def __init__(self, model): super(BaseWrapper, self).__init__() self.device = next(model.parameters()).device self.model_with_norm = model self.model = model.model self.handlers = [] # a set of hook function handlers def _encode_one_hot(self, ids): one_hot = torch.zeros_like(self.logits).to(self.device) one_hot.scatter_(1, ids, 1.0) return one_hot
[docs] def forward(self, image): self.image_shape = image.shape[2:] self.logits = self.model_with_norm(image) self.probs = torch.sigmoid(self.logits) return self.probs.sort(dim=1, descending=True) # ordered results
[docs] def backward(self, ids): """ Class-specific backpropagation """ one_hot = self._encode_one_hot(ids) self.model_with_norm.zero_grad() self.logits.backward(gradient=one_hot, retain_graph=True)
[docs] def generate(self): raise NotImplementedError
[docs] def remove_hook(self): """ Remove all the forward/backward hook functions """ for handle in self.handlers: handle.remove()
[docs]class GradCAM(BaseWrapper): """ "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure 2 on page 4 """ def __init__(self, model, candidate_layers=None): super(GradCAM, self).__init__(model) self.fmap_pool = {} self.grad_pool = {} self.candidate_layers = candidate_layers # list def save_fmaps(key): def forward_hook(module, input, output): self.fmap_pool[key] = output.detach() return forward_hook def save_grads(key): def backward_hook(module, grad_in, grad_out): self.grad_pool[key] = grad_out[0].detach() return backward_hook # If any candidates are not specified, the hook is registered to all the layers. for name, module in self.model.named_modules(): if self.candidate_layers is None or name in self.candidate_layers: self.handlers.append(module.register_forward_hook(save_fmaps(name))) self.handlers.append(module.register_backward_hook(save_grads(name))) def _find(self, pool, target_layer): if target_layer in pool.keys(): return pool[target_layer] else: raise ValueError("Invalid layer name: {}".format(target_layer))
[docs] def generate(self, target_layer): fmaps = self._find(self.fmap_pool, target_layer) grads = self._find(self.grad_pool, target_layer) weights = F.adaptive_avg_pool2d(grads, 1) gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) gcam = F.relu(gcam) gcam = F.interpolate( gcam, self.image_shape, mode="bilinear", align_corners=False ) B, C, H, W = gcam.shape gcam = gcam.view(B, -1) gcam -= gcam.min(dim=1, keepdim=True)[0] gcam /= gcam.max(dim=1, keepdim=True)[0] gcam = gcam.view(B, C, H, W) return gcam