import torch import torch.nn as nn try: from inplace_abn import InPlaceABN except ImportError: InPlaceABN = None class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): if use_batchnorm == "inplace" and InPlaceABN is None: raise RuntimeError( "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" ) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) if use_batchnorm == "inplace": bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) relu = nn.Identity() elif use_batchnorm and use_batchnorm != "inplace": bn = nn.BatchNorm2d(out_channels) else: bn = nn.Identity() super(Conv2dReLU, self).__init__(conv, bn, relu) class SCSEModule(nn.Module): def __init__(self, in_channels, reduction=16): super().__init__() self.cSE = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels // reduction, in_channels, 1), nn.Sigmoid(), ) self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) def forward(self, x): return x * self.cSE(x) + x * self.sSE(x) class CBAMChannel(nn.Module): def __init__(self, in_channels, reduction=16): super(CBAMChannel, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), nn.ReLU(), nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return x * self.sigmoid(out) class CBAMSpatial(nn.Module): def __init__(self, in_channels, kernel_size=7): super(CBAMSpatial, self).__init__() self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) out = torch.cat([avg_out, max_out], dim=1) out = self.conv1(out) return x * self.sigmoid(out) class CBAM(nn.Module): """ Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C] //Proceedings of the European conference on computer vision (ECCV). """ def __init__(self, in_channels, reduction=16, kernel_size=7): super(CBAM, self).__init__() self.ChannelGate = CBAMChannel(in_channels, reduction) self.SpatialGate = CBAMSpatial(kernel_size) def forward(self, x): x = self.ChannelGate(x) x = self.SpatialGate(x) return x class ECAM(nn.Module): """ Ensemble Channel Attention Module for UNetPlusPlus. Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J]. IEEE Geoscience and Remote Sensing Letters, 2021. Not completely consistent, to be improved. """ def __init__(self, in_channels, out_channels, map_num=4): super(ECAM, self).__init__() self.ca1 = CBAMChannel(in_channels * map_num, reduction=16) self.ca2 = CBAMChannel(in_channels, reduction=16 // 4) self.up = nn.ConvTranspose2d(in_channels * map_num, in_channels * map_num, 2, stride=2) self.conv_final = nn.Conv2d(in_channels * map_num, out_channels, kernel_size=1) def forward(self, x): """ x (list[tensor] or tuple(tensor)) """ out = torch.cat(x, 1) intra = torch.sum(torch.stack(x), dim=0) ca2 = self.ca2(intra) out = self.ca1(out) * (out + ca2.repeat(1, 4, 1, 1)) out = self.up(out) out = self.conv_final(out) return out class SEModule(nn.Module): """ Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C] //Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141. """ def __init__(self, in_channels, reduction=16): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, in_channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class ArgMax(nn.Module): def __init__(self, dim=None): super().__init__() self.dim = dim def forward(self, x): return torch.argmax(x, dim=self.dim) class Clamp(nn.Module): def __init__(self, min=0, max=1): super().__init__() self.min, self.max = min, max def forward(self, x): return torch.clamp(x, self.min, self.max) class Activation(nn.Module): def __init__(self, name, **params): super().__init__() if name is None or name == 'identity': self.activation = nn.Identity(**params) elif name == 'sigmoid': self.activation = nn.Sigmoid() elif name == 'softmax2d': self.activation = nn.Softmax(dim=1, **params) elif name == 'softmax': self.activation = nn.Softmax(**params) elif name == 'logsoftmax': self.activation = nn.LogSoftmax(**params) elif name == 'tanh': self.activation = nn.Tanh() elif name == 'argmax': self.activation = ArgMax(**params) elif name == 'argmax2d': self.activation = ArgMax(dim=1, **params) elif name == 'clamp': self.activation = Clamp(**params) elif callable(name): self.activation = name(**params) else: raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name)) def forward(self, x): return self.activation(x) class Attention(nn.Module): def __init__(self, name, **params): super().__init__() if name is None: self.attention = nn.Identity(**params) elif name == 'scse': self.attention = SCSEModule(**params) elif name == 'cbam_channel': self.attention = CBAMChannel(**params) elif name == 'cbam_spatial': self.attention = CBAMSpatial(**params) elif name == 'cbam': self.attention = CBAM(**params) elif name == 'se': self.attention = SEModule(**params) else: raise ValueError("Attention {} is not implemented".format(name)) def forward(self, x): return self.attention(x) class Flatten(nn.Module): def forward(self, x): return x.view(x.shape[0], -1)