import torch import torch.nn.functional as F from torch import nn class BAM(nn.Module): """ Basic self-attention module """ def __init__(self, in_dim, ds=8, activation=nn.ReLU): super(BAM, self).__init__() self.chanel_in = in_dim self.key_channel = self.chanel_in //8 self.activation = activation self.ds = ds # self.pool = nn.AvgPool2d(self.ds) print('ds: ',ds) self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) # def forward(self, input): """ inputs : x : input feature maps( B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ x = self.pool(input) m_batchsize, C, width, height = x.size() proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) energy = torch.bmm(proj_query, proj_key) # transpose check energy = (self.key_channel**-.5) * energy attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(m_batchsize, C, width, height) out = F.interpolate(out, [width*self.ds,height*self.ds]) out = out + input return out