Source code for bob.ip.binseg.models.lwnet

#!/usr/bin/env python
# coding=utf-8

"""Little W-Net

Code was originally developed by Adrian Galdran
(https://github.com/agaldran/lwnet), loosely inspired on
https://github.com/jvanvugt/pytorch-unet

It is based on two simple U-Nets with 3 layers concatenated to each other.  The
first U-Net produces a segmentation map that is used by the second to better
guide segmentation.

Reference: [GALDRAN-2020]_
"""


import torch
import torch.nn


def _conv1x1(in_planes, out_planes, stride=1):
    return torch.nn.Conv2d(
        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
    )


[docs]class ConvBlock(torch.nn.Module): def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): """ pool_mode can be False (no pooling) or True ('maxpool') """ super(ConvBlock, self).__init__() if shortcut == True: self.shortcut = torch.nn.Sequential( _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c) ) else: self.shortcut = False pad = (k_sz - 1) // 2 block = [] if pool: self.pool = torch.nn.MaxPool2d(kernel_size=2) else: self.pool = False block.append( torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) ) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) block.append( torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) ) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): if self.pool: x = self.pool(x) out = self.block(x) if self.shortcut: return out + self.shortcut(x) else: return out
[docs]class UpsampleBlock(torch.nn.Module): def __init__(self, in_c, out_c, up_mode="transp_conv"): super(UpsampleBlock, self).__init__() block = [] if up_mode == "transp_conv": block.append( torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) ) elif up_mode == "up_conv": block.append( torch.nn.Upsample( mode="bilinear", scale_factor=2, align_corners=False ) ) block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) else: raise Exception("Upsampling mode not supported") self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): out = self.block(x) return out
[docs]class ConvBridgeBlock(torch.nn.Module): def __init__(self, channels, k_sz=3): super(ConvBridgeBlock, self).__init__() pad = (k_sz - 1) // 2 block = [] block.append( torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) ) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(channels)) self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): out = self.block(x) return out
[docs]class UpConvBlock(torch.nn.Module): def __init__( self, in_c, out_c, k_sz=3, up_mode="up_conv", conv_bridge=False, shortcut=False, ): super(UpConvBlock, self).__init__() self.conv_bridge = conv_bridge self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode) self.conv_layer = ConvBlock( 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False ) if self.conv_bridge: self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz)
[docs] def forward(self, x, skip): up = self.up_layer(x) if self.conv_bridge: out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1) else: out = torch.cat([up, skip], dim=1) out = self.conv_layer(out) return out
[docs]class LittleUNet(torch.nn.Module): """Little U-Net model""" def __init__( self, in_c, n_classes, layers, k_sz=3, up_mode="transp_conv", conv_bridge=True, shortcut=True, ): super(LittleUNet, self).__init__() self.n_classes = n_classes self.first = ConvBlock( in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False ) self.down_path = torch.nn.ModuleList() for i in range(len(layers) - 1): block = ConvBlock( in_c=layers[i], out_c=layers[i + 1], k_sz=k_sz, shortcut=shortcut, pool=True, ) self.down_path.append(block) self.up_path = torch.nn.ModuleList() reversed_layers = list(reversed(layers)) for i in range(len(layers) - 1): block = UpConvBlock( in_c=reversed_layers[i], out_c=reversed_layers[i + 1], k_sz=k_sz, up_mode=up_mode, conv_bridge=conv_bridge, shortcut=shortcut, ) self.up_path.append(block) # init, shamelessly lifted from torchvision/models/resnet.py for m in self.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1)
[docs] def forward(self, x): x = self.first(x) down_activations = [] for i, down in enumerate(self.down_path): down_activations.append(x) x = down(x) down_activations.reverse() for i, up in enumerate(self.up_path): x = up(x, down_activations[i]) return self.final(x)
[docs]class LittleWNet(torch.nn.Module): """Little W-Net model, concatenating two Little U-Net models""" def __init__( self, n_classes=1, in_c=3, layers=(8, 16, 32), conv_bridge=True, shortcut=True, mode="train", ): super(LittleWNet, self).__init__() self.unet1 = LittleUNet( in_c=in_c, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut, ) self.unet2 = LittleUNet( in_c=in_c + n_classes, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut, ) self.n_classes = n_classes self.mode = mode
[docs] def forward(self, x): x1 = self.unet1(x) x2 = self.unet2(torch.cat([x, x1], dim=1)) if self.mode != "train": return x2 return x1, x2
[docs]def lunet(in_c=3, n_classes=1): """Builds Little U-Net segmentation network (uninitialized) Parameters ---------- input_channels : :py:class:`int`, Optional Number of input channels the network should operate with output_classes : :py:class:`int`, Optional Number of output classes Returns ------- module : :py:class:`torch.nn.Module` Network model for Little U-Net """ return LittleUNet( in_c=input_channels, n_classes=output_classes, layers=[8, 16, 32], conv_bridge=True, shortcut=True, )
[docs]def lwnet(input_channels=3, output_classes=1): """Builds Little W-Net segmentation network (uninitialized) Parameters ---------- input_channels : :py:class:`int`, Optional Number of input channels the network should operate with output_classes : :py:class:`int`, Optional Number of output classes Returns ------- module : :py:class:`torch.nn.Module` Network model for Little W-Net """ return LittleWNet( in_c=input_channels, n_classes=output_classes, layers=[8, 16, 32], conv_bridge=True, shortcut=True, )