2023-02-10 21:51:03 +08:00

52 lines
1.3 KiB
Python

# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM
def define_F(in_c, f_c, type='unet'):
if type == 'mynet3':
print("using mynet3 backbone")
return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
else:
NotImplementedError('no such F type!')
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class CDSA(nn.Module):
"""self attention module for change detection
"""
def __init__(self, in_c, ds=1, mode='BAM'):
super(CDSA, self).__init__()
self.in_C = in_c
self.ds = ds
print('ds: ',self.ds)
self.mode = mode
if self.mode == 'BAM':
self.Self_Att = BAM(self.in_C, ds=self.ds)
elif self.mode == 'PAM':
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
self.apply(weights_init)
def forward(self, x1, x2):
height = x1.shape[3]
x = torch.cat((x1, x2), 3)
x = self.Self_Att(x)
return x[:,:,:,0:height], x[:,:,:,height:]