2023-07-26 20:53:08 +08:00

243 lines
7.5 KiB
Python

import torch
import torch.nn as nn
try:
from inplace_abn import InPlaceABN
except ImportError:
InPlaceABN = None
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)
class CBAMChannel(nn.Module):
def __init__(self, in_channels, reduction=16):
super(CBAMChannel, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return x * self.sigmoid(out)
class CBAMSpatial(nn.Module):
def __init__(self, in_channels, kernel_size=7):
super(CBAMSpatial, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv1(out)
return x * self.sigmoid(out)
class CBAM(nn.Module):
"""
Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]
//Proceedings of the European conference on computer vision (ECCV).
"""
def __init__(self, in_channels, reduction=16, kernel_size=7):
super(CBAM, self).__init__()
self.ChannelGate = CBAMChannel(in_channels, reduction)
self.SpatialGate = CBAMSpatial(kernel_size)
def forward(self, x):
x = self.ChannelGate(x)
x = self.SpatialGate(x)
return x
class ECAM(nn.Module):
"""
Ensemble Channel Attention Module for UNetPlusPlus.
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
IEEE Geoscience and Remote Sensing Letters, 2021.
Not completely consistent, to be improved.
"""
def __init__(self, in_channels, out_channels, map_num=4):
super(ECAM, self).__init__()
self.ca1 = CBAMChannel(in_channels * map_num, reduction=16)
self.ca2 = CBAMChannel(in_channels, reduction=16 // 4)
self.up = nn.ConvTranspose2d(in_channels * map_num, in_channels * map_num, 2, stride=2)
self.conv_final = nn.Conv2d(in_channels * map_num, out_channels, kernel_size=1)
def forward(self, x):
"""
x (list[tensor] or tuple(tensor))
"""
out = torch.cat(x, 1)
intra = torch.sum(torch.stack(x), dim=0)
ca2 = self.ca2(intra)
out = self.ca1(out) * (out + ca2.repeat(1, 4, 1, 1))
out = self.up(out)
out = self.conv_final(out)
return out
class SEModule(nn.Module):
"""
Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
"""
def __init__(self, in_channels, reduction=16):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ArgMax(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, x):
return torch.argmax(x, dim=self.dim)
class Clamp(nn.Module):
def __init__(self, min=0, max=1):
super().__init__()
self.min, self.max = min, max
def forward(self, x):
return torch.clamp(x, self.min, self.max)
class Activation(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None or name == 'identity':
self.activation = nn.Identity(**params)
elif name == 'sigmoid':
self.activation = nn.Sigmoid()
elif name == 'softmax2d':
self.activation = nn.Softmax(dim=1, **params)
elif name == 'softmax':
self.activation = nn.Softmax(**params)
elif name == 'logsoftmax':
self.activation = nn.LogSoftmax(**params)
elif name == 'tanh':
self.activation = nn.Tanh()
elif name == 'argmax':
self.activation = ArgMax(**params)
elif name == 'argmax2d':
self.activation = ArgMax(dim=1, **params)
elif name == 'clamp':
self.activation = Clamp(**params)
elif callable(name):
self.activation = name(**params)
else:
raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name))
def forward(self, x):
return self.activation(x)
class Attention(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None:
self.attention = nn.Identity(**params)
elif name == 'scse':
self.attention = SCSEModule(**params)
elif name == 'cbam_channel':
self.attention = CBAMChannel(**params)
elif name == 'cbam_spatial':
self.attention = CBAMSpatial(**params)
elif name == 'cbam':
self.attention = CBAM(**params)
elif name == 'se':
self.attention = SEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))
def forward(self, x):
return self.attention(x)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)