#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn
from torch.nn import Conv2d, ConvTranspose2d
[docs]class UpsampleCropBlock(torch.nn.Module):
"""
Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 crop
layer in the forward function.
Used for DRIU and HED.
Parameters
----------
in_channels : int
number of channels of intermediate layer
out_channels : int
number of output channels
up_kernel_size : int
kernel size for transposed convolution
up_stride : int
stride for transposed convolution
up_padding : int
padding for transposed convolution
"""
def __init__(
self,
in_channels,
out_channels,
up_kernel_size,
up_stride,
up_padding,
pixelshuffle=False,
):
super().__init__()
# NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
self.conv = conv_with_kaiming_uniform(
in_channels, out_channels, 3, 1, 1
)
if pixelshuffle:
self.upconv = PixelShuffle_ICNR(
out_channels, out_channels, scale=up_stride
)
else:
self.upconv = convtrans_with_kaiming_uniform(
out_channels,
out_channels,
up_kernel_size,
up_stride,
up_padding,
)
[docs] def forward(self, x, input_res):
"""Forward pass of UpsampleBlock.
Upsampled feature maps are cropped to the resolution of the input
image.
Parameters
----------
x : tuple
input channels
input_res : tuple
Resolution of the input image format ``(height, width)``
"""
img_h = input_res[0]
img_w = input_res[1]
x = self.conv(x)
x = self.upconv(x)
# determine center crop
# height
up_h = x.shape[2]
h_crop = up_h - img_h
h_s = h_crop // 2
h_e = up_h - (h_crop - h_s)
# width
up_w = x.shape[3]
w_crop = up_w - img_w
w_s = w_crop // 2
w_e = up_w - (w_crop - w_s)
# perform crop
# needs explicit ranges for onnx export
x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size
return x
[docs]def ifnone(a, b):
"``a`` if ``a`` is not None, otherwise ``b``."
return b if a is None else a
[docs]def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
"""https://docs.fast.ai/layers.html#PixelShuffle_ICNR
ICNR init of ``x``, with ``scale`` and ``init`` function.
"""
ni, nf, h, w = x.shape
ni2 = int(ni / (scale ** 2))
k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale ** 2)
k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
x.data.copy_(k)
[docs]class PixelShuffle_ICNR(torch.nn.Module):
"""https://docs.fast.ai/layers.html#PixelShuffle_ICNR
Upsample by ``scale`` from ``ni`` filters to ``nf`` (default ``ni``), using
``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``.
"""
def __init__(self, ni: int, nf: int = None, scale: int = 2):
super().__init__()
nf = ifnone(nf, ni)
self.conv = conv_with_kaiming_uniform(ni, nf * (scale ** 2), 1)
icnr(self.conv.weight)
self.shuf = torch.nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
self.blur = torch.nn.AvgPool2d(2, stride=1)
self.relu = torch.nn.ReLU(inplace=True)
[docs] def forward(self, x):
x = self.shuf(self.relu(self.conv(x)))
x = self.blur(self.pad(x))
return x
[docs]class UnetBlock(torch.nn.Module):
def __init__(
self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False
):
super().__init__()
# middle block for VGG based U-Net
if middle_block:
up_out_c = up_in_c
else:
up_out_c = up_in_c // 2
cat_channels = x_in_c + up_out_c
inner_channels = cat_channels // 2
if pixel_shuffle:
self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c)
else:
self.upsample = convtrans_with_kaiming_uniform(
up_in_c, up_out_c, 2, 2
)
self.convtrans1 = convtrans_with_kaiming_uniform(
cat_channels, inner_channels, 3, 1, 1
)
self.convtrans2 = convtrans_with_kaiming_uniform(
inner_channels, inner_channels, 3, 1, 1
)
self.relu = torch.nn.ReLU(inplace=True)
[docs] def forward(self, up_in, x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in], dim=1)
x = self.relu(self.convtrans1(cat_x))
x = self.relu(self.convtrans2(x))
return x