53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
|
|
class BAM(nn.Module):
|
|
""" Basic self-attention module
|
|
"""
|
|
|
|
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
|
|
super(BAM, self).__init__()
|
|
self.chanel_in = in_dim
|
|
self.key_channel = self.chanel_in //8
|
|
self.activation = activation
|
|
self.ds = ds #
|
|
self.pool = nn.AvgPool2d(self.ds)
|
|
print('ds: ',ds)
|
|
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
|
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
|
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|
|
|
self.softmax = nn.Softmax(dim=-1) #
|
|
|
|
def forward(self, input):
|
|
"""
|
|
inputs :
|
|
x : input feature maps( B X C X W X H)
|
|
returns :
|
|
out : self attention value + input feature
|
|
attention: B X N X N (N is Width*Height)
|
|
"""
|
|
x = self.pool(input)
|
|
m_batchsize, C, width, height = x.size()
|
|
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
|
|
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
|
|
energy = torch.bmm(proj_query, proj_key) # transpose check
|
|
energy = (self.key_channel**-.5) * energy
|
|
|
|
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
|
|
|
|
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
|
|
|
|
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
|
out = out.view(m_batchsize, C, width, height)
|
|
|
|
out = F.interpolate(out, [width*self.ds,height*self.ds])
|
|
out = out + input
|
|
|
|
return out
|
|
|
|
|