添加ai
@ -1,5 +1,6 @@
|
||||
from misc.utils import Register
|
||||
|
||||
AI_METHOD = Register('AI Method')
|
||||
from .packages.sta_net import STANet
|
||||
|
||||
from .main import AIPlugin
|
||||
from .basic_cd import DPFCN, DVCA,RCNN
|
@ -1,210 +1,183 @@
|
||||
from PyQt5.QtCore import QSettings, pyqtSignal, Qt
|
||||
from PyQt5.QtGui import QIcon, QTextBlock, QTextCursor
|
||||
from PyQt5.QtWidgets import QDialog, QLabel,QComboBox, QDialogButtonBox, QFileDialog, QHBoxLayout, QMessageBox, QProgressBar, QPushButton, QTextEdit, QVBoxLayout
|
||||
import subprocess
|
||||
import threading
|
||||
import sys
|
||||
from . import AI_METHOD
|
||||
from plugins.misc import AlgFrontend
|
||||
from rscder.utils.icons import IconInstance
|
||||
from rscder.utils.project import PairLayer
|
||||
from osgeo import gdal, gdal_array
|
||||
import os
|
||||
import sys
|
||||
import ai_method.subprcess_python as sp
|
||||
from ai_method import AI_METHOD
|
||||
from misc.main import AlgFrontend
|
||||
import abc
|
||||
from rscder.utils.project import Project
|
||||
from rscder.utils.geomath import geo2imageRC, imageRC2geo
|
||||
import math
|
||||
from .packages import get_model
|
||||
|
||||
class AIMethodDialog(QDialog):
|
||||
class BasicAICD(AlgFrontend):
|
||||
|
||||
stage_end = pyqtSignal(int)
|
||||
stage_log = pyqtSignal(str)
|
||||
@staticmethod
|
||||
def get_icon():
|
||||
return IconInstance().ARITHMETIC3
|
||||
|
||||
@staticmethod
|
||||
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, model=None, *args, **kargs):
|
||||
|
||||
if model is None and send_message is not None:
|
||||
send_message.emit('未能加载模型!')
|
||||
return
|
||||
|
||||
ENV = 'base'
|
||||
name = ''
|
||||
stages = []
|
||||
setting_widget = None
|
||||
ds1: gdal.Dataset = gdal.Open(pth1)
|
||||
ds2: gdal.Dataset = gdal.Open(pth2)
|
||||
|
||||
def __init__(self, parent = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
|
||||
# self.setting_widget:AlgFrontend = setting_widget
|
||||
cell_size = (512, 512)
|
||||
xsize = layer_parent.size[0]
|
||||
ysize = layer_parent.size[1]
|
||||
|
||||
vlayout = QVBoxLayout()
|
||||
band = ds1.RasterCount
|
||||
yblocks = ysize // cell_size[1]
|
||||
xblocks = xsize // cell_size[0]
|
||||
|
||||
hlayout = QHBoxLayout()
|
||||
select_label = QLabel('模式选择:')
|
||||
self.select_mode = QComboBox()
|
||||
self.select_mode.addItem('----------', 'NoValue')
|
||||
for stage in self.stages:
|
||||
self.select_mode.addItem(stage[1], stage[0])
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_tif = os.path.join(Project().other_path, 'temp.tif')
|
||||
out_ds = driver.Create(out_tif, xsize, ysize, 1, gdal.GDT_Float32)
|
||||
geo = layer_parent.grid.geo
|
||||
proj = layer_parent.grid.proj
|
||||
out_ds.SetGeoTransform(geo)
|
||||
out_ds.SetProjection(proj)
|
||||
|
||||
setting_btn = QPushButton(IconInstance().SELECT, '配置')
|
||||
max_diff = 0
|
||||
min_diff = math.inf
|
||||
|
||||
hlayout.addWidget(select_label)
|
||||
hlayout.addWidget(self.select_mode, 2)
|
||||
hlayout.addWidget(setting_btn)
|
||||
self.setting_args = []
|
||||
start1x, start1y = geo2imageRC(ds1.GetGeoTransform(
|
||||
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
|
||||
end1x, end1y = geo2imageRC(ds1.GetGeoTransform(
|
||||
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
|
||||
|
||||
def show_setting():
|
||||
if self.setting_widget is None:
|
||||
return
|
||||
dialog = QDialog(parent)
|
||||
vlayout = QVBoxLayout()
|
||||
dialog.setLayout(vlayout)
|
||||
widget = self.setting_widget.get_widget(dialog)
|
||||
vlayout.addWidget(widget)
|
||||
start2x, start2y = geo2imageRC(ds2.GetGeoTransform(
|
||||
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
|
||||
end2x, end2y = geo2imageRC(ds2.GetGeoTransform(
|
||||
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
|
||||
|
||||
dbtn_box = QDialogButtonBox(dialog)
|
||||
dbtn_box.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
||||
dbtn_box.button(QDialogButtonBox.Ok).setText('确定')
|
||||
dbtn_box.button(QDialogButtonBox.Cancel).setText('取消')
|
||||
dbtn_box.button(QDialogButtonBox.Ok).clicked.connect(dialog.accept)
|
||||
dbtn_box.button(QDialogButtonBox.Cancel).clicked.connect(dialog.reject)
|
||||
vlayout.addWidget(dbtn_box, 1, Qt.AlignRight)
|
||||
dialog.setMinimumHeight(500)
|
||||
dialog.setMinimumWidth(900)
|
||||
if dialog.exec_() == QDialog.Accepted:
|
||||
self.setting_args = self.setting_widget.get_params(widget)
|
||||
for j in range(yblocks + 1): # 该改这里了
|
||||
if send_message is not None:
|
||||
send_message.emit(f'计算{j}/{yblocks}')
|
||||
for i in range(xblocks +1):
|
||||
|
||||
setting_btn.clicked.connect(show_setting)
|
||||
|
||||
block_xy1 = (start1x + i * cell_size[0], start1y+j * cell_size[1])
|
||||
block_xy2 = (start2x + i * cell_size[0], start2y+j * cell_size[1])
|
||||
block_xy = (i * cell_size[0], j * cell_size[1])
|
||||
|
||||
if block_xy1[1] > end1y or block_xy2[1] > end2y:
|
||||
break
|
||||
if block_xy1[0] > end1x or block_xy2[0] > end2x:
|
||||
break
|
||||
block_size = list(cell_size)
|
||||
|
||||
btnbox = QDialogButtonBox(self)
|
||||
btnbox.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
||||
btnbox.button(QDialogButtonBox.Ok).setText('确定')
|
||||
btnbox.button(QDialogButtonBox.Cancel).setText('取消')
|
||||
if block_xy[1] + block_size[1] > ysize:
|
||||
block_xy[1] = (ysize - block_size[1])
|
||||
if block_xy[0] + block_size[0] > xsize:
|
||||
block_xy[0] = ( xsize - block_size[0])
|
||||
|
||||
processbar = QProgressBar(self)
|
||||
processbar.setMaximum(1)
|
||||
processbar.setMinimum(0)
|
||||
processbar.setEnabled(False)
|
||||
def run_stage():
|
||||
self.log = f'开始{self.current_stage}...\n'
|
||||
btnbox.button(QDialogButtonBox.Cancel).setEnabled(True)
|
||||
processbar.setMaximum(0)
|
||||
processbar.setValue(0)
|
||||
processbar.setEnabled(True)
|
||||
q = threading.Thread(target=self.run_stage, args=(self.current_stage,))
|
||||
q.start()
|
||||
if block_xy1[1] + block_size[1] > end1y:
|
||||
block_xy1[1] = (end1y - block_size[1])
|
||||
if block_xy1[0] + block_size[0] > end1x:
|
||||
block_xy1[0] = (end1x - block_size[0])
|
||||
if block_xy2[1] + block_size[1] > end2y:
|
||||
block_xy2[1] = (end2y - block_size[1])
|
||||
if block_xy2[0] + block_size[0] > end2x:
|
||||
block_xy2[0] = (end2x - block_size[0])
|
||||
|
||||
# if block_size1[0] * block_size1[1] == 0 or block_size2[0] * block_size2[1] == 0:
|
||||
# continue
|
||||
|
||||
block_data1 = ds1.ReadAsArray(*block_xy1, *block_size)
|
||||
block_data2 = ds2.ReadAsArray(*block_xy2, *block_size)
|
||||
# if block_data1.shape[0] == 0:
|
||||
# continue
|
||||
if band == 1:
|
||||
block_data1 = block_data1[None, ...]
|
||||
block_data2 = block_data2[None, ...]
|
||||
|
||||
block_diff = model(block_data1, block_data2)
|
||||
|
||||
out_ds.GetRasterBand(1).WriteArray(block_diff, *block_xy)
|
||||
if send_message is not None:
|
||||
|
||||
send_message.emit(f'完成{j}/{yblocks}')
|
||||
del ds2
|
||||
del ds1
|
||||
out_ds.FlushCache()
|
||||
del out_ds
|
||||
if send_message is not None:
|
||||
send_message.emit('归一化概率中...')
|
||||
temp_in_ds = gdal.Open(out_tif)
|
||||
|
||||
out_normal_tif = os.path.join(Project().cmi_path, '{}_{}_cmi.tif'.format(
|
||||
layer_parent.name, int(np.random.rand() * 100000)))
|
||||
out_normal_ds = driver.Create(
|
||||
out_normal_tif, xsize, ysize, 1, gdal.GDT_Byte)
|
||||
out_normal_ds.SetGeoTransform(geo)
|
||||
out_normal_ds.SetProjection(proj)
|
||||
# hist = np.zeros(256, dtype=np.int32)
|
||||
for j in range(yblocks+1):
|
||||
block_xy = (0, j * cell_size[1])
|
||||
if block_xy[1] > ysize:
|
||||
break
|
||||
block_size = (xsize, cell_size[1])
|
||||
if block_xy[1] + block_size[1] > ysize:
|
||||
block_size = (xsize, ysize - block_xy[1])
|
||||
block_data = temp_in_ds.ReadAsArray(*block_xy, *block_size)
|
||||
block_data = (block_data - min_diff) / (max_diff - min_diff) * 255
|
||||
block_data = block_data.astype(np.uint8)
|
||||
out_normal_ds.GetRasterBand(1).WriteArray(block_data, *block_xy)
|
||||
# hist_t, _ = np.histogram(block_data, bins=256, range=(0, 256))
|
||||
# hist += hist_t
|
||||
# print(hist)
|
||||
del temp_in_ds
|
||||
del out_normal_ds
|
||||
try:
|
||||
os.remove(out_tif)
|
||||
except:
|
||||
pass
|
||||
if send_message is not None:
|
||||
send_message.emit('计算完成')
|
||||
return out_normal_tif
|
||||
|
||||
btnbox.button(QDialogButtonBox.Cancel).setEnabled(False)
|
||||
btnbox.accepted.connect(run_stage)
|
||||
btnbox.rejected.connect(self._stage_stop)
|
||||
self.processbar = processbar
|
||||
|
||||
vlayout.addLayout(hlayout)
|
||||
self.detail = QTextEdit(self)
|
||||
vlayout.addWidget(self.detail)
|
||||
vlayout.addWidget(processbar)
|
||||
vlayout.addWidget(btnbox)
|
||||
self.detail.setReadOnly(True)
|
||||
# self.detail.copyAvailable(True)
|
||||
self.detail.setText(f'等待开始...')
|
||||
self.setLayout(vlayout)
|
||||
@AI_METHOD.register
|
||||
class DVCA(BasicAICD):
|
||||
|
||||
self.setMinimumHeight(500)
|
||||
self.setMinimumWidth(500)
|
||||
self.setWindowIcon(IconInstance().AI_DETECT)
|
||||
self.setWindowTitle(self.name)
|
||||
@staticmethod
|
||||
def get_name():
|
||||
return 'DVCA'
|
||||
|
||||
self.stage_end.connect(self._stage_end)
|
||||
self.log = f'等待开始...\n'
|
||||
self.stage_log.connect(self._stage_log)
|
||||
self.p = None
|
||||
@staticmethod
|
||||
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
|
||||
model = get_model('DVCA')
|
||||
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
|
||||
|
||||
@property
|
||||
def current_stage(self):
|
||||
if self.select_mode.currentData() == 'NoValue':
|
||||
return None
|
||||
return self.select_mode.currentData()
|
||||
|
||||
@property
|
||||
def _python_base(self):
|
||||
return os.path.abspath(os.path.join(os.path.dirname(sys.executable), '..', '..'))
|
||||
|
||||
@property
|
||||
def activate_env(self):
|
||||
script = os.path.join(self._python_base, 'Scripts','activate')
|
||||
if self.ENV == 'base':
|
||||
return script
|
||||
else:
|
||||
return script + ' ' + self.ENV
|
||||
@AI_METHOD.register
|
||||
class DPFCN(BasicAICD):
|
||||
|
||||
@property
|
||||
def python_path(self):
|
||||
if self.ENV == 'base':
|
||||
return self._python_base
|
||||
return os.path.join(self._python_base, 'envs', self.ENV, 'python.exe')
|
||||
@staticmethod
|
||||
def get_name():
|
||||
return 'DPFCN'
|
||||
|
||||
@staticmethod
|
||||
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
|
||||
model = get_model('DPFCN')
|
||||
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
|
||||
|
||||
@property
|
||||
def workdir(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def stage_script(self, stage):
|
||||
return None
|
||||
|
||||
def _stage_log(self, l):
|
||||
self.log += l + '\n'
|
||||
self.detail.setText(self.log)
|
||||
cursor = self.detail.textCursor()
|
||||
# QTextCursor
|
||||
cursor.movePosition(QTextCursor.End)
|
||||
self.detail.setTextCursor(cursor)
|
||||
def closeEvent(self, e) -> None:
|
||||
if self.p is None:
|
||||
return super().closeEvent(e)
|
||||
dialog = QMessageBox(QMessageBox.Warning, '警告', '关闭窗口将停止' + self.current_stage + '!', QMessageBox.Ok | QMessageBox.Cancel)
|
||||
dialog.button(QMessageBox.Ok).clicked.connect(dialog.accepted)
|
||||
dialog.button(QMessageBox.Cancel).clicked.connect(dialog.rejected)
|
||||
dialog.show()
|
||||
r = dialog.exec()
|
||||
# print(r)
|
||||
# print(QMessageBox.Rejected)
|
||||
# print(QMessageBox.Accepted)
|
||||
if r == QMessageBox.Cancel:
|
||||
e.ignore()
|
||||
return
|
||||
return super().closeEvent(e)
|
||||
|
||||
def _stage_stop(self):
|
||||
if self.p is not None:
|
||||
try:
|
||||
self.stage_log.emit(f'用户停止{self.stage}...')
|
||||
self.p.kill()
|
||||
except:
|
||||
pass
|
||||
@AI_METHOD.register
|
||||
class RCNN(BasicAICD):
|
||||
|
||||
def _stage_end(self, c):
|
||||
self.processbar.setMaximum(1)
|
||||
self.processbar.setValue(0)
|
||||
self.processbar.setEnabled(False)
|
||||
self.log += '完成!'
|
||||
self.detail.setText(self.log)
|
||||
@staticmethod
|
||||
def get_name():
|
||||
return 'RCNN'
|
||||
|
||||
def run_stage(self, stage):
|
||||
if stage is None:
|
||||
return
|
||||
ss = self.stage_script(stage)
|
||||
if ss is None:
|
||||
self.stage_log.emit(f'开始{stage}时未发现脚本')
|
||||
self.stage_end.emit(1)
|
||||
return
|
||||
|
||||
if self.workdir is None:
|
||||
self.stage_log.emit(f'未配置工作目录!')
|
||||
self.stage_end.emit(2)
|
||||
return
|
||||
args = [ss, *self.setting_args]
|
||||
self.p = sp.SubprocessWraper(self.python_path, self.workdir, args, self.activate_env)
|
||||
|
||||
for line in self.p.run():
|
||||
self.stage_log.emit(line)
|
||||
self.stage_end.emit(self.p.returncode)
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
def option_to_gui(parent, options:OrderedDict):
|
||||
for key in options:
|
||||
pass
|
||||
|
||||
def gui_to_option(widget, options:OrderedDict):
|
||||
pass
|
||||
@staticmethod
|
||||
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
|
||||
model = get_model('RCNN')
|
||||
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
|
@ -1,27 +1,194 @@
|
||||
from rscder.plugins.basic import BasicPlugin
|
||||
from rscder.gui.actions import ActionManager
|
||||
from ai_method import AI_METHOD
|
||||
from PyQt5.QtCore import Qt
|
||||
from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox
|
||||
from rscder.utils.icons import IconInstance
|
||||
from plugins.misc.main import AlgFrontend
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from plugins.misc.main import AlgFrontend
|
||||
from rscder.gui.actions import ActionManager
|
||||
from rscder.plugins.basic import BasicPlugin
|
||||
from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox
|
||||
|
||||
from rscder.gui.layercombox import PairLayerCombox
|
||||
from rscder.utils.icons import IconInstance
|
||||
from filter_collection import FILTER
|
||||
from . import AI_METHOD
|
||||
from thres import THRES
|
||||
from misc import table_layer, AlgSelectWidget
|
||||
from follow import FOLLOW
|
||||
import os
|
||||
class AICDMethod(QDialog):
|
||||
def __init__(self,parent=None, alg:AlgFrontend=None):
|
||||
super(AICDMethod, self).__init__(parent)
|
||||
self.alg = alg
|
||||
self.setWindowTitle('AI变化检测:{}'.format(alg.get_name()))
|
||||
self.setWindowIcon(IconInstance().LOGO)
|
||||
self.initUI()
|
||||
self.setMinimumWidth(500)
|
||||
|
||||
def initUI(self):
|
||||
#图层
|
||||
self.layer_combox = PairLayerCombox(self)
|
||||
layerbox = QHBoxLayout()
|
||||
layerbox.addWidget(self.layer_combox)
|
||||
|
||||
self.filter_select = AlgSelectWidget(self, FILTER)
|
||||
self.param_widget = self.alg.get_widget(self)
|
||||
self.unsupervised_menu = self.param_widget
|
||||
self.thres_select = AlgSelectWidget(self, THRES)
|
||||
|
||||
self.ok_button = QPushButton('确定', self)
|
||||
self.ok_button.setIcon(IconInstance().OK)
|
||||
self.ok_button.clicked.connect(self.accept)
|
||||
self.ok_button.setDefault(True)
|
||||
|
||||
self.cancel_button = QPushButton('取消', self)
|
||||
self.cancel_button.setIcon(IconInstance().CANCEL)
|
||||
self.cancel_button.clicked.connect(self.reject)
|
||||
self.cancel_button.setDefault(False)
|
||||
buttonbox=QDialogButtonBox(self)
|
||||
buttonbox.addButton(self.ok_button,QDialogButtonBox.NoRole)
|
||||
buttonbox.addButton(self.cancel_button,QDialogButtonBox.NoRole)
|
||||
buttonbox.setCenterButtons(True)
|
||||
|
||||
totalvlayout=QVBoxLayout()
|
||||
totalvlayout.addLayout(layerbox)
|
||||
totalvlayout.addWidget(self.filter_select)
|
||||
if self.param_widget is not None:
|
||||
totalvlayout.addWidget(self.param_widget)
|
||||
totalvlayout.addWidget(self.thres_select)
|
||||
totalvlayout.addStretch(1)
|
||||
hbox = QHBoxLayout()
|
||||
hbox.addStretch(1)
|
||||
hbox.addWidget(buttonbox)
|
||||
totalvlayout.addLayout(hbox)
|
||||
# totalvlayout.addStretch()
|
||||
|
||||
self.setLayout(totalvlayout)
|
||||
|
||||
@FOLLOW.register
|
||||
class AICDFollow(AlgFrontend):
|
||||
|
||||
@staticmethod
|
||||
def get_name():
|
||||
return 'AI变化检测'
|
||||
|
||||
@staticmethod
|
||||
def get_icon():
|
||||
return IconInstance().UNSUPERVISED
|
||||
|
||||
@staticmethod
|
||||
def get_widget(parent=None):
|
||||
widget = QWidget(parent)
|
||||
layer_combox = PairLayerCombox(widget)
|
||||
layer_combox.setObjectName('layer_combox')
|
||||
|
||||
filter_select = AlgSelectWidget(widget, FILTER)
|
||||
filter_select.setObjectName('filter_select')
|
||||
ai_select = AlgSelectWidget(widget, AI_METHOD)
|
||||
ai_select.setObjectName('ai_select')
|
||||
thres_select = AlgSelectWidget(widget, THRES)
|
||||
thres_select.setObjectName('thres_select')
|
||||
|
||||
totalvlayout=QVBoxLayout()
|
||||
totalvlayout.addWidget(layer_combox)
|
||||
totalvlayout.addWidget(filter_select)
|
||||
totalvlayout.addWidget(ai_select)
|
||||
totalvlayout.addWidget(thres_select)
|
||||
totalvlayout.addStretch()
|
||||
|
||||
widget.setLayout(totalvlayout)
|
||||
|
||||
return widget
|
||||
|
||||
@staticmethod
|
||||
def get_params(widget:QWidget=None):
|
||||
if widget is None:
|
||||
return dict()
|
||||
|
||||
layer_combox = widget.findChild(PairLayerCombox, 'layer_combox')
|
||||
filter_select = widget.findChild(AlgSelectWidget, 'filter_select')
|
||||
ai_select = widget.findChild(AlgSelectWidget, 'ai_select')
|
||||
thres_select = widget.findChild(AlgSelectWidget, 'thres_select')
|
||||
|
||||
layer1=layer_combox.layer1
|
||||
pth1 = layer_combox.layer1.path
|
||||
pth2 = layer_combox.layer2.path
|
||||
|
||||
falg, fparams = filter_select.get_alg_and_params()
|
||||
cdalg, cdparams = ai_select.get_alg_and_params()
|
||||
thalg, thparams = thres_select.get_alg_and_params()
|
||||
|
||||
if cdalg is None or thalg is None:
|
||||
return dict()
|
||||
|
||||
return dict(
|
||||
layer1=layer1,
|
||||
pth1 = pth1,
|
||||
pth2 = pth2,
|
||||
falg = falg,
|
||||
fparams = fparams,
|
||||
cdalg = cdalg,
|
||||
cdparams = cdparams,
|
||||
thalg = thalg,
|
||||
thparams = thparams,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def run_alg(layer1=None,
|
||||
pth1 = None,
|
||||
pth2 = None,
|
||||
falg = None,
|
||||
fparams = None,
|
||||
cdalg = None,
|
||||
cdparams = None,
|
||||
thalg = None,
|
||||
thparams = None,
|
||||
send_message = None):
|
||||
|
||||
if cdalg is None or thalg is None:
|
||||
return
|
||||
|
||||
name = layer1.name
|
||||
method_info = dict()
|
||||
if falg is not None:
|
||||
pth1 = falg.run_alg(pth1, name=name, send_message= send_message, **fparams)
|
||||
pth2 = falg.run_alg(pth2, name=name, send_message= send_message, **fparams)
|
||||
method_info['滤波算法'] = falg.get_name()
|
||||
else:
|
||||
method_info['滤波算法'] = '无'
|
||||
|
||||
cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message= send_message,**cdparams)
|
||||
|
||||
if falg is not None:
|
||||
try:
|
||||
os.remove(pth1)
|
||||
os.remove(pth2)
|
||||
# send_message.emit('删除临时文件')
|
||||
except:
|
||||
# send_message.emit('删除临时文件失败!')
|
||||
pass
|
||||
|
||||
thpth, th = thalg.run_alg(cdpth, name=name, send_message= send_message, **thparams)
|
||||
|
||||
method_info['变化检测算法'] = cdalg.get_name()
|
||||
method_info['二值化算法'] = thalg.get_name()
|
||||
|
||||
table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = send_message)
|
||||
|
||||
|
||||
class AIPlugin(BasicPlugin):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def info():
|
||||
return {
|
||||
'name': 'AI 变化检测',
|
||||
'author': 'RSC',
|
||||
'name': 'AIPlugin',
|
||||
'description': 'AIPlugin',
|
||||
'author': 'RSCDER',
|
||||
'version': '1.0.0',
|
||||
'description': 'AI 变化检测',
|
||||
'category': 'Ai method'
|
||||
}
|
||||
|
||||
def set_action(self):
|
||||
ai_menu = ActionManager().ai_menu
|
||||
# ai_menu.setIcon(IconInstance().UNSUPERVISED)
|
||||
# ActionManager().change_detection_menu.addMenu(ai_menu)
|
||||
AI_menu = QMenu('&AI变化检测', self.mainwindow)
|
||||
AI_menu.setIcon(IconInstance().AI_DETECT)
|
||||
ActionManager().change_detection_menu.addMenu(AI_menu)
|
||||
toolbar = ActionManager().add_toolbar('AI method')
|
||||
for key in AI_METHOD.keys():
|
||||
alg:AlgFrontend = AI_METHOD[key]
|
||||
@ -30,15 +197,58 @@ class AIPlugin(BasicPlugin):
|
||||
else:
|
||||
name = alg.get_name()
|
||||
|
||||
action = QAction(alg.get_icon(), name, ai_menu)
|
||||
action = QAction(alg.get_icon(), name, AI_menu)
|
||||
func = partial(self.run_cd, alg)
|
||||
action.triggered.connect(func)
|
||||
toolbar.addAction(action)
|
||||
ai_menu.addAction(action)
|
||||
AI_menu.addAction(action)
|
||||
|
||||
|
||||
def run_cd(self, alg:AlgFrontend):
|
||||
dialog = alg.get_widget(self.mainwindow)
|
||||
dialog.setWindowModality(Qt.NonModal)
|
||||
def run_cd(self, alg):
|
||||
# print(alg.get_name())
|
||||
dialog = AICDMethod(self.mainwindow, alg)
|
||||
dialog.show()
|
||||
# dialog.exec()
|
||||
|
||||
if dialog.exec_() == QDialog.Accepted:
|
||||
t = Thread(target=self.run_cd_alg, args=(dialog,))
|
||||
t.start()
|
||||
|
||||
def run_cd_alg(self, w:AICDMethod):
|
||||
|
||||
layer1=w.layer_combox.layer1
|
||||
pth1 = w.layer_combox.layer1.path
|
||||
pth2 = w.layer_combox.layer2.path
|
||||
name = layer1.layer_parent.name
|
||||
|
||||
falg, fparams = w.filter_select.get_alg_and_params()
|
||||
cdalg = w.alg
|
||||
cdparams = w.alg.get_params(w.param_widget)
|
||||
thalg, thparams = w.thres_select.get_alg_and_params()
|
||||
|
||||
if cdalg is None or thalg is None:
|
||||
return
|
||||
method_info = dict()
|
||||
if falg is not None:
|
||||
pth1 = falg.run_alg(pth1, name=name, send_message=self.send_message, **fparams)
|
||||
pth2 = falg.run_alg(pth2, name=name, send_message=self.send_message, **fparams)
|
||||
method_info['滤波算法'] = falg.get_name()
|
||||
|
||||
cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message=self.send_message,**cdparams)
|
||||
|
||||
if falg is not None:
|
||||
try:
|
||||
os.remove(pth1)
|
||||
os.remove(pth2)
|
||||
# send_message.emit('删除临时文件')
|
||||
except:
|
||||
# send_message.emit('删除临时文件失败!')
|
||||
pass
|
||||
|
||||
thpth, th = thalg.run_alg(cdpth, name=name, send_message=self.send_message, **thparams)
|
||||
|
||||
method_info['变化检测算法'] = cdalg.get_name()
|
||||
method_info['二值化算法'] = thalg.get_name()
|
||||
|
||||
table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = self.send_message)
|
||||
# table_layer(thpth,layer1,name,self.send_message)
|
||||
|
19
plugins/ai_method/packages/STA_net/.gitignore
vendored
@ -1,19 +0,0 @@
|
||||
.DS_Store
|
||||
checkpoints/
|
||||
results/
|
||||
result/
|
||||
build/
|
||||
*.pth
|
||||
*/*.pth
|
||||
*/*/*.pth
|
||||
torch.egg-info/
|
||||
*/*.pyc
|
||||
*/**/*.pyc
|
||||
*/**/**/*.pyc
|
||||
*/**/**/**/*.pyc
|
||||
*/**/**/**/**/*.pyc
|
||||
*/*.so*
|
||||
*/**/*.so*
|
||||
*/**/*.dylib*
|
||||
.idea
|
||||
|
@ -1,25 +0,0 @@
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2020, justchenhao
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
@ -1,233 +0,0 @@
|
||||
# STANet for remote sensing image change detection
|
||||
|
||||
It is the implementation of the paper: A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection.
|
||||
|
||||
Here, we provide the pytorch implementation of the spatial-temporal attention neural network (STANet) for remote sensing image change detection.
|
||||
|
||||

|
||||
|
||||
## Change log
|
||||
20210112:
|
||||
- add the pretraining weight of PAM. [baidupan link](https://pan.baidu.com/s/1O1kg7JWunqd87ajtVMM6pg), code: 2rja
|
||||
|
||||
20201105:
|
||||
|
||||
- add a demo for quick start.
|
||||
- add more dataset loader modes.
|
||||
- enhance the image augmentation module (crop and rotation).
|
||||
|
||||
20200601:
|
||||
|
||||
- first commit
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- windows or Linux
|
||||
- Python 3.6+
|
||||
- CPU or NVIDIA GPU
|
||||
- CUDA 9.0+
|
||||
- PyTorch > 1.0
|
||||
- visdom
|
||||
|
||||
## Installation
|
||||
|
||||
Clone this repo:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/justchenhao/STANet
|
||||
cd STANet
|
||||
```
|
||||
|
||||
Install [PyTorch](http://pytorch.org/) 1.0+ and other dependencies (e.g., torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate))
|
||||
|
||||
## Quick Start
|
||||
|
||||
You can run a demo to get started.
|
||||
|
||||
```bash
|
||||
python demo.py
|
||||
```
|
||||
|
||||
The input samples are in `samples`. After successfully run this script, you can find the predicted results in `samples/output`.
|
||||
|
||||
## Prepare Datasets
|
||||
|
||||
### download the change detection dataset
|
||||
|
||||
You could download the LEVIR-CD at https://justchenhao.github.io/LEVIR/;
|
||||
|
||||
The path list in the downloaded folder is as follows:
|
||||
|
||||
```
|
||||
path to LEVIR-CD:
|
||||
├─train
|
||||
│ ├─A
|
||||
│ ├─B
|
||||
│ ├─label
|
||||
├─val
|
||||
│ ├─A
|
||||
│ ├─B
|
||||
│ ├─label
|
||||
├─test
|
||||
│ ├─A
|
||||
│ ├─B
|
||||
│ ├─label
|
||||
```
|
||||
|
||||
where A contains images of pre-phase, B contains images of post-phase, and label contains label maps.
|
||||
|
||||
### cut bitemporal image pairs
|
||||
|
||||
The original image in LEVIR-CD has a size of 1024 * 1024, which will consume too much memory when training. Therefore, we can cut the origin images into smaller patches (e.g., 256 * 256, or 512 * 512). In our paper, we cut the original image into patches of 256 * 256 size without overlapping.
|
||||
|
||||
Make sure that the corresponding patch samples in the A, B, and label subfolders have the same name.
|
||||
|
||||
## Train
|
||||
|
||||
### Monitor training status
|
||||
|
||||
To view training results and loss plots, run this script and click the URL [http://localhost:8097](http://localhost:8097/).
|
||||
|
||||
```bash
|
||||
python -m visdom.server
|
||||
```
|
||||
|
||||
### train with our base method
|
||||
|
||||
Run the following script:
|
||||
|
||||
```bash
|
||||
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDF0 --lr 0.001 --model CDF0 --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
|
||||
```
|
||||
|
||||
Once finished, you could find the best model and the log files in the project folder.
|
||||
|
||||
### train with Basic spatial-temporal Attention Module (BAM) method
|
||||
|
||||
```bash
|
||||
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
|
||||
```
|
||||
|
||||
### train with Pyramid spatial-temporal Attention Module (PAM) method
|
||||
|
||||
```bash
|
||||
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
You could edit the file val.py, for example:
|
||||
|
||||
```python
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse() # get training options
|
||||
opt = make_val_opt(opt)
|
||||
opt.phase = 'test'
|
||||
opt.dataroot = 'path-to-LEVIR-CD-test' # data root
|
||||
opt.dataset_mode = 'changedetection'
|
||||
opt.n_class = 2
|
||||
opt.SA_mode = 'PAM' # BAM | PAM
|
||||
opt.arch = 'mynet3'
|
||||
opt.model = 'CDFA' # model type
|
||||
opt.name = 'LEVIR-CDFAp0' # project name
|
||||
opt.results_dir = './results/' # save predicted images
|
||||
opt.epoch = 'best-epoch-in-val' # which epoch to test
|
||||
opt.num_test = np.inf
|
||||
val(opt)
|
||||
```
|
||||
|
||||
then run the script: `python val.py`. Once finished, you can find the prediction log file in the project directory and predicted image files in the result directory.
|
||||
|
||||
## Using other dataset mode
|
||||
|
||||
### List mode
|
||||
|
||||
```bash
|
||||
list=train
|
||||
lr=0.001
|
||||
dataset_mode=list
|
||||
dataroot=path-to-dataroot
|
||||
name=project_name
|
||||
|
||||
python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
|
||||
```
|
||||
|
||||
In this case, the data structure should be the following:
|
||||
|
||||
```
|
||||
"""
|
||||
data structure
|
||||
-dataroot
|
||||
├─A
|
||||
├─train1.png
|
||||
...
|
||||
├─B
|
||||
├─train1.png
|
||||
...
|
||||
├─label
|
||||
├─train1.png
|
||||
...
|
||||
└─list
|
||||
├─val.txt
|
||||
├─test.txt
|
||||
└─train.txt
|
||||
|
||||
# In list/train.txt, each low writes the filename of each sample,
|
||||
# for example:
|
||||
list/train.txt
|
||||
train1.png
|
||||
train2.png
|
||||
...
|
||||
"""
|
||||
```
|
||||
|
||||
### Concat mode for loading multiple datasets (each default mode is List)
|
||||
|
||||
```bash
|
||||
list=train
|
||||
lr=0.001
|
||||
dataset_type=CD_data1,CD_data2,...,
|
||||
val_dataset_type=CD_data
|
||||
dataset_mode=concat
|
||||
name=project_name
|
||||
|
||||
python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
|
||||
```
|
||||
|
||||
Note, in this case, you should modify the `get_dataset_info` in `data/data_config.py` to add the corresponding ` dataset_name` and `dataroot` in it.
|
||||
|
||||
```python
|
||||
if dataset_type == 'LEVIR_CD':
|
||||
root = 'path-to-LEVIR_CD-dataroot'
|
||||
elif ...
|
||||
# add more dataset ...
|
||||
```
|
||||
|
||||
## Other TIPS
|
||||
|
||||
For more Training/Testing guides, you could see the option files in the `./options/` folder.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this code for your research, please cite our papers.
|
||||
|
||||
```
|
||||
@Article{rs12101662,
|
||||
AUTHOR = {Chen, Hao and Shi, Zhenwei},
|
||||
TITLE = {A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection},
|
||||
JOURNAL = {Remote Sensing},
|
||||
VOLUME = {12},
|
||||
YEAR = {2020},
|
||||
NUMBER = {10},
|
||||
ARTICLE-NUMBER = {1662},
|
||||
URL = {https://www.mdpi.com/2072-4292/12/10/1662},
|
||||
ISSN = {2072-4292},
|
||||
DOI = {10.3390/rs12101662}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Our code is inspired by [pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
|
||||
|
||||
|
@ -1,108 +0,0 @@
|
||||
import importlib
|
||||
import torch.utils.data
|
||||
from data.base_dataset import BaseDataset
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from data.data_config import get_dataset_info
|
||||
|
||||
|
||||
def find_dataset_using_name(dataset_name):
|
||||
"""Import the module "data/[dataset_name]_dataset.py".
|
||||
|
||||
In the file, the class called DatasetNameDataset() will
|
||||
be instantiated. It has to be a subclass of BaseDataset,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
dataset_filename = "data." + dataset_name + "_dataset"
|
||||
datasetlib = importlib.import_module(dataset_filename)
|
||||
|
||||
dataset = None
|
||||
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
||||
for name, cls in datasetlib.__dict__.items():
|
||||
if name.lower() == target_dataset_name.lower() \
|
||||
and issubclass(cls, BaseDataset):
|
||||
dataset = cls
|
||||
|
||||
if dataset is None:
|
||||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_option_setter(dataset_name):
|
||||
"""Return the static method <modify_commandline_options> of the dataset class."""
|
||||
dataset_class = find_dataset_using_name(dataset_name)
|
||||
return dataset_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_dataset(opt):
|
||||
"""Create a dataset given the option.
|
||||
|
||||
This function wraps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from data import create_dataset
|
||||
>>> dataset = create_dataset(opt)
|
||||
"""
|
||||
data_loader = CustomDatasetDataLoader(opt)
|
||||
dataset = data_loader.load_data()
|
||||
return dataset
|
||||
|
||||
|
||||
def create_single_dataset(opt, dataset_type_):
|
||||
# return dataset_class
|
||||
dataset_class = find_dataset_using_name('list')
|
||||
# get dataset root
|
||||
opt.dataroot = get_dataset_info(dataset_type_)
|
||||
return dataset_class(opt)
|
||||
|
||||
|
||||
class CustomDatasetDataLoader():
|
||||
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this class
|
||||
Step 1: create a dataset instance given the name [dataset_mode]
|
||||
Step 2: create a multi-threaded data loader.
|
||||
"""
|
||||
self.opt = opt
|
||||
print(opt.dataset_mode)
|
||||
if opt.dataset_mode == 'concat':
|
||||
# 叠加多个数据集
|
||||
datasets = []
|
||||
# 获取concat的多个数据集列表
|
||||
self.dataset_type = opt.dataset_type.split(',')
|
||||
# 去除“,”的影响
|
||||
if self.dataset_type[-1] == '':
|
||||
self.dataset_type = self.dataset_type[:-1]
|
||||
for dataset_type_ in self.dataset_type:
|
||||
dataset_ = create_single_dataset(opt, dataset_type_)
|
||||
datasets.append(dataset_)
|
||||
self.dataset = ConcatDataset(datasets)
|
||||
else:
|
||||
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
||||
|
||||
self.dataset = dataset_class(opt)
|
||||
|
||||
print("dataset [%s] was created" % type(self.dataset).__name__)
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batch_size,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.num_threads),
|
||||
drop_last=True)
|
||||
|
||||
def load_data(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of data in the dataset"""
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return a batch of data"""
|
||||
for i, data in enumerate(self.dataloader):
|
||||
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
||||
break
|
||||
yield data
|
@ -1,189 +0,0 @@
|
||||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
||||
|
||||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image,ImageFilter
|
||||
import torchvision.transforms as transforms
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
|
||||
|
||||
class BaseDataset(data.Dataset, ABC):
|
||||
"""This class is an abstract base class (ABC) for datasets.
|
||||
|
||||
To create a subclass, you need to implement the following four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the class; save the options in the class
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns:
|
||||
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_params(opt, size, test=False):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
angle = 0
|
||||
if opt.preprocess == 'resize_and_crop':
|
||||
new_h = new_w = opt.load_size
|
||||
if 'rotate' in opt.preprocess and test is False:
|
||||
angle = random.uniform(0, opt.angle)
|
||||
# print(angle)
|
||||
new_w = int(new_w * math.cos(angle*math.pi/180) \
|
||||
+ new_h*math.sin(angle*math.pi/180))
|
||||
new_h = int(new_h * math.cos(angle*math.pi/180) \
|
||||
+ new_w*math.sin(angle*math.pi/180))
|
||||
new_w = min(new_w,new_h)
|
||||
new_h = min(new_w,new_h)
|
||||
# print(new_h,new_w)
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
||||
# print('x,y: ',x,y)
|
||||
flip = random.random() > 0.5 # left-right
|
||||
return {'crop_pos': (x, y), 'flip': flip, 'angle': angle}
|
||||
|
||||
|
||||
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC,
|
||||
convert=True, normalize=True, test=False):
|
||||
transform_list = []
|
||||
if grayscale:
|
||||
transform_list.append(transforms.Grayscale(1))
|
||||
if 'resize' in opt.preprocess:
|
||||
osize = [opt.load_size, opt.load_size]
|
||||
transform_list.append(transforms.Resize(osize, method))
|
||||
# gaussian blur
|
||||
if 'blur' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __blur(img)))
|
||||
|
||||
if 'rotate' in opt.preprocess and test==False:
|
||||
if params is None:
|
||||
transform_list.append(transforms.RandomRotation(5))
|
||||
else:
|
||||
degree = params['angle']
|
||||
transform_list.append(transforms.Lambda(lambda img: __rotate(img, degree)))
|
||||
|
||||
if 'crop' in opt.preprocess:
|
||||
if params is None:
|
||||
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
||||
else:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'],
|
||||
opt.crop_size)))
|
||||
|
||||
if not opt.no_flip:
|
||||
if params is None:
|
||||
transform_list.append(transforms.RandomHorizontalFlip())
|
||||
elif params['flip']:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
if convert:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def __blur(img):
|
||||
if img.mode == 'RGB':
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
return img
|
||||
|
||||
def __rotate(img, degree):
|
||||
if img.mode =='RGB':
|
||||
# set img padding == 128
|
||||
img2 = img.convert('RGBA')
|
||||
rot = img2.rotate(degree,expand=1)
|
||||
fff = Image.new('RGBA', rot.size, (128,) * 4) # 灰色
|
||||
out = Image.composite(rot, fff, rot)
|
||||
img = out.convert(img.mode)
|
||||
return img
|
||||
else:
|
||||
# set label padding == 0
|
||||
img2 = img.convert('RGBA')
|
||||
rot = img2.rotate(degree,expand=1)
|
||||
# a white image same size as rotated image
|
||||
fff = Image.new('RGBA', rot.size, (255,) * 4)
|
||||
# create a composite image using the alpha layer of rot as a mask
|
||||
out = Image.composite(rot, fff, rot)
|
||||
img = out.convert(img.mode)
|
||||
return img
|
||||
|
||||
|
||||
def __crop(img, pos, size):
|
||||
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
# print('imagesize:',ow,oh)
|
||||
# only 图像尺寸大于截取尺寸才截取,否则要padding
|
||||
if (ow > tw and oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
|
||||
size = [size, size]
|
||||
if img.mode == 'RGB':
|
||||
new_image = Image.new('RGB', size, (128, 128, 128))
|
||||
new_image.paste(img, (int((1+size[1] - img.size[0]) / 2),
|
||||
int((1+size[0] - img.size[1]) / 2)))
|
||||
|
||||
return new_image
|
||||
else:
|
||||
new_image = Image.new(img.mode, size, 255)
|
||||
# upper left corner
|
||||
new_image.paste(img, (int((1 + size[1] - img.size[0]) / 2),
|
||||
int((1 + size[0] - img.size[1]) / 2)))
|
||||
return new_image
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
||||
|
||||
def __print_size_warning(ow, oh, w, h):
|
||||
"""Print warning information about image size(only print once)"""
|
||||
if not hasattr(__print_size_warning, 'has_printed'):
|
||||
print("The image size needs to be a multiple of 4. "
|
||||
"The loaded image size was (%d, %d), so it was adjusted to "
|
||||
"(%d, %d). This adjustment will be done to all images "
|
||||
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
||||
__print_size_warning.has_printed = True
|
@ -1,63 +0,0 @@
|
||||
from data.base_dataset import BaseDataset, get_transform, get_params
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ChangeDetectionDataset(BaseDataset):
|
||||
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
|
||||
|
||||
datafolder-tree
|
||||
dataroot:.
|
||||
├─A
|
||||
├─B
|
||||
├─label
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseDataset.__init__(self, opt)
|
||||
folder_A = 'A'
|
||||
folder_B = 'B'
|
||||
folder_L = 'label'
|
||||
self.istest = False
|
||||
if opt.phase == 'test':
|
||||
self.istest = True
|
||||
self.A_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_A), opt.max_dataset_size))
|
||||
self.B_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_B), opt.max_dataset_size))
|
||||
if not self.istest:
|
||||
self.L_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_L), opt.max_dataset_size))
|
||||
|
||||
print(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
A_path = self.A_paths[index]
|
||||
B_path = self.B_paths[index]
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
B_img = Image.open(B_path).convert('RGB')
|
||||
|
||||
transform_params = get_params(self.opt, A_img.size, test=self.istest)
|
||||
# apply the same transform to A B L
|
||||
transform = get_transform(self.opt, transform_params, test=self.istest)
|
||||
|
||||
A = transform(A_img)
|
||||
B = transform(B_img)
|
||||
|
||||
if self.istest:
|
||||
return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path}
|
||||
|
||||
L_path = self.L_paths[index]
|
||||
tmp = np.array(Image.open(L_path), dtype=np.uint32)/255
|
||||
L_img = Image.fromarray(tmp)
|
||||
transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False,
|
||||
test=self.istest)
|
||||
L = transform_L(L_img)
|
||||
|
||||
|
||||
return {'A': A, 'A_paths': A_path,
|
||||
'B': B, 'B_paths': B_path,
|
||||
'L': L, 'L_paths': L_path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return len(self.A_paths)
|
@ -1,12 +0,0 @@
|
||||
|
||||
|
||||
def get_dataset_info(dataset_type):
|
||||
"""define dataset_name and its dataroot"""
|
||||
root = ''
|
||||
if dataset_type == 'LEVIR_CD':
|
||||
root = 'path-to-LEVIR_CD-dataroot'
|
||||
# add more dataset ...
|
||||
else:
|
||||
raise TypeError("not define the %s" % dataset_type)
|
||||
|
||||
return root
|
@ -1,66 +0,0 @@
|
||||
"""A modified image folder class
|
||||
|
||||
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
||||
so that this class can load images from both current directory and its subdirectories.
|
||||
"""
|
||||
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 'tif', 'tiff'
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir, max_dataset_size=float("inf")):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " +
|
||||
",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
@ -1,101 +0,0 @@
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
from data.base_dataset import BaseDataset, get_transform, get_params
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class listDataset(BaseDataset):
|
||||
"""
|
||||
data structure
|
||||
-dataroot
|
||||
├─A
|
||||
├─train1.png
|
||||
...
|
||||
├─B
|
||||
├─train1.png
|
||||
...
|
||||
├─label
|
||||
├─train1.png
|
||||
...
|
||||
└─list
|
||||
├─val.txt
|
||||
├─test.txt
|
||||
└─train.txt
|
||||
|
||||
# In list/train.txt, each low writes the filename of each sample,
|
||||
# for example:
|
||||
list/train.txt
|
||||
train1.png
|
||||
train2.png
|
||||
...
|
||||
"""
|
||||
def __init__(self, opt):
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.split = opt.split
|
||||
self.files = collections.defaultdict(list)
|
||||
self.istest = False if opt.phase == 'train' else True # 是否为测试/验证;若是,对数据不做尺度变换和旋转变换;
|
||||
|
||||
path = os.path.join(self.root, 'list', self.split + '.txt')
|
||||
file_list = tuple(open(path, 'r'))
|
||||
file_list = [id_.rstrip() for id_ in file_list]
|
||||
self.files[self.split] = file_list
|
||||
# print(file_list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files[self.split])
|
||||
|
||||
def __getitem__(self, index):
|
||||
paths = self.files[self.split][index]
|
||||
|
||||
path = paths.split(" ")
|
||||
|
||||
A_path = os.path.join(self.root,'A', path[0])
|
||||
B_path = os.path.join(self.root,'B', path[0])
|
||||
|
||||
L_path = os.path.join(self.root,'label', path[0])
|
||||
|
||||
A = Image.open(A_path).convert('RGB')
|
||||
B = Image.open(B_path).convert('RGB')
|
||||
|
||||
tmp = np.array(Image.open(L_path), dtype=np.uint32) / 255
|
||||
# print(tmp.max())
|
||||
L = Image.fromarray(tmp)
|
||||
|
||||
transform_params = get_params(self.opt, A.size, self.istest)
|
||||
|
||||
transform = get_transform(self.opt, transform_params, test=self.istest)
|
||||
transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False,
|
||||
test=self.istest) # 标签不做归一化
|
||||
A = transform(A)
|
||||
B = transform(B)
|
||||
|
||||
L = transform_L(L)
|
||||
|
||||
return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path, 'L': L, 'L_paths': L_path}
|
||||
|
||||
|
||||
# Leave code for debugging purposes
|
||||
|
||||
if __name__ == '__main__':
|
||||
from options.train_options import TrainOptions
|
||||
opt = TrainOptions().parse()
|
||||
opt.dataroot = r'I:\data\change_detection\LEVIR-CD-r'
|
||||
opt.split = 'train'
|
||||
opt.load_size = 500
|
||||
opt.crop_size = 500
|
||||
opt.batch_size = 1
|
||||
opt.dataset_mode = 'list'
|
||||
from data import create_dataset
|
||||
dataset = create_dataset(opt)
|
||||
import matplotlib.pyplot as plt
|
||||
from util.util import tensor2im
|
||||
|
||||
for i, data in enumerate(dataset):
|
||||
A = data['A']
|
||||
L = data['L']
|
||||
A = tensor2im(A)
|
||||
color = tensor2im(L)[:,:,0]*255
|
||||
plt.imshow(A)
|
||||
plt.show()
|
||||
|
@ -1,82 +0,0 @@
|
||||
import time
|
||||
from options.test_options import TestOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
import os
|
||||
from util.util import save_images
|
||||
import numpy as np
|
||||
from util.util import mkdir
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def make_val_opt(opt):
|
||||
|
||||
# hard-code some parameters for test
|
||||
opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
opt.batch_size = 1 # test code only supports batch_size = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed.
|
||||
|
||||
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
opt.phase = 'val'
|
||||
opt.preprocess = 'none1'
|
||||
opt.isTrain = False
|
||||
opt.aspect_ratio = 1
|
||||
opt.eval = True
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
|
||||
def val(opt):
|
||||
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
save_path = opt.results_dir
|
||||
mkdir(save_path)
|
||||
model.eval()
|
||||
|
||||
for i, data in enumerate(dataset):
|
||||
if i >= opt.num_test: # only apply our model to opt.num_test images.
|
||||
break
|
||||
model.set_input(data) # unpack data from data loader
|
||||
pred = model.test(val=False) # run inference return pred
|
||||
|
||||
img_path = model.get_image_paths() # get image paths
|
||||
if i % 5 == 0: # save images to an HTML file
|
||||
print('processing (%04d)-th image... %s' % (i, img_path))
|
||||
|
||||
save_images(pred, save_path, img_path)
|
||||
|
||||
|
||||
def pred_image(data_root, results_dir):
|
||||
opt = TestOptions().parse() # get training options
|
||||
opt = make_val_opt(opt)
|
||||
opt.phase = 'test'
|
||||
opt.dataset_mode = 'changedetection'
|
||||
opt.n_class = 2
|
||||
opt.SA_mode = 'PAM'
|
||||
opt.arch = 'mynet3'
|
||||
opt.model = 'CDFA'
|
||||
opt.epoch = 'pam'
|
||||
opt.num_test = np.inf
|
||||
opt.name = 'pam'
|
||||
opt.dataroot = data_root
|
||||
opt.results_dir = results_dir
|
||||
|
||||
val(opt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# define the data_root and the results_dir
|
||||
# note:
|
||||
# data_root should have such structure:
|
||||
# ├─A
|
||||
# ├─B
|
||||
# A for before images
|
||||
# B for after images
|
||||
data_root = './samples'
|
||||
results_dir = './samples/output/'
|
||||
pred_image(data_root, results_dir)
|
@ -1,52 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BAM(nn.Module):
|
||||
""" Basic self-attention module
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
|
||||
super(BAM, self).__init__()
|
||||
self.chanel_in = in_dim
|
||||
self.key_channel = self.chanel_in //8
|
||||
self.activation = activation
|
||||
self.ds = ds #
|
||||
self.pool = nn.AvgPool2d(self.ds)
|
||||
print('ds: ',ds)
|
||||
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
||||
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
||||
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1) #
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
inputs :
|
||||
x : input feature maps( B X C X W X H)
|
||||
returns :
|
||||
out : self attention value + input feature
|
||||
attention: B X N X N (N is Width*Height)
|
||||
"""
|
||||
x = self.pool(input)
|
||||
m_batchsize, C, width, height = x.size()
|
||||
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
|
||||
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
|
||||
energy = torch.bmm(proj_query, proj_key) # transpose check
|
||||
energy = (self.key_channel**-.5) * energy
|
||||
|
||||
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
|
||||
|
||||
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
|
||||
|
||||
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
||||
out = out.view(m_batchsize, C, width, height)
|
||||
|
||||
out = F.interpolate(out, [width*self.ds,height*self.ds])
|
||||
out = out + input
|
||||
|
||||
return out
|
||||
|
||||
|
@ -1,114 +0,0 @@
|
||||
import torch
|
||||
import itertools
|
||||
from .base_model import BaseModel
|
||||
from . import backbone
|
||||
import torch.nn.functional as F
|
||||
from . import loss
|
||||
|
||||
|
||||
class CDF0Model(BaseModel):
|
||||
"""
|
||||
change detection module:
|
||||
feature extractor
|
||||
contrastive loss
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
self.istest = opt.istest
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['f']
|
||||
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
||||
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
|
||||
if self.istest:
|
||||
self.visual_names = ['A', 'B', 'pred_L_show']
|
||||
self.visual_features = ['feat_A', 'feat_B']
|
||||
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
|
||||
if self.isTrain:
|
||||
self.model_names = ['F']
|
||||
else: # during test time, only load Gs
|
||||
self.model_names = ['F']
|
||||
self.ds=1
|
||||
# define networks (both Generators and discriminators)
|
||||
self.n_class = 2
|
||||
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
|
||||
|
||||
if self.isTrain:
|
||||
# define loss functions
|
||||
self.criterionF = loss.BCL()
|
||||
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
||||
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
self.A = input['A'].to(self.device)
|
||||
self.B = input['B'].to(self.device)
|
||||
if not self.istest:
|
||||
self.L = input['L'].to(self.device).long()
|
||||
self.image_paths = input['A_paths']
|
||||
if self.isTrain:
|
||||
self.L_s = self.L.float()
|
||||
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
|
||||
self.L_s[self.L_s == 1] = -1 # change
|
||||
self.L_s[self.L_s == 0] = 1 # no change
|
||||
|
||||
|
||||
def test(self, val=False):
|
||||
"""Forward function used in test time.
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
if val: # score
|
||||
from util.metrics import RunningMetrics
|
||||
metrics = RunningMetrics(self.n_class)
|
||||
pred = self.pred_L.long()
|
||||
|
||||
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
|
||||
scores = metrics.get_cm()
|
||||
return scores
|
||||
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.feat_A = self.netF(self.A) # f(A)
|
||||
self.feat_B = self.netF(self.B) # f(B)
|
||||
|
||||
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)
|
||||
# print(self.dist.shape)
|
||||
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
|
||||
self.pred_L = (self.dist > 1).float()
|
||||
self.pred_L_show = self.pred_L.long()
|
||||
return self.pred_L
|
||||
|
||||
def backward(self):
|
||||
"""Calculate the loss for generators F and L"""
|
||||
# print(self.weight)
|
||||
self.loss_f = self.criterionF(self.dist, self.L_s)
|
||||
|
||||
self.loss = self.loss_f
|
||||
if torch.isnan(self.loss):
|
||||
print(self.image_paths)
|
||||
|
||||
self.loss.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# forward
|
||||
self.forward() # compute feat and dist
|
||||
|
||||
self.optimizer_G.zero_grad() # set G's gradients to zero
|
||||
self.backward() # calculate graidents for G
|
||||
self.optimizer_G.step() # udpate G's weights
|
@ -1,119 +0,0 @@
|
||||
import torch
|
||||
import itertools
|
||||
from .base_model import BaseModel
|
||||
from . import backbone
|
||||
import torch.nn.functional as F
|
||||
from . import loss
|
||||
|
||||
|
||||
class CDFAModel(BaseModel):
|
||||
"""
|
||||
change detection module:
|
||||
feature extractor+ spatial-temporal-self-attention
|
||||
contrastive loss
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
return parser
|
||||
def __init__(self, opt):
|
||||
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['f']
|
||||
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
||||
if opt.phase == 'test':
|
||||
self.istest = True
|
||||
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
|
||||
if self.istest:
|
||||
self.visual_names = ['A', 'B', 'pred_L_show'] # visualizations for A and B
|
||||
|
||||
self.visual_features = ['feat_A','feat_B']
|
||||
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
|
||||
if self.isTrain:
|
||||
self.model_names = ['F','A']
|
||||
else: # during test time, only load Gs
|
||||
self.model_names = ['F','A']
|
||||
self.istest = False
|
||||
self.ds = 1
|
||||
self.n_class =2
|
||||
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
|
||||
self.netA = backbone.CDSA(in_c=opt.f_c, ds=opt.ds, mode=opt.SA_mode).to(self.device)
|
||||
|
||||
if self.isTrain:
|
||||
# define loss functions
|
||||
self.criterionF = loss.BCL()
|
||||
|
||||
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
||||
self.optimizer_G = torch.optim.Adam(itertools.chain(
|
||||
self.netF.parameters(),
|
||||
), lr=opt.lr*opt.lr_decay, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_A = torch.optim.Adam(self.netA.parameters(), lr=opt.lr*1, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_A)
|
||||
|
||||
|
||||
def set_input(self, input):
|
||||
self.A = input['A'].to(self.device)
|
||||
self.B = input['B'].to(self.device)
|
||||
if self.istest is False:
|
||||
if 'L' in input.keys():
|
||||
self.L = input['L'].to(self.device).long()
|
||||
|
||||
self.image_paths = input['A_paths']
|
||||
if self.isTrain:
|
||||
self.L_s = self.L.float()
|
||||
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
|
||||
self.L_s[self.L_s == 1] = -1 # change
|
||||
self.L_s[self.L_s == 0] = 1 # no change
|
||||
|
||||
|
||||
def test(self, val=False):
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
if val: # 返回score
|
||||
from util.metrics import RunningMetrics
|
||||
metrics = RunningMetrics(self.n_class)
|
||||
pred = self.pred_L.long()
|
||||
|
||||
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
|
||||
scores = metrics.get_cm()
|
||||
return scores
|
||||
else:
|
||||
return self.pred_L.long()
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.feat_A = self.netF(self.A) # f(A)
|
||||
self.feat_B = self.netF(self.B) # f(B)
|
||||
|
||||
self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B)
|
||||
|
||||
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) # 特征距离
|
||||
|
||||
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
|
||||
|
||||
self.pred_L = (self.dist > 1).float()
|
||||
# self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest')
|
||||
self.pred_L_show = self.pred_L.long()
|
||||
|
||||
return self.pred_L
|
||||
|
||||
def backward(self):
|
||||
self.loss_f = self.criterionF(self.dist, self.L_s)
|
||||
|
||||
self.loss = self.loss_f
|
||||
# print(self.loss)
|
||||
self.loss.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# forward
|
||||
self.forward() # compute feat and dist
|
||||
|
||||
self.set_requires_grad([self.netF, self.netA], True)
|
||||
self.optimizer_G.zero_grad() # set G's gradients to zero
|
||||
self.optimizer_A.zero_grad()
|
||||
self.backward() # calculate graidents for G
|
||||
self.optimizer_G.step() # udpate G's weights
|
||||
self.optimizer_A.step()
|
@ -1,168 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class _PAMBlock(nn.Module):
|
||||
'''
|
||||
The basic implementation for self-attention block/non-local block
|
||||
Input/Output:
|
||||
N * C * H * (2*W)
|
||||
Parameters:
|
||||
in_channels : the dimension of the input feature map
|
||||
key_channels : the dimension after the key/query transform
|
||||
value_channels : the dimension after the value transform
|
||||
scale : choose the scale to partition the input feature maps
|
||||
ds : downsampling scale
|
||||
'''
|
||||
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
|
||||
super(_PAMBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.ds = ds
|
||||
self.pool = nn.AvgPool2d(self.ds)
|
||||
self.in_channels = in_channels
|
||||
self.key_channels = key_channels
|
||||
self.value_channels = value_channels
|
||||
|
||||
self.f_key = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
||||
kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(self.key_channels)
|
||||
)
|
||||
self.f_query = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
||||
kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(self.key_channels)
|
||||
)
|
||||
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
|
||||
kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
x = input
|
||||
if self.ds != 1:
|
||||
x = self.pool(input)
|
||||
# input shape: b,c,h,2w
|
||||
batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)//2
|
||||
|
||||
|
||||
local_y = []
|
||||
local_x = []
|
||||
step_h, step_w = h//self.scale, w//self.scale
|
||||
for i in range(0, self.scale):
|
||||
for j in range(0, self.scale):
|
||||
start_x, start_y = i*step_h, j*step_w
|
||||
end_x, end_y = min(start_x+step_h, h), min(start_y+step_w, w)
|
||||
if i == (self.scale-1):
|
||||
end_x = h
|
||||
if j == (self.scale-1):
|
||||
end_y = w
|
||||
local_x += [start_x, end_x]
|
||||
local_y += [start_y, end_y]
|
||||
|
||||
value = self.f_value(x)
|
||||
query = self.f_query(x)
|
||||
key = self.f_key(x)
|
||||
|
||||
value = torch.stack([value[:, :, :, :w], value[:,:,:,w:]], 4) # B*N*H*W*2
|
||||
query = torch.stack([query[:, :, :, :w], query[:,:,:,w:]], 4) # B*N*H*W*2
|
||||
key = torch.stack([key[:, :, :, :w], key[:,:,:,w:]], 4) # B*N*H*W*2
|
||||
|
||||
|
||||
local_block_cnt = 2*self.scale*self.scale
|
||||
|
||||
# self-attention func
|
||||
def func(value_local, query_local, key_local):
|
||||
batch_size_new = value_local.size(0)
|
||||
h_local, w_local = value_local.size(2), value_local.size(3)
|
||||
value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)
|
||||
|
||||
query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
|
||||
query_local = query_local.permute(0, 2, 1)
|
||||
key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)
|
||||
|
||||
sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication
|
||||
sim_map = (self.key_channels**-.5) * sim_map
|
||||
sim_map = F.softmax(sim_map, dim=-1)
|
||||
|
||||
context_local = torch.bmm(value_local, sim_map.permute(0,2,1))
|
||||
# context_local = context_local.permute(0, 2, 1).contiguous()
|
||||
context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
|
||||
return context_local
|
||||
|
||||
# Parallel Computing to speed up
|
||||
# reshape value_local, q, k
|
||||
v_list = [value[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
|
||||
v_locals = torch.cat(v_list,dim=0)
|
||||
q_list = [query[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
|
||||
q_locals = torch.cat(q_list,dim=0)
|
||||
k_list = [key[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
|
||||
k_locals = torch.cat(k_list,dim=0)
|
||||
# print(v_locals.shape)
|
||||
context_locals = func(v_locals,q_locals,k_locals)
|
||||
|
||||
context_list = []
|
||||
for i in range(0, self.scale):
|
||||
row_tmp = []
|
||||
for j in range(0, self.scale):
|
||||
left = batch_size*(j+i*self.scale)
|
||||
right = batch_size*(j+i*self.scale) + batch_size
|
||||
tmp = context_locals[left:right]
|
||||
row_tmp.append(tmp)
|
||||
context_list.append(torch.cat(row_tmp, 3))
|
||||
|
||||
context = torch.cat(context_list, 2)
|
||||
context = torch.cat([context[:,:,:,:,0],context[:,:,:,:,1]],3)
|
||||
|
||||
|
||||
if self.ds !=1:
|
||||
context = F.interpolate(context, [h*self.ds, 2*w*self.ds])
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class PAMBlock(_PAMBlock):
|
||||
def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1):
|
||||
if key_channels == None:
|
||||
key_channels = in_channels//8
|
||||
if value_channels == None:
|
||||
value_channels = in_channels
|
||||
super(PAMBlock, self).__init__(in_channels,key_channels,value_channels,scale,ds)
|
||||
|
||||
|
||||
class PAM(nn.Module):
|
||||
"""
|
||||
PAM module
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
|
||||
super(PAM, self).__init__()
|
||||
self.group = len(sizes)
|
||||
self.stages = []
|
||||
self.ds = ds # output stride
|
||||
self.value_channels = out_channels
|
||||
self.key_channels = out_channels // 8
|
||||
|
||||
|
||||
self.stages = nn.ModuleList(
|
||||
[self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds)
|
||||
for size in sizes])
|
||||
self.conv_bn = nn.Sequential(
|
||||
nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0,bias=False),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def _make_stage(self, in_channels, key_channels, value_channels, size, ds):
|
||||
return PAMBlock(in_channels,key_channels,value_channels,size,ds)
|
||||
|
||||
def forward(self, feats):
|
||||
priors = [stage(feats) for stage in self.stages]
|
||||
|
||||
# concat
|
||||
context = []
|
||||
for i in range(0, len(priors)):
|
||||
context += [priors[i]]
|
||||
output = self.conv_bn(torch.cat(context, 1))
|
||||
|
||||
return output
|
@ -1,67 +0,0 @@
|
||||
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
||||
|
||||
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
||||
You need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
|
||||
In the function <__init__>, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
||||
|
||||
Now you can use the model class by specifying flag '--model dummy'.
|
||||
See our template model class 'template_model.py' for more details.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from models.base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
model_filename = "models." + model_name + "_model"
|
||||
modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
@ -1,51 +0,0 @@
|
||||
# coding: utf-8
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .mynet3 import F_mynet3
|
||||
from .BAM import BAM
|
||||
from .PAM2 import PAM as PAM
|
||||
|
||||
|
||||
|
||||
def define_F(in_c, f_c, type='unet'):
|
||||
if type == 'mynet3':
|
||||
print("using mynet3 backbone")
|
||||
return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
|
||||
else:
|
||||
NotImplementedError('no such F type!')
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
|
||||
class CDSA(nn.Module):
|
||||
"""self attention module for change detection
|
||||
|
||||
"""
|
||||
def __init__(self, in_c, ds=1, mode='BAM'):
|
||||
super(CDSA, self).__init__()
|
||||
self.in_C = in_c
|
||||
self.ds = ds
|
||||
print('ds: ',self.ds)
|
||||
self.mode = mode
|
||||
if self.mode == 'BAM':
|
||||
self.Self_Att = BAM(self.in_C, ds=self.ds)
|
||||
elif self.mode == 'PAM':
|
||||
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
|
||||
self.apply(weights_init)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
height = x1.shape[3]
|
||||
x = torch.cat((x1, x2), 3)
|
||||
x = self.Self_Att(x)
|
||||
return x[:,:,:,0:height], x[:,:,:,height:]
|
||||
|
||||
|
||||
|
||||
|
@ -1,333 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
|
||||
def get_scheduler(optimizer, opt):
|
||||
"""Return a learning rate scheduler
|
||||
|
||||
Parameters:
|
||||
optimizer -- the optimizer of the network
|
||||
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||||
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||||
|
||||
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
||||
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
||||
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||||
See https://pytorch.org/docs/stable/optim.html for more details.
|
||||
"""
|
||||
if opt.lr_policy == 'linear':
|
||||
def lambda_rule(epoch):
|
||||
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
||||
return lr_l
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
elif opt.lr_policy == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||||
elif opt.lr_policy == 'plateau':
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||||
elif opt.lr_policy == 'cosine':
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||
return scheduler
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""This class is an abstract base class (ABC) for models.
|
||||
To create a subclass, you need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the BaseModel class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
When creating your custom class, you need to implement your own initialization.
|
||||
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
||||
Then, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): specify the images that you want to display and save.
|
||||
-- self.visual_names (str list): define networks used in our training.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
"""
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||||
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.loss_names = []
|
||||
self.model_names = []
|
||||
self.visual_names = []
|
||||
self.visual_features = []
|
||||
self.optimizers = []
|
||||
self.image_paths = []
|
||||
self.metric = 0 # used for learning rate policy 'plateau'
|
||||
self.istest = True if opt.phase == 'test' else False # 如果是测试,该模式下,没有标注样本;
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new model-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): includes the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
pass
|
||||
|
||||
def setup(self, opt):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
if self.isTrain:
|
||||
self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||||
if not self.isTrain or opt.continue_train:
|
||||
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
||||
self.load_networks(load_suffix)
|
||||
self.print_networks(opt.verbose)
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode during test time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.eval()
|
||||
|
||||
|
||||
def train(self):
|
||||
"""Make models train mode during train time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.train()
|
||||
|
||||
def test(self):
|
||||
"""Forward function used in test time.
|
||||
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
|
||||
def compute_visuals(self):
|
||||
"""Calculate additional output images for visdom and HTML visualization"""
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
""" Return image paths that are used to load current data"""
|
||||
return self.image_paths
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||||
for scheduler in self.schedulers:
|
||||
if self.opt.lr_policy == 'plateau':
|
||||
scheduler.step(self.metric)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
lr = self.optimizers[0].param_groups[0]['lr']
|
||||
print('learning rate = %.7f' % lr)
|
||||
|
||||
def get_current_visuals(self):
|
||||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||||
visual_ret = OrderedDict()
|
||||
for name in self.visual_names:
|
||||
if isinstance(name, str):
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def get_current_losses(self):
|
||||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||||
errors_ret = OrderedDict()
|
||||
for name in self.loss_names:
|
||||
if isinstance(name, str):
|
||||
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||||
return errors_ret
|
||||
|
||||
def save_networks(self, epoch):
|
||||
"""Save all the networks to the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
# torch.save(net.module.cpu().state_dict(), save_path)
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
net.cuda(self.gpu_ids[0])
|
||||
else:
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
def get_visual(self, name):
|
||||
visual_ret = {}
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def pred_large(self, A, B, input_size=256, stride=0):
|
||||
"""
|
||||
输入前后时相的大图,获得预测结果
|
||||
假定预测结果中心部分为准确,边缘padding = (input_size-stride)/2
|
||||
:param A: tensor, N*C*H*W
|
||||
:param B: tensor, N*C*H*W
|
||||
:param input_size: int, 输入网络的图像size
|
||||
:param stride: int, 预测时的跨步
|
||||
:return: pred, tensor, N*1*H*W
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
n, c, h, w = A.shape
|
||||
assert A.shape == B.shape
|
||||
# 分块数量
|
||||
n_h = math.ceil((h - input_size) / stride) + 1
|
||||
n_w = math.ceil((w - input_size) / stride) + 1
|
||||
# 重新计算长宽
|
||||
new_h = (n_h - 1) * stride + input_size
|
||||
new_w = (n_w - 1) * stride + input_size
|
||||
print("new_h: ", new_h)
|
||||
print("new_w: ", new_w)
|
||||
print("n_h: ", n_h)
|
||||
print("n_w: ", n_w)
|
||||
new_A = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
|
||||
new_B = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
|
||||
new_A[:, :, :h, :w] = A
|
||||
new_B[:, :, :h, :w] = B
|
||||
new_pred = torch.zeros([n, 1, new_h, new_w], dtype=torch.uint8)
|
||||
del A
|
||||
del B
|
||||
#
|
||||
for i in range(0, new_h - input_size + 1, stride):
|
||||
for j in range(0, new_w - input_size + 1, stride):
|
||||
left = j
|
||||
right = input_size + j
|
||||
top = i
|
||||
bottom = input_size + i
|
||||
patch_A = new_A[:, :, top:bottom, left:right]
|
||||
patch_B = new_B[:, :, top:bottom, left:right]
|
||||
# print(left,' ',right,' ', top,' ', bottom)
|
||||
self.A = patch_A.to(self.device)
|
||||
self.B = patch_B.to(self.device)
|
||||
with torch.no_grad():
|
||||
patch_pred = self.forward()
|
||||
new_pred[:, :, top:bottom, left:right] = patch_pred.detach().cpu()
|
||||
pred = new_pred[:, :, :h, :w]
|
||||
return pred
|
||||
|
||||
def load_networks(self, epoch):
|
||||
"""Load all the networks from the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
load_path = os.path.join(self.save_dir, load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
# if isinstance(net, torch.nn.DataParallel):
|
||||
# net = net.module
|
||||
# net = net.module # 适配保存的module
|
||||
print('loading the model from %s' % load_path)
|
||||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||||
# GitHub source), you can remove str() on self.device
|
||||
state_dict = torch.load(load_path, map_location=str(self.device))
|
||||
|
||||
if hasattr(state_dict, '_metadata'):
|
||||
del state_dict._metadata
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||||
# print(key)
|
||||
net.load_state_dict(state_dict,strict=False)
|
||||
|
||||
def print_networks(self, verbose):
|
||||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||||
|
||||
Parameters:
|
||||
verbose (bool) -- if verbose: print the network architecture
|
||||
"""
|
||||
print('---------- Networks initialized -------------')
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
if verbose:
|
||||
print(net)
|
||||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_requires_grad(self, nets, requires_grad=False):
|
||||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||
Parameters:
|
||||
nets (network list) -- a list of networks
|
||||
requires_grad (bool) -- whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
A = torch.rand([1,3,512,512],dtype=torch.float32)
|
||||
B = torch.rand([1,3,512,512],dtype=torch.float32)
|
@ -1,29 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
class BCL(nn.Module):
|
||||
"""
|
||||
batch-balanced contrastive loss
|
||||
no-change,1
|
||||
change,-1
|
||||
"""
|
||||
|
||||
def __init__(self, margin=2.0):
|
||||
super(BCL, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, distance, label):
|
||||
label[label==255] = 1
|
||||
mask = (label != 255).float()
|
||||
distance = distance * mask
|
||||
pos_num = torch.sum((label==1).float())+0.0001
|
||||
neg_num = torch.sum((label==-1).float())+0.0001
|
||||
|
||||
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
|
||||
loss_2 = torch.sum((1-label) / 2 * mask *
|
||||
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
|
||||
) / neg_num
|
||||
loss = loss_1 + loss_2
|
||||
return loss
|
||||
|
@ -1,352 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import math
|
||||
|
||||
|
||||
class F_mynet3(nn.Module):
|
||||
def __init__(self, backbone='resnet18',in_c=3, f_c=64, output_stride=8):
|
||||
self.in_c = in_c
|
||||
super(F_mynet3, self).__init__()
|
||||
self.module = mynet3(backbone=backbone, output_stride=output_stride, f_c=f_c, in_c=self.in_c)
|
||||
def forward(self, input):
|
||||
return self.module(input)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
|
||||
"""
|
||||
output, low_level_feat:
|
||||
512, 64
|
||||
"""
|
||||
print(in_c)
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
|
||||
if in_c != 3:
|
||||
pretrained = False
|
||||
if pretrained:
|
||||
model._load_pretrained_model(model_urls['resnet34'])
|
||||
return model
|
||||
|
||||
def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
|
||||
"""
|
||||
output, low_level_feat:
|
||||
512, 256, 128, 64, 64
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
|
||||
if in_c !=3:
|
||||
pretrained=False
|
||||
if pretrained:
|
||||
model._load_pretrained_model(model_urls['resnet18'])
|
||||
return model
|
||||
|
||||
def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
|
||||
"""
|
||||
output, low_level_feat:
|
||||
2048, 256
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
|
||||
if in_c !=3:
|
||||
pretrained=False
|
||||
if pretrained:
|
||||
model._load_pretrained_model(model_urls['resnet50'])
|
||||
return model
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
||||
dilation=dilation, padding=dilation, bias=False)
|
||||
self.bn1 = BatchNorm(planes)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = BatchNorm(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = BatchNorm(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
dilation=dilation, padding=dilation, bias=False)
|
||||
self.bn2 = BatchNorm(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = BatchNorm(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
|
||||
|
||||
self.inplanes = 64
|
||||
self.in_c = in_c
|
||||
print('in_c: ',self.in_c)
|
||||
super(ResNet, self).__init__()
|
||||
blocks = [1, 2, 4]
|
||||
if output_stride == 32:
|
||||
strides = [1, 2, 2, 2]
|
||||
dilations = [1, 1, 1, 1]
|
||||
elif output_stride == 16:
|
||||
strides = [1, 2, 2, 1]
|
||||
dilations = [1, 1, 1, 2]
|
||||
elif output_stride == 8:
|
||||
strides = [1, 2, 1, 1]
|
||||
dilations = [1, 1, 2, 4]
|
||||
elif output_stride == 4:
|
||||
strides = [1, 1, 1, 1]
|
||||
dilations = [1, 2, 4, 8]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Modules
|
||||
self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = BatchNorm(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
||||
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
||||
# self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
||||
self._init_weight()
|
||||
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
BatchNorm(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
BatchNorm(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
|
||||
downsample=downsample, BatchNorm=BatchNorm))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, len(blocks)):
|
||||
layers.append(block(self.inplanes, planes, stride=1,
|
||||
dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv1(input)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x) # | 4
|
||||
x = self.layer1(x) # | 4
|
||||
low_level_feat2 = x # | 4
|
||||
x = self.layer2(x) # | 8
|
||||
low_level_feat3 = x
|
||||
x = self.layer3(x) # | 16
|
||||
low_level_feat4 = x
|
||||
x = self.layer4(x) # | 32
|
||||
return x, low_level_feat2, low_level_feat3, low_level_feat4
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _load_pretrained_model(self, model_path):
|
||||
pretrain_dict = model_zoo.load_url(model_path)
|
||||
model_dict = {}
|
||||
state_dict = self.state_dict()
|
||||
for k, v in pretrain_dict.items():
|
||||
if k in state_dict:
|
||||
model_dict[k] = v
|
||||
state_dict.update(model_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
|
||||
def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
|
||||
if backbone == 'resnet50':
|
||||
return ResNet50(output_stride, BatchNorm, in_c=in_c)
|
||||
elif backbone == 'resnet34':
|
||||
return ResNet34(output_stride, BatchNorm, in_c=in_c)
|
||||
elif backbone == 'resnet18':
|
||||
return ResNet18(output_stride, BatchNorm, in_c=in_c)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class DR(nn.Module):
|
||||
def __init__(self, in_d, out_d):
|
||||
super(DR, self).__init__()
|
||||
self.in_d = in_d
|
||||
self.out_d = out_d
|
||||
self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.out_d)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv1(input)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, fc, BatchNorm):
|
||||
super(Decoder, self).__init__()
|
||||
self.fc = fc
|
||||
self.dr2 = DR(64, 96)
|
||||
self.dr3 = DR(128, 96)
|
||||
self.dr4 = DR(256, 96)
|
||||
self.dr5 = DR(512, 96)
|
||||
self.last_conv = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
BatchNorm(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Conv2d(256, self.fc, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
BatchNorm(self.fc),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self._init_weight()
|
||||
|
||||
|
||||
def forward(self, x,low_level_feat2, low_level_feat3, low_level_feat4):
|
||||
|
||||
# x1 = self.dr1(low_level_feat1)
|
||||
x2 = self.dr2(low_level_feat2)
|
||||
x3 = self.dr3(low_level_feat3)
|
||||
x4 = self.dr4(low_level_feat4)
|
||||
x = self.dr5(x)
|
||||
x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
|
||||
# x2 = F.interpolate(x2, size=x3.size()[2:], mode='bilinear', align_corners=True)
|
||||
x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True)
|
||||
x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True)
|
||||
|
||||
x = torch.cat((x, x2, x3, x4), dim=1)
|
||||
|
||||
x = self.last_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
torch.nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def build_decoder(fc, backbone, BatchNorm):
|
||||
return Decoder(fc, BatchNorm)
|
||||
|
||||
|
||||
class mynet3(nn.Module):
|
||||
def __init__(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3):
|
||||
super(mynet3, self).__init__()
|
||||
print('arch: mynet3')
|
||||
BatchNorm = nn.BatchNorm2d
|
||||
self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c)
|
||||
self.decoder = build_decoder(f_c, backbone, BatchNorm)
|
||||
|
||||
if freeze_bn:
|
||||
self.freeze_bn()
|
||||
|
||||
def forward(self, input):
|
||||
x, f2, f3, f4 = self.backbone(input)
|
||||
x = self.decoder(x, f2, f3, f4)
|
||||
return x
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
@ -1,147 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from util import util
|
||||
import torch
|
||||
import models
|
||||
import data
|
||||
|
||||
|
||||
class BaseOptions():
|
||||
"""This class defines options used during both training and test time.
|
||||
|
||||
It also implements several helper functions such as parsing, printing, and saving the options.
|
||||
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Reset the class; indicates the class hasn't been initailized"""
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, parser):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
# basic parameters
|
||||
parser.add_argument('--dataroot', type=str, default='./LEVIR-CD', help='path to images (should have subfolders A, B, label)')
|
||||
parser.add_argument('--val_dataroot', type=str, default='./LEVIR-CD', help='path to images in the val phase (should have subfolders A, B, label)')
|
||||
|
||||
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
||||
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
||||
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
# model parameters
|
||||
parser.add_argument('--model', type=str, default='CDF0', help='chooses which model to use. [CDF0 | CDFA]')
|
||||
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ')
|
||||
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB')
|
||||
parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3')
|
||||
parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num')
|
||||
parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes')
|
||||
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
||||
parser.add_argument('--SA_mode', type=str, default='BAM', help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...')
|
||||
# dataset parameters
|
||||
parser.add_argument('--dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat | list | json]')
|
||||
parser.add_argument('--val_dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat| list | json]')
|
||||
parser.add_argument('--dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]')
|
||||
parser.add_argument('--val_dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]')
|
||||
parser.add_argument('--split', type=str, default='train', help='chooses wihch list-file to open when use listDataset. [train | val | test]')
|
||||
parser.add_argument('--val_split', type=str, default='val', help='chooses wihch list-file to open when use listDataset. [train | val | test]')
|
||||
parser.add_argument('--json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase')
|
||||
parser.add_argument('--val_json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase')
|
||||
parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate')
|
||||
parser.add_argument('--angle', type=int, default=0, help='rotate angle')
|
||||
parser.add_argument('--istest', type=bool, default=False, help='True for the case without label')
|
||||
|
||||
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
|
||||
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | none]')
|
||||
parser.add_argument('--no_flip', type=bool, default=True, help='if specified, do not flip(left-right) the images for data augmentation')
|
||||
|
||||
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
||||
# additional parameters
|
||||
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
|
||||
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
||||
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
||||
self.initialized = True
|
||||
return parser
|
||||
|
||||
def gather_options(self):
|
||||
"""Initialize our parser with basic options(only once).
|
||||
Add additional model-specific and dataset-specific options.
|
||||
These options are defined in the <modify_commandline_options> function
|
||||
in model and dataset classes.
|
||||
"""
|
||||
if not self.initialized: # check if it has been initialized
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = self.initialize(parser)
|
||||
|
||||
# get the basic options
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
# modify model-related parser options
|
||||
model_name = opt.model
|
||||
model_option_setter = models.get_option_setter(model_name)
|
||||
parser = model_option_setter(parser, self.isTrain)
|
||||
opt, _ = parser.parse_known_args() # parse again with new defaults
|
||||
|
||||
# modify dataset-related parser options
|
||||
dataset_name = opt.dataset_mode
|
||||
if dataset_name != 'concat':
|
||||
dataset_option_setter = data.get_option_setter(dataset_name)
|
||||
parser = dataset_option_setter(parser, self.isTrain)
|
||||
|
||||
# save and return the parser
|
||||
self.parser = parser
|
||||
return parser.parse_args()
|
||||
|
||||
def print_options(self, opt):
|
||||
"""Print and save options
|
||||
|
||||
It will print both current options and default values(if different).
|
||||
It will save options into a text file / [checkpoints_dir] / opt.txt
|
||||
"""
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
|
||||
# save to the disk
|
||||
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
util.mkdirs(expr_dir)
|
||||
file_name = os.path.join(expr_dir, 'opt.txt')
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write(message)
|
||||
opt_file.write('\n')
|
||||
|
||||
def parse(self):
|
||||
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
||||
opt = self.gather_options()
|
||||
opt.isTrain = self.isTrain # train or test
|
||||
|
||||
# process opt.suffix
|
||||
if opt.suffix:
|
||||
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
||||
opt.name = opt.name + suffix
|
||||
|
||||
self.print_options(opt)
|
||||
|
||||
# set gpu ids
|
||||
str_ids = opt.gpu_ids.split(',')
|
||||
opt.gpu_ids = []
|
||||
for str_id in str_ids:
|
||||
id = int(str_id)
|
||||
if id >= 0:
|
||||
opt.gpu_ids.append(id)
|
||||
if len(opt.gpu_ids) > 0:
|
||||
torch.cuda.set_device(opt.gpu_ids[0])
|
||||
|
||||
self.opt = opt
|
||||
return self.opt
|
@ -1,22 +0,0 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
"""This class includes test options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser) # define shared options
|
||||
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
|
||||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
||||
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
||||
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
# Dropout and Batchnorm has different behavioir during training and test.
|
||||
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
||||
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
||||
# To avoid cropping, the load_size should be the same as crop_size
|
||||
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
||||
self.isTrain = False
|
||||
return parser
|
@ -1,40 +0,0 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TrainOptions(BaseOptions):
|
||||
"""This class includes training options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser)
|
||||
# visdom and HTML visualization parameters
|
||||
parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
|
||||
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
|
||||
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
|
||||
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
|
||||
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
|
||||
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
||||
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
|
||||
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
||||
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
||||
# network saving and loading parameters
|
||||
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
||||
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
|
||||
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
|
||||
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
||||
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
||||
parser.add_argument('--lr_decay', type=float, default=1, help='learning rate decay for certain module ...')
|
||||
|
||||
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
# training parameters
|
||||
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
|
||||
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
||||
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
|
||||
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
|
||||
|
||||
self.isTrain = True
|
||||
return parser
|
Before Width: | Height: | Size: 77 KiB |
Before Width: | Height: | Size: 100 KiB |
Before Width: | Height: | Size: 128 KiB |
Before Width: | Height: | Size: 137 KiB |
Before Width: | Height: | Size: 112 KiB |
Before Width: | Height: | Size: 155 KiB |
Before Width: | Height: | Size: 124 KiB |
Before Width: | Height: | Size: 125 KiB |
Before Width: | Height: | Size: 127 KiB |
Before Width: | Height: | Size: 130 KiB |
Before Width: | Height: | Size: 135 KiB |
Before Width: | Height: | Size: 131 KiB |
Before Width: | Height: | Size: 1.2 KiB |
Before Width: | Height: | Size: 2.0 KiB |
Before Width: | Height: | Size: 1.0 KiB |
Before Width: | Height: | Size: 1.7 KiB |
Before Width: | Height: | Size: 1.7 KiB |
Before Width: | Height: | Size: 922 B |
@ -1,5 +0,0 @@
|
||||
|
||||
|
||||
# train PAM
|
||||
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
|
||||
|
@ -1,11 +0,0 @@
|
||||
|
||||
# concat mode
|
||||
list=train
|
||||
lr=0.001
|
||||
dataset_type=CD_data1,CD_data2,...,
|
||||
dataset_type=LEVIR_CD1
|
||||
val_dataset_type=LEVIR_CD1
|
||||
dataset_mode=concat
|
||||
name=project_name
|
||||
|
||||
python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
|
@ -1,9 +0,0 @@
|
||||
|
||||
# List mode
|
||||
list=train
|
||||
lr=0.001
|
||||
dataset_mode=list
|
||||
dataroot=path-to-dataroot
|
||||
name=project_name
|
||||
#
|
||||
python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 25 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
|
Before Width: | Height: | Size: 434 KiB |
@ -1,105 +0,0 @@
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
from util.util import save_images
|
||||
import numpy as np
|
||||
from util.util import mkdir
|
||||
import argparse
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
def transform():
|
||||
transform_list = []
|
||||
transform_list += [transforms.ToTensor()]
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def val(opt):
|
||||
image_1_path = opt.image1_path
|
||||
image_2_path = opt.image2_path
|
||||
A_img = Image.open(image_1_path).convert('RGB')
|
||||
B_img = Image.open(image_2_path).convert('RGB')
|
||||
trans = transform()
|
||||
A = trans(A_img).unsqueeze(0)
|
||||
B = trans(B_img).unsqueeze(0)
|
||||
# dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
save_path = opt.results_dir
|
||||
mkdir(save_path)
|
||||
model.eval()
|
||||
data = {}
|
||||
data['A']= A
|
||||
data['B'] = B
|
||||
data['A_paths'] = [image_1_path]
|
||||
|
||||
model.set_input(data) # unpack data from data loader
|
||||
pred = model.test(val=False) # run inference return pred
|
||||
|
||||
img_path = [image_1_path] # get image paths
|
||||
|
||||
save_images(pred, save_path, img_path)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 从外界调用方式:
|
||||
# python test.py --image1_path [path-to-img1] --image2_path [path-to-img2] --results_dir [path-to-result_dir]
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument('--image1_path', type=str, default='./samples/A/test_2_0000_0000.png',
|
||||
help='path to images A')
|
||||
parser.add_argument('--image2_path', type=str, default='./samples/B/test_2_0000_0000.png',
|
||||
help='path to images B')
|
||||
parser.add_argument('--results_dir', type=str, default='./samples/output/', help='saves results here.')
|
||||
|
||||
parser.add_argument('--name', type=str, default='pam',
|
||||
help='name of the experiment. It decides where to store samples and models')
|
||||
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
||||
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
# model parameters
|
||||
parser.add_argument('--model', type=str, default='CDFA', help='chooses which model to use. [CDF0 | CDFA]')
|
||||
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ')
|
||||
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB')
|
||||
parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3')
|
||||
parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num')
|
||||
parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes')
|
||||
|
||||
parser.add_argument('--SA_mode', type=str, default='PAM',
|
||||
help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...')
|
||||
# dataset parameters
|
||||
parser.add_argument('--dataset_mode', type=str, default='changedetection',
|
||||
help='chooses how datasets are loaded. [changedetection | json]')
|
||||
parser.add_argument('--val_dataset_mode', type=str, default='changedetection',
|
||||
help='chooses how datasets are loaded. [changedetection | json]')
|
||||
parser.add_argument('--split', type=str, default='train',
|
||||
help='chooses wihch list-file to open when use listDataset. [train | val | test]')
|
||||
parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate')
|
||||
parser.add_argument('--angle', type=int, default=0, help='rotate angle')
|
||||
parser.add_argument('--istest', type=bool, default=False, help='True for the case without label')
|
||||
parser.add_argument('--serial_batches', action='store_true',
|
||||
help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
|
||||
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
|
||||
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
parser.add_argument('--preprocess', type=str, default='resize_and_crop',
|
||||
help='scaling and cropping of images at load time [resize_and_crop | none]')
|
||||
parser.add_argument('--no_flip', type=bool, default=True,
|
||||
help='if specified, do not flip(left-right) the images for data augmentation')
|
||||
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
||||
parser.add_argument('--epoch', type=str, default='pam',
|
||||
help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--load_iter', type=int, default='0',
|
||||
help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
|
||||
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
||||
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
parser.add_argument('--isTrain', type=bool, default=False, help='is or not')
|
||||
parser.add_argument('--num_test', type=int, default=np.inf, help='how many test images to run')
|
||||
|
||||
opt = parser.parse_args()
|
||||
val(opt)
|
@ -1,181 +0,0 @@
|
||||
import time
|
||||
from options.train_options import TrainOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
from util.visualizer import Visualizer
|
||||
import os
|
||||
from util import html
|
||||
from util.visualizer import save_images
|
||||
from util.metrics import AverageMeter
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
def seed_torch(seed=2019):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
# set seeds
|
||||
# seed_torch(2019)
|
||||
ifSaveImage = False
|
||||
|
||||
def make_val_opt(opt):
|
||||
|
||||
val_opt = copy.deepcopy(opt)
|
||||
val_opt.preprocess = '' #
|
||||
# hard-code some parameters for test
|
||||
val_opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
val_opt.batch_size = 4 # test code only supports batch_size = 1
|
||||
val_opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
val_opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
val_opt.angle = 0
|
||||
val_opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
val_opt.phase = 'val'
|
||||
val_opt.split = opt.val_split # function in jsonDataset and ListDataset
|
||||
val_opt.isTrain = False
|
||||
val_opt.aspect_ratio = 1
|
||||
val_opt.results_dir = './results/'
|
||||
val_opt.dataroot = opt.val_dataroot
|
||||
val_opt.dataset_mode = opt.val_dataset_mode
|
||||
val_opt.dataset_type = opt.val_dataset_type
|
||||
val_opt.json_name = opt.val_json_name
|
||||
val_opt.eval = True
|
||||
|
||||
val_opt.num_test = 2000
|
||||
return val_opt
|
||||
|
||||
|
||||
def print_current_acc(log_name, epoch, score):
|
||||
"""print current acc on console; also save the losses to the disk
|
||||
Parameters:
|
||||
"""
|
||||
message = '(epoch: %d) ' % epoch
|
||||
for k, v in score.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
print(message) # print the message
|
||||
with open(log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message) # save the message
|
||||
|
||||
|
||||
def val(opt, model):
|
||||
opt = make_val_opt(opt)
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
# model = create_model(opt) # create a model given opt.model and other options
|
||||
# model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
|
||||
web_dir = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
|
||||
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
|
||||
model.eval()
|
||||
# create a logging file to store training losses
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val_log.txt')
|
||||
with open(log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ val acc (%s) ================\n' % now)
|
||||
|
||||
running_metrics = AverageMeter()
|
||||
for i, data in enumerate(dataset):
|
||||
if i >= opt.num_test: # only apply our model to opt.num_test images.
|
||||
break
|
||||
model.set_input(data) # unpack data from data loader
|
||||
score = model.test(val=True) # run inference
|
||||
running_metrics.update(score)
|
||||
visuals = model.get_current_visuals() # get image results
|
||||
img_path = model.get_image_paths() # get image paths
|
||||
if i % 5 == 0: # save images to an HTML file
|
||||
print('processing (%04d)-th image... %s' % (i, img_path))
|
||||
if ifSaveImage:
|
||||
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
|
||||
|
||||
score = running_metrics.get_scores()
|
||||
print_current_acc(log_name, epoch, score)
|
||||
if opt.display_id > 0:
|
||||
visualizer.plot_current_acc(epoch, float(epoch_iter) / dataset_size, score)
|
||||
webpage.save() # save the HTML
|
||||
|
||||
return score[metric_name]
|
||||
|
||||
metric_name = 'F1_1'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TrainOptions().parse() # get training options
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
dataset_size = len(dataset) # get the number of images in the dataset.
|
||||
print('The number of training images = %d' % dataset_size)
|
||||
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
|
||||
total_iters = 0 # the total number of training iterations
|
||||
miou_best = 0
|
||||
n_epoch_bad = 0
|
||||
epoch_best = 0
|
||||
time_metric = AverageMeter()
|
||||
time_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'time_log.txt')
|
||||
with open(time_log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ training time (%s) ================\n' % now)
|
||||
|
||||
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
|
||||
epoch_start_time = time.time() # timer for entire epoch
|
||||
iter_data_time = time.time() # timer for data loading per iteration
|
||||
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
|
||||
model.train()
|
||||
# miou_current = val(opt, model)
|
||||
for i, data in enumerate(dataset): # inner loop within one epoch
|
||||
iter_start_time = time.time() # timer for computation per iteration
|
||||
if total_iters % opt.print_freq == 0:
|
||||
t_data = iter_start_time - iter_data_time
|
||||
visualizer.reset()
|
||||
total_iters += opt.batch_size
|
||||
epoch_iter += opt.batch_size
|
||||
n_epoch = opt.niter + opt.niter_decay
|
||||
|
||||
model.set_input(data) # unpack data from dataset and apply preprocessing
|
||||
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
||||
if ifSaveImage:
|
||||
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
|
||||
save_result = total_iters % opt.update_html_freq == 0
|
||||
model.compute_visuals()
|
||||
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
|
||||
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
|
||||
losses = model.get_current_losses()
|
||||
t_comp = (time.time() - iter_start_time) / opt.batch_size
|
||||
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
|
||||
if opt.display_id > 0:
|
||||
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
|
||||
|
||||
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
|
||||
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
|
||||
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
|
||||
model.save_networks(save_suffix)
|
||||
|
||||
iter_data_time = time.time()
|
||||
|
||||
t_epoch = time.time()-epoch_start_time
|
||||
time_metric.update(t_epoch)
|
||||
print_current_acc(time_log_name, epoch,{"current_t_epoch": t_epoch})
|
||||
|
||||
|
||||
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
|
||||
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
|
||||
model.save_networks('latest')
|
||||
miou_current = val(opt, model)
|
||||
|
||||
if miou_current > miou_best:
|
||||
miou_best = miou_current
|
||||
epoch_best = epoch
|
||||
model.save_networks(str(epoch_best)+"_"+metric_name+'_'+'%0.5f'% miou_best)
|
||||
|
||||
|
||||
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
|
||||
model.update_learning_rate() # update learning rates at the end of every epoch.
|
||||
|
||||
time_ave = time_metric.average()
|
||||
print_current_acc(time_log_name, epoch, {"ave_t_epoch": time_ave})
|
@ -1 +0,0 @@
|
||||
"""This package includes a miscellaneous collection of useful helper functions."""
|
@ -1,86 +0,0 @@
|
||||
import dominate
|
||||
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
||||
import os
|
||||
|
||||
|
||||
class HTML:
|
||||
"""This HTML class allows us to save images and write texts into a single HTML file.
|
||||
|
||||
It consists of functions such as <add_header> (add a text header to the HTML file),
|
||||
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
||||
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
||||
"""
|
||||
|
||||
def __init__(self, web_dir, title, refresh=0):
|
||||
"""Initialize the HTML classes
|
||||
|
||||
Parameters:
|
||||
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
||||
title (str) -- the webpage name
|
||||
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
||||
"""
|
||||
self.title = title
|
||||
self.web_dir = web_dir
|
||||
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||
if not os.path.exists(self.web_dir):
|
||||
os.makedirs(self.web_dir)
|
||||
if not os.path.exists(self.img_dir):
|
||||
os.makedirs(self.img_dir)
|
||||
|
||||
self.doc = dominate.document(title=title)
|
||||
if refresh > 0:
|
||||
with self.doc.head:
|
||||
meta(http_equiv="refresh", content=str(refresh))
|
||||
|
||||
def get_image_dir(self):
|
||||
"""Return the directory that stores images"""
|
||||
return self.img_dir
|
||||
|
||||
def add_header(self, text):
|
||||
"""Insert a header to the HTML file
|
||||
|
||||
Parameters:
|
||||
text (str) -- the header text
|
||||
"""
|
||||
with self.doc:
|
||||
h3(text)
|
||||
|
||||
def add_images(self, ims, txts, links, width=400):
|
||||
"""add images to the HTML file
|
||||
|
||||
Parameters:
|
||||
ims (str list) -- a list of image paths
|
||||
txts (str list) -- a list of image names shown on the website
|
||||
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
||||
"""
|
||||
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
||||
self.doc.add(self.t)
|
||||
with self.t:
|
||||
with tr():
|
||||
for im, txt, link in zip(ims, txts, links):
|
||||
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
||||
with p():
|
||||
with a(href=os.path.join('images', link)):
|
||||
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
||||
br()
|
||||
p(txt)
|
||||
|
||||
def save(self):
|
||||
"""save the current content to the HMTL file"""
|
||||
html_file = '%s/index.html' % self.web_dir
|
||||
f = open(html_file, 'wt')
|
||||
f.write(self.doc.render())
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__': # we show an example usage here.
|
||||
html = HTML('web/', 'test_html')
|
||||
html.add_header('hello world')
|
||||
|
||||
ims, txts, links = [], [], []
|
||||
for n in range(4):
|
||||
ims.append('image_%d.png' % n)
|
||||
txts.append('text_%d' % n)
|
||||
links.append('image_%d.png' % n)
|
||||
html.add_images(ims, txts, links)
|
||||
html.save()
|
@ -1,176 +0,0 @@
|
||||
# Adapted from score written by wkentaro
|
||||
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
|
||||
import numpy as np
|
||||
eps=np.finfo(float).eps
|
||||
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.val = None
|
||||
self.avg = None
|
||||
self.sum = None
|
||||
self.count = None
|
||||
|
||||
def initialize(self, val, weight):
|
||||
self.val = val
|
||||
self.avg = val
|
||||
self.sum = val * weight
|
||||
self.count = weight
|
||||
self.initialized = True
|
||||
|
||||
def update(self, val, weight=1):
|
||||
if not self.initialized:
|
||||
self.initialize(val, weight)
|
||||
else:
|
||||
self.add(val, weight)
|
||||
|
||||
def add(self, val, weight):
|
||||
self.val = val
|
||||
self.sum += val * weight
|
||||
self.count += weight
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def value(self):
|
||||
return self.val
|
||||
|
||||
def average(self):
|
||||
return self.avg
|
||||
|
||||
def get_scores(self):
|
||||
scores, cls_iu, m_1 = cm2score(self.sum)
|
||||
scores.update(cls_iu)
|
||||
scores.update(m_1)
|
||||
return scores
|
||||
|
||||
|
||||
def cm2score(confusion_matrix):
|
||||
hist = confusion_matrix
|
||||
n_class = hist.shape[0]
|
||||
tp = np.diag(hist)
|
||||
sum_a1 = hist.sum(axis=1)
|
||||
sum_a0 = hist.sum(axis=0)
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 1. Accuracy & Class Accuracy
|
||||
# ---------------------------------------------------------------------- #
|
||||
acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
|
||||
|
||||
acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps)
|
||||
|
||||
# precision
|
||||
precision = tp / (sum_a0 + np.finfo(np.float32).eps)
|
||||
|
||||
# F1 score
|
||||
F1 = 2*acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps)
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 2. Mean IoU
|
||||
# ---------------------------------------------------------------------- #
|
||||
iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
|
||||
mean_iu = np.nanmean(iu)
|
||||
|
||||
cls_iu = dict(zip(range(n_class), iu))
|
||||
|
||||
|
||||
|
||||
return {'Overall_Acc': acc,
|
||||
'Mean_IoU': mean_iu}, cls_iu, \
|
||||
{
|
||||
'precision_1': precision[1],
|
||||
'recall_1': acc_cls_[1],
|
||||
'F1_1': F1[1],}
|
||||
|
||||
|
||||
class RunningMetrics(object):
|
||||
def __init__(self, num_classes):
|
||||
"""
|
||||
Computes and stores the Metric values from Confusion Matrix
|
||||
- overall accuracy
|
||||
- mean accuracy
|
||||
- mean IU
|
||||
- fwavacc
|
||||
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
|
||||
:param num_classes: <int> number of classes
|
||||
"""
|
||||
self.num_classes = num_classes
|
||||
self.confusion_matrix = np.zeros((num_classes, num_classes))
|
||||
|
||||
def __fast_hist(self, label_gt, label_pred):
|
||||
"""
|
||||
Collect values for Confusion Matrix
|
||||
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
|
||||
:param label_gt: <np.array> ground-truth
|
||||
:param label_pred: <np.array> prediction
|
||||
:return: <np.ndarray> values for confusion matrix
|
||||
"""
|
||||
mask = (label_gt >= 0) & (label_gt < self.num_classes)
|
||||
hist = np.bincount(self.num_classes * label_gt[mask].astype(int) + label_pred[mask],
|
||||
minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes)
|
||||
return hist
|
||||
|
||||
def update(self, label_gts, label_preds):
|
||||
"""
|
||||
Compute Confusion Matrix
|
||||
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
|
||||
:param label_gts: <np.ndarray> ground-truths
|
||||
:param label_preds: <np.ndarray> predictions
|
||||
:return:
|
||||
"""
|
||||
for lt, lp in zip(label_gts, label_preds):
|
||||
self.confusion_matrix += self.__fast_hist(lt.flatten(), lp.flatten())
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset Confusion Matrix
|
||||
:return:
|
||||
"""
|
||||
self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
|
||||
|
||||
def get_cm(self):
|
||||
return self.confusion_matrix
|
||||
|
||||
def get_scores(self):
|
||||
"""
|
||||
Returns score about:
|
||||
- overall accuracy
|
||||
- mean accuracy
|
||||
- mean IU
|
||||
- fwavacc
|
||||
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
|
||||
:return:
|
||||
"""
|
||||
hist = self.confusion_matrix
|
||||
tp = np.diag(hist)
|
||||
sum_a1 = hist.sum(axis=1)
|
||||
sum_a0 = hist.sum(axis=0)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 1. Accuracy & Class Accuracy
|
||||
# ---------------------------------------------------------------------- #
|
||||
acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
|
||||
|
||||
# recall
|
||||
acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps)
|
||||
|
||||
# precision
|
||||
precision = tp / (sum_a0 + np.finfo(np.float32).eps)
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 2. Mean IoU
|
||||
# ---------------------------------------------------------------------- #
|
||||
iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
|
||||
mean_iu = np.nanmean(iu)
|
||||
|
||||
cls_iu = dict(zip(range(self.num_classes), iu))
|
||||
|
||||
# F1 score
|
||||
F1 = 2 * acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps)
|
||||
|
||||
scores = {'Overall_Acc': acc,
|
||||
'Mean_IoU': mean_iu}
|
||||
scores.update(cls_iu)
|
||||
scores.update({'precision_1': precision[1],
|
||||
'recall_1': acc_cls_[1],
|
||||
'F1_1': F1[1]})
|
||||
return scores
|
||||
|
@ -1,110 +0,0 @@
|
||||
"""This module contains simple helper functions """
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
import ntpath
|
||||
|
||||
|
||||
def save_images(images, img_dir, name):
|
||||
"""save images in img_dir, with name
|
||||
iamges: torch.float, B*C*H*W
|
||||
img_dir: str
|
||||
name: list [str]
|
||||
"""
|
||||
for i, image in enumerate(images):
|
||||
print(image.shape)
|
||||
image_numpy = tensor2im(image.unsqueeze(0),normalize=False)*255
|
||||
basename = os.path.basename(name[i])
|
||||
print('name:', basename)
|
||||
save_path = os.path.join(img_dir,basename)
|
||||
save_image(image_numpy,save_path)
|
||||
|
||||
|
||||
def save_visuals(visuals,img_dir,name):
|
||||
"""
|
||||
"""
|
||||
name = ntpath.basename(name)
|
||||
name = name.split(".")[0]
|
||||
print(name)
|
||||
# save images to the disk
|
||||
for label, image in visuals.items():
|
||||
image_numpy = tensor2im(image)
|
||||
img_path = os.path.join(img_dir, '%s_%s.png' % (name, label))
|
||||
save_image(image_numpy, img_path)
|
||||
|
||||
|
||||
def tensor2im(input_image, imtype=np.uint8, normalize=True):
|
||||
""""Converts a Tensor array into a numpy image array.
|
||||
|
||||
Parameters:
|
||||
input_image (tensor) -- the input image tensor array
|
||||
imtype (type) -- the desired type of the converted numpy array
|
||||
"""
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
||||
image_tensor = input_image.data
|
||||
else:
|
||||
return input_image
|
||||
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
||||
if image_numpy.shape[0] == 1: # grayscale to RGB
|
||||
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
||||
|
||||
image_numpy = np.transpose(image_numpy, (1, 2, 0))
|
||||
if normalize:
|
||||
image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
||||
|
||||
else: # if it is a numpy array, do nothing
|
||||
image_numpy = input_image
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
image_numpy (numpy array) -- input numpy array
|
||||
image_path (str) -- the path of the image
|
||||
"""
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
|
||||
def print_numpy(x, val=True, shp=False):
|
||||
"""Print the mean, min, max, median, std, and size of a numpy array
|
||||
|
||||
Parameters:
|
||||
val (bool) -- if print the values of the numpy array
|
||||
shp (bool) -- if print the shape of the numpy array
|
||||
"""
|
||||
x = x.astype(np.float64)
|
||||
if shp:
|
||||
print('shape,', x.shape)
|
||||
if val:
|
||||
x = x.flatten()
|
||||
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
||||
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
"""create empty directories if they don't exist
|
||||
|
||||
Parameters:
|
||||
paths (str list) -- a list of directory paths
|
||||
"""
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
"""create a single empty directory if it didn't exist
|
||||
|
||||
Parameters:
|
||||
path (str) -- a single directory path
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
@ -1,246 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import ntpath
|
||||
import time
|
||||
from . import util, html
|
||||
from subprocess import Popen, PIPE
|
||||
# from scipy.misc import imresize
|
||||
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
VisdomExceptionBase = Exception
|
||||
else:
|
||||
VisdomExceptionBase = ConnectionError
|
||||
|
||||
|
||||
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
||||
"""Save images to the disk.
|
||||
|
||||
Parameters:
|
||||
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
||||
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
||||
image_path (str) -- the string is used to create image paths
|
||||
aspect_ratio (float) -- the aspect ratio of saved images
|
||||
width (int) -- the images will be resized to width x width
|
||||
|
||||
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
||||
"""
|
||||
image_dir = webpage.get_image_dir()
|
||||
short_path = ntpath.basename(image_path[0])
|
||||
name = os.path.splitext(short_path)[0]
|
||||
|
||||
webpage.add_header(name)
|
||||
ims, txts, links = [], [], []
|
||||
|
||||
for label, im_data in visuals.items():
|
||||
im = util.tensor2im(im_data)
|
||||
image_name = '%s_%s.png' % (name, label)
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
h, w, _ = im.shape
|
||||
# if aspect_ratio > 1.0:
|
||||
# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
|
||||
# if aspect_ratio < 1.0:
|
||||
# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
|
||||
util.save_image(im, save_path)
|
||||
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
webpage.add_images(ims, txts, links, width=width)
|
||||
|
||||
|
||||
class Visualizer():
|
||||
"""This class includes several functions that can display/save images and print/save logging information.
|
||||
|
||||
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the Visualizer class
|
||||
|
||||
Parameters:
|
||||
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
Step 1: Cache the training/test options
|
||||
Step 2: connect to a visdom server
|
||||
Step 3: create an HTML object for saveing HTML filters
|
||||
Step 4: create a logging file to store training losses
|
||||
"""
|
||||
self.opt = opt # cache the option
|
||||
self.display_id = opt.display_id
|
||||
self.use_html = opt.isTrain and not opt.no_html
|
||||
self.win_size = opt.display_winsize
|
||||
self.name = opt.name
|
||||
self.port = opt.display_port
|
||||
self.saved = False
|
||||
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
||||
import visdom
|
||||
self.ncols = opt.display_ncols
|
||||
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
||||
if not self.vis.check_connection():
|
||||
self.create_visdom_connections()
|
||||
|
||||
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
||||
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
||||
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||
print('create web directory %s...' % self.web_dir)
|
||||
util.mkdirs([self.web_dir, self.img_dir])
|
||||
# create a logging file to store training losses
|
||||
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
with open(self.log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ Training Loss (%s) ================\n' % now)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the self.saved status"""
|
||||
self.saved = False
|
||||
|
||||
def create_visdom_connections(self):
|
||||
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
||||
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
||||
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
||||
print('Command: %s' % cmd)
|
||||
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
||||
|
||||
def display_current_results(self, visuals, epoch, save_result):
|
||||
"""Display current results on visdom; save current results to an HTML file.
|
||||
|
||||
Parameters:
|
||||
visuals (OrderedDict) - - dictionary of images to display or save
|
||||
epoch (int) - - the current epoch
|
||||
save_result (bool) - - if save the current results to an HTML file
|
||||
"""
|
||||
if self.display_id > 0: # show images in the browser using visdom
|
||||
ncols = self.ncols
|
||||
if ncols > 0: # show all the images in one visdom panel
|
||||
ncols = min(ncols, len(visuals))
|
||||
h, w = next(iter(visuals.values())).shape[:2]
|
||||
table_css = """<style>
|
||||
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
||||
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
||||
</style>""" % (w, h) # create a table css
|
||||
# create a table of images.
|
||||
title = self.name
|
||||
label_html = ''
|
||||
label_html_row = ''
|
||||
images = []
|
||||
idx = 0
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
label_html_row += '<td>%s</td>' % label
|
||||
images.append(image_numpy.transpose([2, 0, 1]))
|
||||
idx += 1
|
||||
if idx % ncols == 0:
|
||||
label_html += '<tr>%s</tr>' % label_html_row
|
||||
label_html_row = ''
|
||||
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
||||
while idx % ncols != 0:
|
||||
images.append(white_image)
|
||||
label_html_row += '<td></td>'
|
||||
idx += 1
|
||||
if label_html_row != '':
|
||||
label_html += '<tr>%s</tr>' % label_html_row
|
||||
try:
|
||||
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
||||
padding=2, opts=dict(title=title + ' images'))
|
||||
label_html = '<table>%s</table>' % label_html
|
||||
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
||||
opts=dict(title=title + ' labels'))
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
else: # show each image in a separate visdom panel;
|
||||
idx = 1
|
||||
try:
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
||||
win=self.display_id + idx)
|
||||
idx += 1
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
||||
self.saved = True
|
||||
# save images to the disk
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
||||
util.save_image(image_numpy, img_path)
|
||||
|
||||
# update website
|
||||
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
||||
for n in range(epoch, 0, -1):
|
||||
webpage.add_header('epoch [%d]' % n)
|
||||
ims, txts, links = [], [], []
|
||||
|
||||
for label, image_numpy in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
img_path = 'epoch%.3d_%s.png' % (n, label)
|
||||
ims.append(img_path)
|
||||
txts.append(label)
|
||||
links.append(img_path)
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
webpage.save()
|
||||
|
||||
def plot_current_losses(self, epoch, counter_ratio, losses):
|
||||
"""display the current losses on visdom display: dictionary of error labels and values
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch
|
||||
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
||||
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
||||
"""
|
||||
if not hasattr(self, 'plot_data'):
|
||||
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
||||
self.plot_data['X'].append(epoch + counter_ratio)
|
||||
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
||||
try:
|
||||
self.vis.line(
|
||||
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
||||
Y=np.array(self.plot_data['Y']),
|
||||
opts={
|
||||
'title': self.name + ' loss over time',
|
||||
'legend': self.plot_data['legend'],
|
||||
'xlabel': 'epoch',
|
||||
'ylabel': 'loss'},
|
||||
win=self.display_id)
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
def plot_current_acc(self, epoch, counter_ratio, acc):
|
||||
if not hasattr(self, 'acc_data'):
|
||||
self.acc_data = {'X': [], 'Y': [], 'legend': list(acc.keys())}
|
||||
self.acc_data['X'].append(epoch + counter_ratio)
|
||||
self.acc_data['Y'].append([acc[k] for k in self.acc_data['legend']])
|
||||
try:
|
||||
self.vis.line(
|
||||
X=np.stack([np.array(self.acc_data['X'])] * len(self.acc_data['legend']), 1),
|
||||
Y=np.array(self.acc_data['Y']),
|
||||
opts={
|
||||
'title': self.name + ' acc over time',
|
||||
'legend': self.acc_data['legend'],
|
||||
'xlabel': 'epoch',
|
||||
'ylabel': 'acc'},
|
||||
win=self.display_id+3)
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
# losses: same format as |losses| of plot_current_losses
|
||||
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
||||
"""print current losses on console; also save the losses to the disk
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch
|
||||
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
||||
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
||||
t_comp (float) -- computational time per data point (normalized by batch_size)
|
||||
t_data (float) -- data loading time per data point (normalized by batch_size)
|
||||
"""
|
||||
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
||||
for k, v in losses.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
|
||||
print(message) # print the message
|
||||
with open(self.log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message) # save the message
|
@ -1,96 +0,0 @@
|
||||
import time
|
||||
from options.test_options import TestOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
import os
|
||||
from util.util import save_visuals
|
||||
from util.metrics import AverageMeter
|
||||
import numpy as np
|
||||
from util.util import mkdir
|
||||
|
||||
def make_val_opt(opt):
|
||||
|
||||
# hard-code some parameters for test
|
||||
opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
opt.batch_size = 1 # test code only supports batch_size = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed.
|
||||
|
||||
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
opt.phase = 'val'
|
||||
opt.preprocess = 'none1'
|
||||
opt.isTrain = False
|
||||
opt.aspect_ratio = 1
|
||||
opt.eval = True
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
def print_current_acc(log_name, epoch, score):
|
||||
"""print current acc on console; also save the losses to the disk
|
||||
Parameters:
|
||||
"""
|
||||
|
||||
message = '(epoch: %s) ' % str(epoch)
|
||||
for k, v in score.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
print(message) # print the message
|
||||
with open(log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message) # save the message
|
||||
|
||||
|
||||
def val(opt):
|
||||
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
save_path = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))
|
||||
mkdir(save_path)
|
||||
model.eval()
|
||||
# create a logging file to store training losses
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val1_log.txt')
|
||||
with open(log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ val acc (%s) ================\n' % now)
|
||||
|
||||
running_metrics = AverageMeter()
|
||||
for i, data in enumerate(dataset):
|
||||
if i >= opt.num_test: # only apply our model to opt.num_test images.
|
||||
break
|
||||
model.set_input(data) # unpack data from data loader
|
||||
score = model.test(val=True) # run inference return confusion_matrix
|
||||
running_metrics.update(score)
|
||||
|
||||
visuals = model.get_current_visuals() # get image results
|
||||
img_path = model.get_image_paths() # get image paths
|
||||
if i % 5 == 0: # save images to an HTML file
|
||||
print('processing (%04d)-th image... %s' % (i, img_path))
|
||||
|
||||
save_visuals(visuals,save_path,img_path[0])
|
||||
score = running_metrics.get_scores()
|
||||
print_current_acc(log_name, opt.epoch, score)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse() # get training options
|
||||
opt = make_val_opt(opt)
|
||||
opt.phase = 'val'
|
||||
opt.dataroot = 'path-to-LEVIR-CD-test'
|
||||
|
||||
opt.dataset_mode = 'changedetection'
|
||||
|
||||
opt.n_class = 2
|
||||
|
||||
opt.SA_mode = 'PAM'
|
||||
opt.arch = 'mynet3'
|
||||
|
||||
opt.model = 'CDFA'
|
||||
|
||||
opt.name = 'pam'
|
||||
opt.results_dir = './results/'
|
||||
|
||||
opt.epoch = '78_F1_1_0.88780'
|
||||
opt.num_test = np.inf
|
||||
|
||||
val(opt)
|
39
plugins/ai_method/packages/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
from .models import create_model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Front:
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
self.model = model
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
self.model = model.to(self.device)
|
||||
|
||||
def __call__(self, inp1, inp2):
|
||||
inp1 = torch.from_numpy(inp1).to(self.device)
|
||||
inp1 = torch.transpose(inp1, 0, 2).transpose(1, 2).unsqueeze(0)
|
||||
inp2 = torch.from_numpy(inp2).to(self.device)
|
||||
inp2 = torch.transpose(inp2, 0, 2).transpose(1, 2).unsqueeze(0)
|
||||
|
||||
out = self.model(inp1, inp2)
|
||||
out = out.sigmoid()
|
||||
out = out.cpu().detach().numpy()[0,0]
|
||||
return out
|
||||
|
||||
def get_model(name):
|
||||
|
||||
try:
|
||||
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
|
||||
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
|
||||
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
|
||||
classes=1, # model output channels (number of classes in your datasets)
|
||||
siam_encoder=True, # whether to use a siamese encoder
|
||||
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
|
||||
)
|
||||
except:
|
||||
return None
|
||||
|
||||
return Front(model)
|
1
plugins/ai_method/packages/models/DPFCN/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import DPFCN
|
134
plugins/ai_method/packages/models/DPFCN/decoder.py
Normal file
@ -0,0 +1,134 @@
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base import modules as md
|
||||
from ..base import Decoder
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
skip_channels,
|
||||
out_channels,
|
||||
use_batchnorm=True,
|
||||
attention_type=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = md.Conv2dReLU(
|
||||
in_channels + skip_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
use_batchnorm=use_batchnorm,
|
||||
)
|
||||
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
|
||||
self.conv2 = md.Conv2dReLU(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
use_batchnorm=use_batchnorm,
|
||||
)
|
||||
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
|
||||
|
||||
def forward(self, x, skip=None):
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
||||
if skip is not None:
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.attention1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.attention2(x)
|
||||
return x
|
||||
|
||||
|
||||
class CenterBlock(nn.Sequential):
|
||||
def __init__(self, in_channels, out_channels, use_batchnorm=True):
|
||||
conv1 = md.Conv2dReLU(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
use_batchnorm=use_batchnorm,
|
||||
)
|
||||
conv2 = md.Conv2dReLU(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
use_batchnorm=use_batchnorm,
|
||||
)
|
||||
super().__init__(conv1, conv2)
|
||||
|
||||
|
||||
class DPFCNDecoder(Decoder):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_channels,
|
||||
decoder_channels,
|
||||
n_blocks=5,
|
||||
use_batchnorm=True,
|
||||
attention_type=None,
|
||||
center=False,
|
||||
fusion_form="concat",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if n_blocks != len(decoder_channels):
|
||||
raise ValueError(
|
||||
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
|
||||
n_blocks, len(decoder_channels)
|
||||
)
|
||||
)
|
||||
|
||||
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
|
||||
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
|
||||
|
||||
# computing blocks input and output channels
|
||||
head_channels = encoder_channels[0]
|
||||
in_channels = [head_channels] + list(decoder_channels[:-1])
|
||||
skip_channels = list(encoder_channels[1:]) + [0]
|
||||
out_channels = decoder_channels
|
||||
|
||||
# adjust encoder channels according to fusion form
|
||||
self.fusion_form = fusion_form
|
||||
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
|
||||
skip_channels = [ch*2 for ch in skip_channels]
|
||||
in_channels[0] = in_channels[0] * 2
|
||||
head_channels = head_channels * 2
|
||||
|
||||
if center:
|
||||
self.center = CenterBlock(
|
||||
head_channels, head_channels, use_batchnorm=use_batchnorm
|
||||
)
|
||||
else:
|
||||
self.center = nn.Identity()
|
||||
|
||||
# combine decoder keyword arguments
|
||||
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
|
||||
blocks = [
|
||||
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
|
||||
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
|
||||
]
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, *features):
|
||||
|
||||
features = self.aggregation_layer(features[0], features[1],
|
||||
self.fusion_form, ignore_original_img=True)
|
||||
# features = features[1:] # remove first skip with same spatial resolution
|
||||
features = features[::-1] # reverse channels to start from head of encoder
|
||||
|
||||
head = features[0]
|
||||
skips = features[1:]
|
||||
|
||||
x = self.center(head)
|
||||
for i, decoder_block in enumerate(self.blocks):
|
||||
skip = skips[i] if i < len(skips) else None
|
||||
x = decoder_block(x, skip)
|
||||
|
||||
return x
|
72
plugins/ai_method/packages/models/DPFCN/model.py
Normal file
@ -0,0 +1,72 @@
|
||||
from typing import Optional, Union, List
|
||||
from .decoder import DPFCNDecoder
|
||||
from ..encoders import get_encoder
|
||||
from ..base import SegmentationModel
|
||||
from ..base import SegmentationHead, ClassificationHead
|
||||
|
||||
|
||||
class DPFCN(SegmentationModel):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_name: str = "resnet34",
|
||||
encoder_depth: int = 5,
|
||||
encoder_weights: Optional[str] = "imagenet",
|
||||
decoder_use_batchnorm: bool = True,
|
||||
decoder_channels: List[int] = (256, 128, 64, 32, 16),
|
||||
decoder_attention_type: Optional[str] = None,
|
||||
in_channels: int = 3,
|
||||
classes: int = 1,
|
||||
activation: Optional[Union[str, callable]] = None,
|
||||
aux_params: Optional[dict] = None,
|
||||
siam_encoder: bool = True,
|
||||
fusion_form: str = "concat",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.siam_encoder = siam_encoder
|
||||
|
||||
self.encoder = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
)
|
||||
|
||||
if not self.siam_encoder:
|
||||
self.encoder_non_siam = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
)
|
||||
|
||||
self.decoder = DPFCNDecoder(
|
||||
encoder_channels=self.encoder.out_channels,
|
||||
decoder_channels=decoder_channels,
|
||||
n_blocks=encoder_depth,
|
||||
use_batchnorm=decoder_use_batchnorm,
|
||||
center=True if encoder_name.startswith("vgg") else False,
|
||||
attention_type=decoder_attention_type,
|
||||
fusion_form=fusion_form,
|
||||
)
|
||||
|
||||
self.segmentation_head = SegmentationHead(
|
||||
in_channels=decoder_channels[-1],
|
||||
out_channels=classes,
|
||||
activation=activation,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
if aux_params is not None:
|
||||
self.classification_head = ClassificationHead(
|
||||
in_channels=self.encoder.out_channels[-1], **aux_params
|
||||
)
|
||||
else:
|
||||
self.classification_head = None
|
||||
|
||||
self.name = "u-{}".format(encoder_name)
|
||||
self.initialize()
|
||||
|
1
plugins/ai_method/packages/models/DVCA/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import DVCA
|
179
plugins/ai_method/packages/models/DVCA/decoder.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) Soumith Chintala 2016,
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..base import Decoder
|
||||
|
||||
__all__ = ["DVCADecoder"]
|
||||
|
||||
|
||||
class DVCADecoder(Decoder):
|
||||
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36), fusion_form="concat"):
|
||||
super().__init__()
|
||||
|
||||
# adjust encoder channels according to fusion form
|
||||
if fusion_form in self.FUSION_DIC["2to2_fusion"]:
|
||||
in_channels = in_channels * 2
|
||||
|
||||
self.aspp = nn.Sequential(
|
||||
ASPP(in_channels, out_channels, atrous_rates),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.fusion_form = fusion_form
|
||||
|
||||
def forward(self, *features):
|
||||
x = self.fusion(features[0][-1], features[1][-1], self.fusion_form)
|
||||
x = self.aspp(x)
|
||||
return x
|
||||
|
||||
|
||||
class ASPPConv(nn.Sequential):
|
||||
def __init__(self, in_channels, out_channels, dilation):
|
||||
super().__init__(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
|
||||
class ASPPSeparableConv(nn.Sequential):
|
||||
def __init__(self, in_channels, out_channels, dilation):
|
||||
super().__init__(
|
||||
SeparableConv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
|
||||
class ASPPPooling(nn.Sequential):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-2:]
|
||||
for mod in self:
|
||||
x = mod(x)
|
||||
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
|
||||
super(ASPP, self).__init__()
|
||||
modules = []
|
||||
modules.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
)
|
||||
|
||||
rate1, rate2, rate3 = tuple(atrous_rates)
|
||||
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
|
||||
|
||||
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
|
||||
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
|
||||
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
|
||||
modules.append(ASPPPooling(in_channels, out_channels))
|
||||
|
||||
self.convs = nn.ModuleList(modules)
|
||||
|
||||
self.project = nn.Sequential(
|
||||
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = []
|
||||
for conv in self.convs:
|
||||
res.append(conv(x))
|
||||
res = torch.cat(res, dim=1)
|
||||
return self.project(res)
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
):
|
||||
dephtwise_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
)
|
||||
pointwise_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=bias,
|
||||
)
|
||||
super().__init__(dephtwise_conv, pointwise_conv)
|
69
plugins/ai_method/packages/models/DVCA/model.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
||||
from ..encoders import get_encoder
|
||||
from .decoder import DVCADecoder
|
||||
|
||||
|
||||
class DVCA(SegmentationModel):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_name: str = "resnet34",
|
||||
encoder_depth: int = 5,
|
||||
encoder_weights: Optional[str] = "imagenet",
|
||||
decoder_channels: int = 256,
|
||||
in_channels: int = 3,
|
||||
classes: int = 1,
|
||||
activation: Optional[str] = None,
|
||||
upsampling: int = 8,
|
||||
aux_params: Optional[dict] = None,
|
||||
siam_encoder: bool = True,
|
||||
fusion_form: str = "concat",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.siam_encoder = siam_encoder
|
||||
|
||||
self.encoder = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
output_stride=8,
|
||||
)
|
||||
|
||||
if not self.siam_encoder:
|
||||
self.encoder_non_siam = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
output_stride=8,
|
||||
)
|
||||
|
||||
self.decoder = DVCADecoder(
|
||||
in_channels=self.encoder.out_channels[-1],
|
||||
out_channels=decoder_channels,
|
||||
fusion_form=fusion_form,
|
||||
)
|
||||
|
||||
self.segmentation_head = SegmentationHead(
|
||||
in_channels=self.decoder.out_channels,
|
||||
out_channels=classes,
|
||||
activation=activation,
|
||||
kernel_size=1,
|
||||
upsampling=upsampling,
|
||||
)
|
||||
|
||||
if aux_params is not None:
|
||||
self.classification_head = ClassificationHead(
|
||||
in_channels=self.encoder.out_channels[-1], **aux_params
|
||||
)
|
||||
else:
|
||||
self.classification_head = None
|
||||
|
1
plugins/ai_method/packages/models/RCNN/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import RCNN
|
131
plugins/ai_method/packages/models/RCNN/decoder.py
Normal file
@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base import Decoder
|
||||
|
||||
|
||||
class Conv3x3GNReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, upsample=False):
|
||||
super().__init__()
|
||||
self.upsample = upsample
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
|
||||
),
|
||||
nn.GroupNorm(32, out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNBlock(nn.Module):
|
||||
def __init__(self, pyramid_channels, skip_channels):
|
||||
super().__init__()
|
||||
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, skip=None):
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
skip = self.skip_conv(skip)
|
||||
x = x + skip
|
||||
return x
|
||||
|
||||
|
||||
class SegmentationBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, n_upsamples=0):
|
||||
super().__init__()
|
||||
|
||||
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
|
||||
|
||||
if n_upsamples > 1:
|
||||
for _ in range(1, n_upsamples):
|
||||
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
|
||||
|
||||
self.block = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class MergeBlock(nn.Module):
|
||||
def __init__(self, policy):
|
||||
super().__init__()
|
||||
if policy not in ["add", "cat"]:
|
||||
raise ValueError(
|
||||
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
|
||||
policy
|
||||
)
|
||||
)
|
||||
self.policy = policy
|
||||
|
||||
def forward(self, x):
|
||||
if self.policy == 'add':
|
||||
return sum(x)
|
||||
elif self.policy == 'cat':
|
||||
return torch.cat(x, dim=1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
|
||||
)
|
||||
|
||||
|
||||
class RCNNDecoder(Decoder):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_channels,
|
||||
encoder_depth=5,
|
||||
pyramid_channels=256,
|
||||
segmentation_channels=128,
|
||||
dropout=0.2,
|
||||
merge_policy="add",
|
||||
fusion_form="concat",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
|
||||
if encoder_depth < 3:
|
||||
raise ValueError("Encoder depth for RCNN decoder cannot be less than 3, got {}.".format(encoder_depth))
|
||||
|
||||
encoder_channels = encoder_channels[::-1]
|
||||
encoder_channels = encoder_channels[:encoder_depth + 1]
|
||||
# (512, 256, 128, 64, 64, 3)
|
||||
|
||||
# adjust encoder channels according to fusion form
|
||||
self.fusion_form = fusion_form
|
||||
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
|
||||
encoder_channels = [ch*2 for ch in encoder_channels]
|
||||
|
||||
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
|
||||
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
|
||||
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
|
||||
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
|
||||
|
||||
self.seg_blocks = nn.ModuleList([
|
||||
SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
|
||||
for n_upsamples in [3, 2, 1, 0]
|
||||
])
|
||||
|
||||
self.merge = MergeBlock(merge_policy)
|
||||
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
|
||||
|
||||
def forward(self, *features):
|
||||
|
||||
features = self.aggregation_layer(features[0], features[1],
|
||||
self.fusion_form, ignore_original_img=True)
|
||||
c2, c3, c4, c5 = features[-4:]
|
||||
|
||||
p5 = self.p5(c5)
|
||||
p4 = self.p4(p5, c4)
|
||||
p3 = self.p3(p4, c3)
|
||||
p2 = self.p2(p3, c2)
|
||||
|
||||
feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])]
|
||||
x = self.merge(feature_pyramid)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
73
plugins/ai_method/packages/models/RCNN/model.py
Normal file
@ -0,0 +1,73 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
||||
from ..encoders import get_encoder
|
||||
from .decoder import RCNNDecoder
|
||||
|
||||
|
||||
class RCNN(SegmentationModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_name: str = "resnet34",
|
||||
encoder_depth: int = 5,
|
||||
encoder_weights: Optional[str] = "imagenet",
|
||||
decoder_pyramid_channels: int = 256,
|
||||
decoder_segmentation_channels: int = 128,
|
||||
decoder_merge_policy: str = "add",
|
||||
decoder_dropout: float = 0.2,
|
||||
in_channels: int = 3,
|
||||
classes: int = 1,
|
||||
activation: Optional[str] = None,
|
||||
upsampling: int = 4,
|
||||
aux_params: Optional[dict] = None,
|
||||
siam_encoder: bool = True,
|
||||
fusion_form: str = "concat",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.siam_encoder = siam_encoder
|
||||
|
||||
self.encoder = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
)
|
||||
|
||||
if not self.siam_encoder:
|
||||
self.encoder_non_siam = get_encoder(
|
||||
encoder_name,
|
||||
in_channels=in_channels,
|
||||
depth=encoder_depth,
|
||||
weights=encoder_weights,
|
||||
)
|
||||
|
||||
self.decoder = RCNNDecoder(
|
||||
encoder_channels=self.encoder.out_channels,
|
||||
encoder_depth=encoder_depth,
|
||||
pyramid_channels=decoder_pyramid_channels,
|
||||
segmentation_channels=decoder_segmentation_channels,
|
||||
dropout=decoder_dropout,
|
||||
merge_policy=decoder_merge_policy,
|
||||
fusion_form=fusion_form,
|
||||
)
|
||||
|
||||
self.segmentation_head = SegmentationHead(
|
||||
in_channels=self.decoder.out_channels,
|
||||
out_channels=classes,
|
||||
activation=activation,
|
||||
kernel_size=1,
|
||||
upsampling=upsampling,
|
||||
)
|
||||
|
||||
if aux_params is not None:
|
||||
self.classification_head = ClassificationHead(
|
||||
in_channels=self.encoder.out_channels[-1], **aux_params
|
||||
)
|
||||
else:
|
||||
self.classification_head = None
|
||||
|
||||
self.name = "rcnn-{}".format(encoder_name)
|
||||
self.initialize()
|
42
plugins/ai_method/packages/models/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
from .RCNN import RCNN
|
||||
from .DVCA import DVCA
|
||||
from .DPFCN import DPFCN
|
||||
|
||||
from . import encoders
|
||||
from . import utils
|
||||
from . import losses
|
||||
from . import datasets
|
||||
|
||||
from .__version__ import __version__
|
||||
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
def create_model(
|
||||
arch: str,
|
||||
encoder_name: str = "resnet34",
|
||||
encoder_weights: Optional[str] = "imagenet",
|
||||
in_channels: int = 3,
|
||||
classes: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.nn.Module:
|
||||
"""Models wrapper. Allows to create any model just with parametes
|
||||
|
||||
"""
|
||||
|
||||
archs = [DVCA, DPFCN, RCNN]
|
||||
archs_dict = {a.__name__.lower(): a for a in archs}
|
||||
try:
|
||||
model_class = archs_dict[arch.lower()]
|
||||
except KeyError:
|
||||
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format(
|
||||
arch, list(archs_dict.keys()),
|
||||
))
|
||||
return model_class(
|
||||
encoder_name=encoder_name,
|
||||
encoder_weights=encoder_weights,
|
||||
in_channels=in_channels,
|
||||
classes=classes,
|
||||
**kwargs,
|
||||
)
|
3
plugins/ai_method/packages/models/__version__.py
Normal file
@ -0,0 +1,3 @@
|
||||
VERSION = (0, 1, 4)
|
||||
|
||||
__version__ = '.'.join(map(str, VERSION))
|
12
plugins/ai_method/packages/models/base/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .model import SegmentationModel
|
||||
from .decoder import Decoder
|
||||
|
||||
from .modules import (
|
||||
Conv2dReLU,
|
||||
Attention,
|
||||
)
|
||||
|
||||
from .heads import (
|
||||
SegmentationHead,
|
||||
ClassificationHead,
|
||||
)
|
33
plugins/ai_method/packages/models/base/decoder.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
# TODO: support learnable fusion modules
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"],
|
||||
"2to2_fusion": ["concat"]}
|
||||
|
||||
def fusion(self, x1, x2, fusion_form="concat"):
|
||||
"""Specify the form of feature fusion"""
|
||||
if fusion_form == "concat":
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
elif fusion_form == "sum":
|
||||
x = x1 + x2
|
||||
elif fusion_form == "diff":
|
||||
x = x2 - x1
|
||||
elif fusion_form == "abs_diff":
|
||||
x = torch.abs(x1 - x2)
|
||||
else:
|
||||
raise ValueError('the fusion form "{}" is not defined'.format(fusion_form))
|
||||
|
||||
return x
|
||||
|
||||
def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True):
|
||||
"""aggregate features from siamese or non-siamese branches"""
|
||||
|
||||
start_idx = 1 if ignore_original_img else 0
|
||||
aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form)
|
||||
for idx in range(start_idx, len(fea1))]
|
||||
|
||||
return aggregate_fea
|
24
plugins/ai_method/packages/models/base/heads.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch.nn as nn
|
||||
from .modules import Flatten, Activation
|
||||
|
||||
|
||||
class SegmentationHead(nn.Sequential):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1, align_corners=True):
|
||||
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||
upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity()
|
||||
activation = Activation(activation)
|
||||
super().__init__(conv2d, upsampling, activation)
|
||||
|
||||
|
||||
class ClassificationHead(nn.Sequential):
|
||||
|
||||
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
|
||||
if pooling not in ("max", "avg"):
|
||||
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
|
||||
pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1)
|
||||
flatten = Flatten()
|
||||
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
|
||||
linear = nn.Linear(in_channels, classes, bias=True)
|
||||
activation = Activation(activation)
|
||||
super().__init__(pool, flatten, dropout, linear, activation)
|
27
plugins/ai_method/packages/models/base/initialization.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def initialize_decoder(module):
|
||||
for m in module.modules():
|
||||
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def initialize_head(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
53
plugins/ai_method/packages/models/base/model.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from . import initialization as init
|
||||
|
||||
|
||||
class SegmentationModel(torch.nn.Module):
|
||||
|
||||
def initialize(self):
|
||||
init.initialize_decoder(self.decoder)
|
||||
init.initialize_head(self.segmentation_head)
|
||||
if self.classification_head is not None:
|
||||
init.initialize_head(self.classification_head)
|
||||
|
||||
def base_forward(self, x1, x2):
|
||||
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
|
||||
if self.siam_encoder:
|
||||
features = self.encoder(x1), self.encoder(x2)
|
||||
else:
|
||||
features = self.encoder(x1), self.encoder_non_siam(x2)
|
||||
|
||||
decoder_output = self.decoder(*features)
|
||||
|
||||
# TODO: features = self.fusion_policy(features)
|
||||
|
||||
masks = self.segmentation_head(decoder_output)
|
||||
|
||||
if self.classification_head is not None:
|
||||
raise AttributeError("`classification_head` is not supported now.")
|
||||
# labels = self.classification_head(features[-1])
|
||||
# return masks, labels
|
||||
|
||||
return masks
|
||||
|
||||
def forward(self, x1, x2):
|
||||
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
|
||||
return self.base_forward(x1, x2)
|
||||
|
||||
def predict(self, x1, x2):
|
||||
"""Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()`
|
||||
|
||||
Args:
|
||||
x1, x2: 4D torch tensor with shape (batch_size, channels, height, width)
|
||||
|
||||
Return:
|
||||
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
|
||||
|
||||
"""
|
||||
if self.training:
|
||||
self.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
x = self.forward(x1, x2)
|
||||
|
||||
return x
|
242
plugins/ai_method/packages/models/base/modules.py
Normal file
@ -0,0 +1,242 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from inplace_abn import InPlaceABN
|
||||
except ImportError:
|
||||
InPlaceABN = None
|
||||
|
||||
|
||||
class Conv2dReLU(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=0,
|
||||
stride=1,
|
||||
use_batchnorm=True,
|
||||
):
|
||||
|
||||
if use_batchnorm == "inplace" and InPlaceABN is None:
|
||||
raise RuntimeError(
|
||||
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
||||
+ "To install see: https://github.com/mapillary/inplace_abn"
|
||||
)
|
||||
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=not (use_batchnorm),
|
||||
)
|
||||
relu = nn.ReLU(inplace=True)
|
||||
|
||||
if use_batchnorm == "inplace":
|
||||
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
|
||||
relu = nn.Identity()
|
||||
|
||||
elif use_batchnorm and use_batchnorm != "inplace":
|
||||
bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
else:
|
||||
bn = nn.Identity()
|
||||
|
||||
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
||||
|
||||
|
||||
class SCSEModule(nn.Module):
|
||||
def __init__(self, in_channels, reduction=16):
|
||||
super().__init__()
|
||||
self.cSE = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.cSE(x) + x * self.sSE(x)
|
||||
|
||||
|
||||
class CBAMChannel(nn.Module):
|
||||
def __init__(self, in_channels, reduction=16):
|
||||
super(CBAMChannel, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False))
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
out = avg_out + max_out
|
||||
return x * self.sigmoid(out)
|
||||
|
||||
|
||||
class CBAMSpatial(nn.Module):
|
||||
def __init__(self, in_channels, kernel_size=7):
|
||||
super(CBAMSpatial, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
out = self.conv1(out)
|
||||
return x * self.sigmoid(out)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""
|
||||
Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]
|
||||
//Proceedings of the European conference on computer vision (ECCV).
|
||||
"""
|
||||
def __init__(self, in_channels, reduction=16, kernel_size=7):
|
||||
super(CBAM, self).__init__()
|
||||
self.ChannelGate = CBAMChannel(in_channels, reduction)
|
||||
self.SpatialGate = CBAMSpatial(kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.ChannelGate(x)
|
||||
x = self.SpatialGate(x)
|
||||
return x
|
||||
|
||||
|
||||
class ECAM(nn.Module):
|
||||
"""
|
||||
Ensemble Channel Attention Module for UNetPlusPlus.
|
||||
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
|
||||
IEEE Geoscience and Remote Sensing Letters, 2021.
|
||||
Not completely consistent, to be improved.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, map_num=4):
|
||||
super(ECAM, self).__init__()
|
||||
self.ca1 = CBAMChannel(in_channels * map_num, reduction=16)
|
||||
self.ca2 = CBAMChannel(in_channels, reduction=16 // 4)
|
||||
self.up = nn.ConvTranspose2d(in_channels * map_num, in_channels * map_num, 2, stride=2)
|
||||
self.conv_final = nn.Conv2d(in_channels * map_num, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x (list[tensor] or tuple(tensor))
|
||||
"""
|
||||
out = torch.cat(x, 1)
|
||||
intra = torch.sum(torch.stack(x), dim=0)
|
||||
ca2 = self.ca2(intra)
|
||||
out = self.ca1(out) * (out + ca2.repeat(1, 4, 1, 1))
|
||||
out = self.up(out)
|
||||
out = self.conv_final(out)
|
||||
return out
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
"""
|
||||
Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
|
||||
//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
|
||||
"""
|
||||
def __init__(self, in_channels, reduction=16):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(in_channels // reduction, in_channels, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ArgMax(nn.Module):
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
return torch.argmax(x, dim=self.dim)
|
||||
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def __init__(self, min=0, max=1):
|
||||
super().__init__()
|
||||
self.min, self.max = min, max
|
||||
|
||||
def forward(self, x):
|
||||
return torch.clamp(x, self.min, self.max)
|
||||
|
||||
|
||||
class Activation(nn.Module):
|
||||
|
||||
def __init__(self, name, **params):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if name is None or name == 'identity':
|
||||
self.activation = nn.Identity(**params)
|
||||
elif name == 'sigmoid':
|
||||
self.activation = nn.Sigmoid()
|
||||
elif name == 'softmax2d':
|
||||
self.activation = nn.Softmax(dim=1, **params)
|
||||
elif name == 'softmax':
|
||||
self.activation = nn.Softmax(**params)
|
||||
elif name == 'logsoftmax':
|
||||
self.activation = nn.LogSoftmax(**params)
|
||||
elif name == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif name == 'argmax':
|
||||
self.activation = ArgMax(**params)
|
||||
elif name == 'argmax2d':
|
||||
self.activation = ArgMax(dim=1, **params)
|
||||
elif name == 'clamp':
|
||||
self.activation = Clamp(**params)
|
||||
elif callable(name):
|
||||
self.activation = name(**params)
|
||||
else:
|
||||
raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name))
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, name, **params):
|
||||
super().__init__()
|
||||
|
||||
if name is None:
|
||||
self.attention = nn.Identity(**params)
|
||||
elif name == 'scse':
|
||||
self.attention = SCSEModule(**params)
|
||||
elif name == 'cbam_channel':
|
||||
self.attention = CBAMChannel(**params)
|
||||
elif name == 'cbam_spatial':
|
||||
self.attention = CBAMSpatial(**params)
|
||||
elif name == 'cbam':
|
||||
self.attention = CBAM(**params)
|
||||
elif name == 'se':
|
||||
self.attention = SEModule(**params)
|
||||
else:
|
||||
raise ValueError("Attention {} is not implemented".format(name))
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(x)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
113
plugins/ai_method/packages/models/encoders/__init__.py
Normal file
@ -0,0 +1,113 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from ._preprocessing import preprocess_input
|
||||
from .densenet import densenet_encoders
|
||||
from .dpn import dpn_encoders
|
||||
from .efficientnet import efficient_net_encoders
|
||||
from .inceptionresnetv2 import inceptionresnetv2_encoders
|
||||
from .inceptionv4 import inceptionv4_encoders
|
||||
from .mobilenet import mobilenet_encoders
|
||||
from .resnet import resnet_encoders
|
||||
from .senet import senet_encoders
|
||||
from .timm_efficientnet import timm_efficientnet_encoders
|
||||
from .timm_gernet import timm_gernet_encoders
|
||||
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
|
||||
from .timm_regnet import timm_regnet_encoders
|
||||
from .timm_res2net import timm_res2net_encoders
|
||||
from .timm_resnest import timm_resnest_encoders
|
||||
from .timm_sknet import timm_sknet_encoders
|
||||
from .timm_universal import TimmUniversalEncoder
|
||||
from .vgg import vgg_encoders
|
||||
from .xception import xception_encoders
|
||||
from .swin_transformer import swin_transformer_encoders
|
||||
from .mit_encoder import mit_encoders
|
||||
# from .hrnet import hrnet_encoders
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
encoders = {}
|
||||
encoders.update(resnet_encoders)
|
||||
encoders.update(dpn_encoders)
|
||||
encoders.update(vgg_encoders)
|
||||
encoders.update(senet_encoders)
|
||||
encoders.update(densenet_encoders)
|
||||
encoders.update(inceptionresnetv2_encoders)
|
||||
encoders.update(inceptionv4_encoders)
|
||||
encoders.update(efficient_net_encoders)
|
||||
encoders.update(mobilenet_encoders)
|
||||
encoders.update(xception_encoders)
|
||||
encoders.update(timm_efficientnet_encoders)
|
||||
encoders.update(timm_resnest_encoders)
|
||||
encoders.update(timm_res2net_encoders)
|
||||
encoders.update(timm_regnet_encoders)
|
||||
encoders.update(timm_sknet_encoders)
|
||||
encoders.update(timm_mobilenetv3_encoders)
|
||||
encoders.update(timm_gernet_encoders)
|
||||
encoders.update(swin_transformer_encoders)
|
||||
encoders.update(mit_encoders)
|
||||
# encoders.update(hrnet_encoders)
|
||||
|
||||
|
||||
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
|
||||
|
||||
if name.startswith("tu-"):
|
||||
name = name[3:]
|
||||
encoder = TimmUniversalEncoder(
|
||||
name=name,
|
||||
in_channels=in_channels,
|
||||
depth=depth,
|
||||
output_stride=output_stride,
|
||||
pretrained=weights is not None,
|
||||
**kwargs
|
||||
)
|
||||
return encoder
|
||||
|
||||
try:
|
||||
Encoder = encoders[name]["encoder"]
|
||||
except KeyError:
|
||||
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
|
||||
|
||||
params = encoders[name]["params"]
|
||||
params.update(depth=depth)
|
||||
encoder = Encoder(**params)
|
||||
|
||||
if weights is not None:
|
||||
try:
|
||||
settings = encoders[name]["pretrained_settings"][weights]
|
||||
except KeyError:
|
||||
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
|
||||
weights, name, list(encoders[name]["pretrained_settings"].keys()),
|
||||
))
|
||||
encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE)))
|
||||
|
||||
encoder.set_in_channels(in_channels, pretrained=weights is not None)
|
||||
if output_stride != 32:
|
||||
encoder.make_dilated(output_stride)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def get_encoder_names():
|
||||
return list(encoders.keys())
|
||||
|
||||
|
||||
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
|
||||
settings = encoders[encoder_name]["pretrained_settings"]
|
||||
|
||||
if pretrained not in settings.keys():
|
||||
raise ValueError("Available pretrained options {}".format(settings.keys()))
|
||||
|
||||
formatted_settings = {}
|
||||
formatted_settings["input_space"] = settings[pretrained].get("input_space")
|
||||
formatted_settings["input_range"] = settings[pretrained].get("input_range")
|
||||
formatted_settings["mean"] = settings[pretrained].get("mean")
|
||||
formatted_settings["std"] = settings[pretrained].get("std")
|
||||
return formatted_settings
|
||||
|
||||
|
||||
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
|
||||
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
|
||||
return functools.partial(preprocess_input, **params)
|
53
plugins/ai_method/packages/models/encoders/_base.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
from collections import OrderedDict
|
||||
|
||||
from . import _utils as utils
|
||||
|
||||
|
||||
class EncoderMixin:
|
||||
"""Add encoder functionality such as:
|
||||
- output channels specification of feature tensors (produced by encoder)
|
||||
- patching first convolution for arbitrary input channels
|
||||
"""
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
"""Return channels dimensions for each tensor of forward output of encoder"""
|
||||
return self._out_channels[: self._depth + 1]
|
||||
|
||||
def set_in_channels(self, in_channels, pretrained=True):
|
||||
"""Change first convolution channels"""
|
||||
if in_channels == 3:
|
||||
return
|
||||
|
||||
self._in_channels = in_channels
|
||||
if self._out_channels[0] == 3:
|
||||
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
|
||||
|
||||
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
|
||||
|
||||
def get_stages(self):
|
||||
"""Method should be overridden in encoder"""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
|
||||
if output_stride == 16:
|
||||
stage_list=[5,]
|
||||
dilation_list=[2,]
|
||||
|
||||
elif output_stride == 8:
|
||||
stage_list=[4, 5]
|
||||
dilation_list=[2, 4]
|
||||
|
||||
else:
|
||||
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
|
||||
|
||||
stages = self.get_stages()
|
||||
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
|
||||
utils.replace_strides_with_dilation(
|
||||
module=stages[stage_indx],
|
||||
dilation_rate=dilation_rate,
|
||||
)
|
23
plugins/ai_method/packages/models/encoders/_preprocessing.py
Normal file
@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def preprocess_input(
|
||||
x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
|
||||
):
|
||||
|
||||
if input_space == "BGR":
|
||||
x = x[..., ::-1].copy()
|
||||
|
||||
if input_range is not None:
|
||||
if x.max() > 1 and input_range[1] == 1:
|
||||
x = x / 255.0
|
||||
|
||||
if mean is not None:
|
||||
mean = np.array(mean)
|
||||
x = x - mean
|
||||
|
||||
if std is not None:
|
||||
std = np.array(std)
|
||||
x = x / std
|
||||
|
||||
return x
|
59
plugins/ai_method/packages/models/encoders/_utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
|
||||
"""Change first convolution layer input channels.
|
||||
In case:
|
||||
in_channels == 1 or in_channels == 2 -> reuse original weights
|
||||
in_channels > 3 -> make random kaiming normal initialization
|
||||
"""
|
||||
|
||||
# get first conv
|
||||
for module in model.modules():
|
||||
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
|
||||
break
|
||||
|
||||
weight = module.weight.detach()
|
||||
module.in_channels = new_in_channels
|
||||
|
||||
if not pretrained:
|
||||
module.weight = nn.parameter.Parameter(
|
||||
torch.Tensor(
|
||||
module.out_channels,
|
||||
new_in_channels // module.groups,
|
||||
*module.kernel_size
|
||||
)
|
||||
)
|
||||
module.reset_parameters()
|
||||
|
||||
elif new_in_channels == 1:
|
||||
new_weight = weight.sum(1, keepdim=True)
|
||||
module.weight = nn.parameter.Parameter(new_weight)
|
||||
|
||||
else:
|
||||
new_weight = torch.Tensor(
|
||||
module.out_channels,
|
||||
new_in_channels // module.groups,
|
||||
*module.kernel_size
|
||||
)
|
||||
|
||||
for i in range(new_in_channels):
|
||||
new_weight[:, i] = weight[:, i % default_in_channels]
|
||||
|
||||
new_weight = new_weight * (default_in_channels / new_in_channels)
|
||||
module.weight = nn.parameter.Parameter(new_weight)
|
||||
|
||||
|
||||
def replace_strides_with_dilation(module, dilation_rate):
|
||||
"""Patch Conv2d modules replacing strides with dilation"""
|
||||
for mod in module.modules():
|
||||
if isinstance(mod, nn.Conv2d):
|
||||
mod.stride = (1, 1)
|
||||
mod.dilation = (dilation_rate, dilation_rate)
|
||||
kh, kw = mod.kernel_size
|
||||
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
|
||||
|
||||
# Kostyl for EfficientNet
|
||||
if hasattr(mod, "static_padding"):
|
||||
mod.static_padding = nn.Identity()
|
146
plugins/ai_method/packages/models/encoders/densenet.py
Normal file
@ -0,0 +1,146 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import re
|
||||
import torch.nn as nn
|
||||
|
||||
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
||||
from torchvision.models.densenet import DenseNet
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class TransitionWithSkip(nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.module:
|
||||
x = module(x)
|
||||
if isinstance(module, nn.ReLU):
|
||||
skip = x
|
||||
return x, skip
|
||||
|
||||
|
||||
class DenseNetEncoder(DenseNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
del self.classifier
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
raise ValueError("DenseNet encoders do not support dilated mode "
|
||||
"due to pooling operation for downsampling!")
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0),
|
||||
nn.Sequential(self.features.pool0, self.features.denseblock1,
|
||||
TransitionWithSkip(self.features.transition1)),
|
||||
nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)),
|
||||
nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)),
|
||||
nn.Sequential(self.features.denseblock4, self.features.norm5)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x, skip = x
|
||||
features.append(skip)
|
||||
else:
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pattern = re.compile(
|
||||
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
|
||||
)
|
||||
for key in list(state_dict.keys()):
|
||||
res = pattern.match(key)
|
||||
if res:
|
||||
new_key = res.group(1) + res.group(2)
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
# remove linear
|
||||
state_dict.pop("classifier.bias", None)
|
||||
state_dict.pop("classifier.weight", None)
|
||||
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
densenet_encoders = {
|
||||
"densenet121": {
|
||||
"encoder": DenseNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["densenet121"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 1024),
|
||||
"num_init_features": 64,
|
||||
"growth_rate": 32,
|
||||
"block_config": (6, 12, 24, 16),
|
||||
},
|
||||
},
|
||||
"densenet169": {
|
||||
"encoder": DenseNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["densenet169"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1280, 1664),
|
||||
"num_init_features": 64,
|
||||
"growth_rate": 32,
|
||||
"block_config": (6, 12, 32, 32),
|
||||
},
|
||||
},
|
||||
"densenet201": {
|
||||
"encoder": DenseNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["densenet201"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1792, 1920),
|
||||
"num_init_features": 64,
|
||||
"growth_rate": 32,
|
||||
"block_config": (6, 12, 48, 32),
|
||||
},
|
||||
},
|
||||
"densenet161": {
|
||||
"encoder": DenseNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["densenet161"],
|
||||
"params": {
|
||||
"out_channels": (3, 96, 384, 768, 2112, 2208),
|
||||
"num_init_features": 96,
|
||||
"growth_rate": 48,
|
||||
"block_config": (6, 12, 36, 24),
|
||||
},
|
||||
},
|
||||
}
|
170
plugins/ai_method/packages/models/encoders/dpn.py
Normal file
@ -0,0 +1,170 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pretrainedmodels.models.dpn import DPN
|
||||
from pretrainedmodels.models.dpn import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class DPNEncoder(DPN, EncoderMixin):
|
||||
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._stage_idxs = stage_idxs
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.last_linear
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
|
||||
nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
|
||||
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
|
||||
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
|
||||
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
features.append(F.relu(torch.cat(x, dim=1), inplace=True))
|
||||
else:
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("last_linear.bias", None)
|
||||
state_dict.pop("last_linear.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
dpn_encoders = {
|
||||
"dpn68": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn68"],
|
||||
"params": {
|
||||
"stage_idxs": (4, 8, 20, 24),
|
||||
"out_channels": (3, 10, 144, 320, 704, 832),
|
||||
"groups": 32,
|
||||
"inc_sec": (16, 32, 32, 64),
|
||||
"k_r": 128,
|
||||
"k_sec": (3, 4, 12, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 10,
|
||||
"small": True,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
"dpn68b": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn68b"],
|
||||
"params": {
|
||||
"stage_idxs": (4, 8, 20, 24),
|
||||
"out_channels": (3, 10, 144, 320, 704, 832),
|
||||
"b": True,
|
||||
"groups": 32,
|
||||
"inc_sec": (16, 32, 32, 64),
|
||||
"k_r": 128,
|
||||
"k_sec": (3, 4, 12, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 10,
|
||||
"small": True,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
"dpn92": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn92"],
|
||||
"params": {
|
||||
"stage_idxs": (4, 8, 28, 32),
|
||||
"out_channels": (3, 64, 336, 704, 1552, 2688),
|
||||
"groups": 32,
|
||||
"inc_sec": (16, 32, 24, 128),
|
||||
"k_r": 96,
|
||||
"k_sec": (3, 4, 20, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 64,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
"dpn98": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn98"],
|
||||
"params": {
|
||||
"stage_idxs": (4, 10, 30, 34),
|
||||
"out_channels": (3, 96, 336, 768, 1728, 2688),
|
||||
"groups": 40,
|
||||
"inc_sec": (16, 32, 32, 128),
|
||||
"k_r": 160,
|
||||
"k_sec": (3, 6, 20, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 96,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
"dpn107": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn107"],
|
||||
"params": {
|
||||
"stage_idxs": (5, 13, 33, 37),
|
||||
"out_channels": (3, 128, 376, 1152, 2432, 2688),
|
||||
"groups": 50,
|
||||
"inc_sec": (20, 64, 64, 128),
|
||||
"k_r": 200,
|
||||
"k_sec": (4, 8, 20, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 128,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
"dpn131": {
|
||||
"encoder": DPNEncoder,
|
||||
"pretrained_settings": pretrained_settings["dpn131"],
|
||||
"params": {
|
||||
"stage_idxs": (5, 13, 41, 45),
|
||||
"out_channels": (3, 128, 352, 832, 1984, 2688),
|
||||
"groups": 40,
|
||||
"inc_sec": (16, 32, 32, 128),
|
||||
"k_r": 160,
|
||||
"k_sec": (4, 8, 28, 3),
|
||||
"num_classes": 1000,
|
||||
"num_init_features": 128,
|
||||
"test_time_pool": True,
|
||||
},
|
||||
},
|
||||
}
|
178
plugins/ai_method/packages/models/encoders/efficientnet.py
Normal file
@ -0,0 +1,178 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from efficientnet_pytorch import EfficientNet
|
||||
from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class EfficientNetEncoder(EfficientNet, EncoderMixin):
|
||||
def __init__(self, stage_idxs, out_channels, model_name, depth=5):
|
||||
|
||||
blocks_args, global_params = get_model_params(model_name, override_params=None)
|
||||
super().__init__(blocks_args, global_params)
|
||||
|
||||
self._stage_idxs = stage_idxs
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
|
||||
del self._fc
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self._conv_stem, self._bn0, self._swish),
|
||||
self._blocks[:self._stage_idxs[0]],
|
||||
self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
|
||||
self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
|
||||
self._blocks[self._stage_idxs[2]:],
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
block_number = 0.
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
|
||||
# Identity and Sequential stages
|
||||
if i < 2:
|
||||
x = stages[i](x)
|
||||
|
||||
# Block stages need drop_connect rate
|
||||
else:
|
||||
for module in stages[i]:
|
||||
drop_connect = drop_connect_rate * block_number / len(self._blocks)
|
||||
block_number += 1.
|
||||
x = module(x, drop_connect)
|
||||
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("_fc.bias", None)
|
||||
state_dict.pop("_fc.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
def _get_pretrained_settings(encoder):
|
||||
pretrained_settings = {
|
||||
"imagenet": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
"url": url_map[encoder],
|
||||
"input_space": "RGB",
|
||||
"input_range": [0, 1],
|
||||
},
|
||||
"advprop": {
|
||||
"mean": [0.5, 0.5, 0.5],
|
||||
"std": [0.5, 0.5, 0.5],
|
||||
"url": url_map_advprop[encoder],
|
||||
"input_space": "RGB",
|
||||
"input_range": [0, 1],
|
||||
}
|
||||
}
|
||||
return pretrained_settings
|
||||
|
||||
|
||||
efficient_net_encoders = {
|
||||
"efficientnet-b0": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (3, 5, 9, 16),
|
||||
"model_name": "efficientnet-b0",
|
||||
},
|
||||
},
|
||||
"efficientnet-b1": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (5, 8, 16, 23),
|
||||
"model_name": "efficientnet-b1",
|
||||
},
|
||||
},
|
||||
"efficientnet-b2": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 48, 120, 352),
|
||||
"stage_idxs": (5, 8, 16, 23),
|
||||
"model_name": "efficientnet-b2",
|
||||
},
|
||||
},
|
||||
"efficientnet-b3": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
|
||||
"params": {
|
||||
"out_channels": (3, 40, 32, 48, 136, 384),
|
||||
"stage_idxs": (5, 8, 18, 26),
|
||||
"model_name": "efficientnet-b3",
|
||||
},
|
||||
},
|
||||
"efficientnet-b4": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
|
||||
"params": {
|
||||
"out_channels": (3, 48, 32, 56, 160, 448),
|
||||
"stage_idxs": (6, 10, 22, 32),
|
||||
"model_name": "efficientnet-b4",
|
||||
},
|
||||
},
|
||||
"efficientnet-b5": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
|
||||
"params": {
|
||||
"out_channels": (3, 48, 40, 64, 176, 512),
|
||||
"stage_idxs": (8, 13, 27, 39),
|
||||
"model_name": "efficientnet-b5",
|
||||
},
|
||||
},
|
||||
"efficientnet-b6": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
|
||||
"params": {
|
||||
"out_channels": (3, 56, 40, 72, 200, 576),
|
||||
"stage_idxs": (9, 15, 31, 45),
|
||||
"model_name": "efficientnet-b6",
|
||||
},
|
||||
},
|
||||
"efficientnet-b7": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
|
||||
"params": {
|
||||
"out_channels": (3, 64, 48, 80, 224, 640),
|
||||
"stage_idxs": (11, 18, 38, 55),
|
||||
"model_name": "efficientnet-b7",
|
||||
},
|
||||
},
|
||||
}
|
@ -0,0 +1,90 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2
|
||||
from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
|
||||
# correct paddings
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m.kernel_size == (3, 3):
|
||||
m.padding = (1, 1)
|
||||
if isinstance(m, nn.MaxPool2d):
|
||||
m.padding = (1, 1)
|
||||
|
||||
# remove linear layers
|
||||
del self.avgpool_1a
|
||||
del self.last_linear
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
raise ValueError("InceptionResnetV2 encoder does not support dilated mode "
|
||||
"due to pooling operation for downsampling!")
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b),
|
||||
nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a),
|
||||
nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat),
|
||||
nn.Sequential(self.mixed_6a, self.repeat_1),
|
||||
nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b),
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("last_linear.bias", None)
|
||||
state_dict.pop("last_linear.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
inceptionresnetv2_encoders = {
|
||||
"inceptionresnetv2": {
|
||||
"encoder": InceptionResNetV2Encoder,
|
||||
"pretrained_settings": pretrained_settings["inceptionresnetv2"],
|
||||
"params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000},
|
||||
}
|
||||
}
|
93
plugins/ai_method/packages/models/encoders/inceptionv4.py
Normal file
@ -0,0 +1,93 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d
|
||||
from pretrainedmodels.models.inceptionv4 import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class InceptionV4Encoder(InceptionV4, EncoderMixin):
|
||||
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._stage_idxs = stage_idxs
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
|
||||
# correct paddings
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m.kernel_size == (3, 3):
|
||||
m.padding = (1, 1)
|
||||
if isinstance(m, nn.MaxPool2d):
|
||||
m.padding = (1, 1)
|
||||
|
||||
# remove linear layers
|
||||
del self.last_linear
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
raise ValueError("InceptionV4 encoder does not support dilated mode "
|
||||
"due to pooling operation for downsampling!")
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
self.features[: self._stage_idxs[0]],
|
||||
self.features[self._stage_idxs[0]: self._stage_idxs[1]],
|
||||
self.features[self._stage_idxs[1]: self._stage_idxs[2]],
|
||||
self.features[self._stage_idxs[2]: self._stage_idxs[3]],
|
||||
self.features[self._stage_idxs[3]:],
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("last_linear.bias", None)
|
||||
state_dict.pop("last_linear.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
inceptionv4_encoders = {
|
||||
"inceptionv4": {
|
||||
"encoder": InceptionV4Encoder,
|
||||
"pretrained_settings": pretrained_settings["inceptionv4"],
|
||||
"params": {
|
||||
"stage_idxs": (3, 5, 9, 15),
|
||||
"out_channels": (3, 64, 192, 384, 1024, 1536),
|
||||
"num_classes": 1001,
|
||||
},
|
||||
}
|
||||
}
|
192
plugins/ai_method/packages/models/encoders/mit_encoder.py
Normal file
@ -0,0 +1,192 @@
|
||||
# ---------------------------------------------------------------
|
||||
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# This work is licensed under the NVIDIA Source Code License
|
||||
# ---------------------------------------------------------------
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
from .mix_transformer import MixVisionTransformer
|
||||
|
||||
|
||||
class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
def get_stages(self):
|
||||
return [nn.Identity()]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for stage in stages:
|
||||
x = stage(x)
|
||||
features.append(x)
|
||||
outs = self.forward_features(x)
|
||||
add_feature = F.interpolate(outs[0], scale_factor=2)
|
||||
features = features + [add_feature] + outs
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
new_state_dict = {}
|
||||
if state_dict.get('state_dict'):
|
||||
state_dict = state_dict['state_dict']
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('backbone'):
|
||||
new_state_dict[k.replace('backbone.', '')] = v
|
||||
else:
|
||||
new_state_dict = deepcopy(state_dict)
|
||||
super().load_state_dict(new_state_dict, **kwargs)
|
||||
|
||||
|
||||
# https://github.com/NVlabs/SegFormer
|
||||
new_settings = {
|
||||
"mit-b0": {
|
||||
"imagenet": "https://lino.local.server/mit_b0.pth"
|
||||
},
|
||||
"mit-b1": {
|
||||
"imagenet": "https://lino.local.server/mit_b1.pth"
|
||||
},
|
||||
"mit-b2": {
|
||||
"imagenet": "https://lino.local.server/mit_b2.pth"
|
||||
},
|
||||
"mit-b3": {
|
||||
"imagenet": "https://lino.local.server/mit_b3.pth"
|
||||
},
|
||||
"mit-b4": {
|
||||
"imagenet": "https://lino.local.server/mit_b4.pth"
|
||||
},
|
||||
"mit-b5": {
|
||||
"imagenet": "https://lino.local.server/mit_b5.pth"
|
||||
},
|
||||
}
|
||||
|
||||
pretrained_settings = deepcopy(pretrained_settings)
|
||||
for model_name, sources in new_settings.items():
|
||||
if model_name not in pretrained_settings:
|
||||
pretrained_settings[model_name] = {}
|
||||
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
mit_encoders = {
|
||||
"mit-b0": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b0"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [32, 64, 160, 256],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [2, 2, 2, 2],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 32, 32, 64, 160, 256)
|
||||
}
|
||||
},
|
||||
"mit-b1": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b1"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [64, 128, 320, 512],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [2, 2, 2, 2],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 64, 64, 128, 320, 512)
|
||||
}
|
||||
},
|
||||
"mit-b2": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b2"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [64, 128, 320, 512],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [3, 4, 6, 3],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 64, 64, 128, 320, 512)
|
||||
}
|
||||
},
|
||||
"mit-b3": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b3"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [64, 128, 320, 512],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [3, 4, 18, 3],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 64, 64, 128, 320, 512)
|
||||
}
|
||||
},
|
||||
"mit-b4": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b4"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [64, 128, 320, 512],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [3, 8, 27, 3],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 64, 64, 128, 320, 512)
|
||||
}
|
||||
},
|
||||
"mit-b5": {
|
||||
"encoder": MixVisionTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["mit-b5"],
|
||||
"params": {
|
||||
"patch_size": 4,
|
||||
"embed_dims": [64, 128, 320, 512],
|
||||
"num_heads": [1, 2, 5, 8],
|
||||
"mlp_ratios": [4, 4, 4, 4],
|
||||
"qkv_bias": True,
|
||||
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
|
||||
"depths": [3, 6, 40, 3],
|
||||
"sr_ratios": [8, 4, 2, 1],
|
||||
"drop_rate": 0.0,
|
||||
"drop_path_rate": 0.1,
|
||||
"out_channels": (3, 64, 64, 128, 320, 512)
|
||||
}
|
||||
},
|
||||
}
|
361
plugins/ai_method/packages/models/encoders/mix_transformer.py
Normal file
@ -0,0 +1,361 @@
|
||||
# ---------------------------------------------------------------
|
||||
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# This work is licensed under the NVIDIA Source Code License
|
||||
# ---------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.vision_transformer import _cfg
|
||||
import math
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MixVisionTransformer(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
||||
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
|
||||
# patch_embed
|
||||
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
||||
embed_dim=embed_dims[0])
|
||||
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
||||
embed_dim=embed_dims[1])
|
||||
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
||||
embed_dim=embed_dims[2])
|
||||
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
||||
embed_dim=embed_dims[3])
|
||||
|
||||
# transformer encoder
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
self.block1 = nn.ModuleList([Block(
|
||||
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[0])
|
||||
for i in range(depths[0])])
|
||||
self.norm1 = norm_layer(embed_dims[0])
|
||||
|
||||
cur += depths[0]
|
||||
self.block2 = nn.ModuleList([Block(
|
||||
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[1])
|
||||
for i in range(depths[1])])
|
||||
self.norm2 = norm_layer(embed_dims[1])
|
||||
|
||||
cur += depths[1]
|
||||
self.block3 = nn.ModuleList([Block(
|
||||
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[2])
|
||||
for i in range(depths[2])])
|
||||
self.norm3 = norm_layer(embed_dims[2])
|
||||
|
||||
cur += depths[2]
|
||||
self.block4 = nn.ModuleList([Block(
|
||||
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[3])
|
||||
for i in range(depths[3])])
|
||||
self.norm4 = norm_layer(embed_dims[3])
|
||||
|
||||
# classification head
|
||||
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def reset_drop_path(self, drop_path_rate):
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
||||
cur = 0
|
||||
for i in range(self.depths[0]):
|
||||
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[0]
|
||||
for i in range(self.depths[1]):
|
||||
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[1]
|
||||
for i in range(self.depths[2]):
|
||||
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[2]
|
||||
for i in range(self.depths[3]):
|
||||
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
def freeze_patch_emb(self):
|
||||
self.patch_embed1.requires_grad = False
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
# stage 1
|
||||
x, H, W = self.patch_embed1(x)
|
||||
for i, blk in enumerate(self.block1):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 2
|
||||
x, H, W = self.patch_embed2(x)
|
||||
for i, blk in enumerate(self.block2):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm2(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 3
|
||||
x, H, W = self.patch_embed3(x)
|
||||
for i, blk in enumerate(self.block3):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm3(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 4
|
||||
x, H, W = self.patch_embed4(x)
|
||||
for i, blk in enumerate(self.block4):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm4(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
83
plugins/ai_method/packages/models/encoders/mobilenet.py
Normal file
@ -0,0 +1,83 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
|
||||
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
del self.classifier
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
self.features[:2],
|
||||
self.features[2:4],
|
||||
self.features[4:7],
|
||||
self.features[7:14],
|
||||
self.features[14:],
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("classifier.1.bias", None)
|
||||
state_dict.pop("classifier.1.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
mobilenet_encoders = {
|
||||
"mobilenet_v2": {
|
||||
"encoder": MobileNetV2Encoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
"url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
"input_space": "RGB",
|
||||
"input_range": [0, 1],
|
||||
},
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 16, 24, 32, 96, 1280),
|
||||
},
|
||||
},
|
||||
}
|
238
plugins/ai_method/packages/models/encoders/resnet.py
Normal file
@ -0,0 +1,238 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision.models.resnet import ResNet
|
||||
from torchvision.models.resnet import BasicBlock
|
||||
from torchvision.models.resnet import Bottleneck
|
||||
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class ResNetEncoder(ResNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.fc
|
||||
del self.avgpool
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.conv1, self.bn1, self.relu),
|
||||
nn.Sequential(self.maxpool, self.layer1),
|
||||
self.layer2,
|
||||
self.layer3,
|
||||
self.layer4,
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("fc.bias", None)
|
||||
state_dict.pop("fc.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
new_settings = {
|
||||
"resnet18": {
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
|
||||
},
|
||||
"resnet50": {
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
|
||||
},
|
||||
"resnext50_32x4d": {
|
||||
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth",
|
||||
},
|
||||
"resnext101_32x4d": {
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
|
||||
},
|
||||
"resnext101_32x8d": {
|
||||
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth",
|
||||
},
|
||||
"resnext101_32x16d": {
|
||||
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
|
||||
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth",
|
||||
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth",
|
||||
},
|
||||
"resnext101_32x32d": {
|
||||
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
|
||||
},
|
||||
"resnext101_32x48d": {
|
||||
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
|
||||
}
|
||||
}
|
||||
|
||||
pretrained_settings = deepcopy(pretrained_settings)
|
||||
for model_name, sources in new_settings.items():
|
||||
if model_name not in pretrained_settings:
|
||||
pretrained_settings[model_name] = {}
|
||||
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
|
||||
resnet_encoders = {
|
||||
"resnet18": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnet18"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 64, 128, 256, 512),
|
||||
"block": BasicBlock,
|
||||
"layers": [2, 2, 2, 2],
|
||||
},
|
||||
},
|
||||
"resnet34": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnet34"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 64, 128, 256, 512),
|
||||
"block": BasicBlock,
|
||||
"layers": [3, 4, 6, 3],
|
||||
},
|
||||
},
|
||||
"resnet50": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnet50"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 6, 3],
|
||||
},
|
||||
},
|
||||
"resnet101": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnet101"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
},
|
||||
},
|
||||
"resnet152": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnet152"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 8, 36, 3],
|
||||
},
|
||||
},
|
||||
"resnext50_32x4d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext50_32x4d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 6, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 4,
|
||||
},
|
||||
},
|
||||
"resnext101_32x4d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext101_32x4d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 4,
|
||||
},
|
||||
},
|
||||
"resnext101_32x8d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext101_32x8d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 8,
|
||||
},
|
||||
},
|
||||
"resnext101_32x16d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext101_32x16d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 16,
|
||||
},
|
||||
},
|
||||
"resnext101_32x32d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext101_32x32d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 32,
|
||||
},
|
||||
},
|
||||
"resnext101_32x48d": {
|
||||
"encoder": ResNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["resnext101_32x48d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": Bottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"groups": 32,
|
||||
"width_per_group": 48,
|
||||
},
|
||||
},
|
||||
}
|
174
plugins/ai_method/packages/models/encoders/senet.py
Normal file
@ -0,0 +1,174 @@
|
||||
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
|
||||
|
||||
Attributes:
|
||||
|
||||
_out_channels (list of int): specify number of channels for each encoder feature tensor
|
||||
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
|
||||
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
|
||||
|
||||
Methods:
|
||||
|
||||
forward(self, x: torch.Tensor)
|
||||
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
|
||||
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
|
||||
with resolution same as input `x` tensor).
|
||||
|
||||
Input: `x` with shape (1, 3, 64, 64)
|
||||
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
|
||||
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
|
||||
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
|
||||
|
||||
also should support number of features according to specified depth, e.g. if depth = 5,
|
||||
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
|
||||
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from pretrainedmodels.models.senet import (
|
||||
SENet,
|
||||
SEBottleneck,
|
||||
SEResNetBottleneck,
|
||||
SEResNeXtBottleneck,
|
||||
pretrained_settings,
|
||||
)
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
class SENetEncoder(SENet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
|
||||
del self.last_linear
|
||||
del self.avg_pool
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
self.layer0[:-1],
|
||||
nn.Sequential(self.layer0[-1], self.layer1),
|
||||
self.layer2,
|
||||
self.layer3,
|
||||
self.layer4,
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("last_linear.bias", None)
|
||||
state_dict.pop("last_linear.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
senet_encoders = {
|
||||
"senet154": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["senet154"],
|
||||
"params": {
|
||||
"out_channels": (3, 128, 256, 512, 1024, 2048),
|
||||
"block": SEBottleneck,
|
||||
"dropout_p": 0.2,
|
||||
"groups": 64,
|
||||
"layers": [3, 8, 36, 3],
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
"se_resnet50": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["se_resnet50"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": SEResNetBottleneck,
|
||||
"layers": [3, 4, 6, 3],
|
||||
"downsample_kernel_size": 1,
|
||||
"downsample_padding": 0,
|
||||
"dropout_p": None,
|
||||
"groups": 1,
|
||||
"inplanes": 64,
|
||||
"input_3x3": False,
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
"se_resnet101": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["se_resnet101"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": SEResNetBottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"downsample_kernel_size": 1,
|
||||
"downsample_padding": 0,
|
||||
"dropout_p": None,
|
||||
"groups": 1,
|
||||
"inplanes": 64,
|
||||
"input_3x3": False,
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
"se_resnet152": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["se_resnet152"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": SEResNetBottleneck,
|
||||
"layers": [3, 8, 36, 3],
|
||||
"downsample_kernel_size": 1,
|
||||
"downsample_padding": 0,
|
||||
"dropout_p": None,
|
||||
"groups": 1,
|
||||
"inplanes": 64,
|
||||
"input_3x3": False,
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
"se_resnext50_32x4d": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["se_resnext50_32x4d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": SEResNeXtBottleneck,
|
||||
"layers": [3, 4, 6, 3],
|
||||
"downsample_kernel_size": 1,
|
||||
"downsample_padding": 0,
|
||||
"dropout_p": None,
|
||||
"groups": 32,
|
||||
"inplanes": 64,
|
||||
"input_3x3": False,
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
"se_resnext101_32x4d": {
|
||||
"encoder": SENetEncoder,
|
||||
"pretrained_settings": pretrained_settings["se_resnext101_32x4d"],
|
||||
"params": {
|
||||
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
||||
"block": SEResNeXtBottleneck,
|
||||
"layers": [3, 4, 23, 3],
|
||||
"downsample_kernel_size": 1,
|
||||
"downsample_padding": 0,
|
||||
"dropout_p": None,
|
||||
"groups": 32,
|
||||
"inplanes": 64,
|
||||
"input_3x3": False,
|
||||
"num_classes": 1000,
|
||||
"reduction": 16,
|
||||
},
|
||||
},
|
||||
}
|
196
plugins/ai_method/packages/models/encoders/swin_transformer.py
Normal file
@ -0,0 +1,196 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from collections import OrderedDict
|
||||
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
||||
|
||||
from ._base import EncoderMixin
|
||||
from .swin_transformer_model import SwinTransformer
|
||||
|
||||
|
||||
class SwinTransformerEncoder(SwinTransformer, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
def get_stages(self):
|
||||
return [nn.Identity()]
|
||||
|
||||
def feature_forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
return outs
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for stage in stages:
|
||||
x = stage(x)
|
||||
features.append(x)
|
||||
outs = self.feature_forward(x)
|
||||
|
||||
# Note: An additional interpolated feature to accommodate five-stage decoders,\
|
||||
# the additional feature will be ignored if a decoder with fewer stages is used.
|
||||
add_feature = F.interpolate(outs[0], scale_factor=2)
|
||||
features = features + [add_feature] + outs
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
if 'state_dict' in state_dict:
|
||||
_state_dict = state_dict['state_dict']
|
||||
elif 'model' in state_dict:
|
||||
_state_dict = state_dict['model']
|
||||
else:
|
||||
_state_dict = state_dict
|
||||
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
new_state_dict[k[9:]] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
# Note: In swin seg model: `attn_mask` is no longer a class attribute for
|
||||
# multi-scale inputs; a norm layer is added for each output; the head layer
|
||||
# is removed.
|
||||
kwargs.update({'strict': False})
|
||||
super().load_state_dict(new_state_dict, **kwargs)
|
||||
|
||||
|
||||
# https://github.com/microsoft/Swin-Transformer
|
||||
new_settings = {
|
||||
"Swin-T": {
|
||||
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth",
|
||||
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth"
|
||||
},
|
||||
"Swin-S": {
|
||||
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth",
|
||||
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_small_patch4_window7_512x512.pth"
|
||||
},
|
||||
"Swin-B": {
|
||||
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth",
|
||||
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth",
|
||||
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth"
|
||||
},
|
||||
"Swin-L": {
|
||||
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth"
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
pretrained_settings = deepcopy(pretrained_settings)
|
||||
for model_name, sources in new_settings.items():
|
||||
if model_name not in pretrained_settings:
|
||||
pretrained_settings[model_name] = {}
|
||||
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
swin_transformer_encoders = {
|
||||
"Swin-T": {
|
||||
"encoder": SwinTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["Swin-T"],
|
||||
"params": {
|
||||
"embed_dim": 96,
|
||||
"out_channels": (3, 96, 96, 192, 384, 768),
|
||||
"depths": [2, 2, 6, 2],
|
||||
"num_heads": [3, 6, 12, 24],
|
||||
"window_size": 7,
|
||||
"ape": False,
|
||||
"drop_path_rate": 0.3,
|
||||
"patch_norm": True,
|
||||
"use_checkpoint": False
|
||||
}
|
||||
},
|
||||
"Swin-S": {
|
||||
"encoder": SwinTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["Swin-S"],
|
||||
"params": {
|
||||
"embed_dim": 96,
|
||||
"out_channels": (3, 96, 96, 192, 384, 768),
|
||||
"depths": [2, 2, 18, 2],
|
||||
"num_heads": [3, 6, 12, 24],
|
||||
"window_size": 7,
|
||||
"ape": False,
|
||||
"drop_path_rate": 0.3,
|
||||
"patch_norm": True,
|
||||
"use_checkpoint": False
|
||||
}
|
||||
},
|
||||
"Swin-B": {
|
||||
"encoder": SwinTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["Swin-B"],
|
||||
"params": {
|
||||
"embed_dim": 128,
|
||||
"out_channels": (3, 128, 128, 256, 512, 1024),
|
||||
"depths": [2, 2, 18, 2],
|
||||
"num_heads": [4, 8, 16, 32],
|
||||
"window_size": 7,
|
||||
"ape": False,
|
||||
"drop_path_rate": 0.3,
|
||||
"patch_norm": True,
|
||||
"use_checkpoint": False
|
||||
}
|
||||
},
|
||||
"Swin-L": {
|
||||
"encoder": SwinTransformerEncoder,
|
||||
"pretrained_settings": pretrained_settings["Swin-L"],
|
||||
"params": {
|
||||
"embed_dim": 192,
|
||||
"out_channels": (3, 192, 192, 384, 768, 1536),
|
||||
"depths": [2, 2, 18, 2],
|
||||
"num_heads": [6, 12, 24, 48],
|
||||
"window_size": 7,
|
||||
"ape": False,
|
||||
"drop_path_rate": 0.3,
|
||||
"patch_norm": True,
|
||||
"use_checkpoint": False
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
input = torch.randn(1, 3, 256, 256).to(device)
|
||||
|
||||
model = SwinTransformerEncoder(2, window_size=8)
|
||||
# print(model)
|
||||
|
||||
res = model.forward(input)
|
||||
for i in res:
|
||||
print(i.shape)
|
@ -0,0 +1,626 @@
|
||||
# --------------------------------------------------------
|
||||
# Swin Transformer
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
import numpy as np
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" Multilayer perceptron."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
mask_matrix: Attention mask for cyclic shift.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of feature channels
|
||||
depth (int): Depths of this stage.
|
||||
num_heads (int): Number of attention head.
|
||||
window_size (int): Local window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
""" Swin Transformer backbone.
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
pretrain_img_size (int): Input image size for training the pretrained model,
|
||||
used in absolute postion embedding. Default 224.
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
num_heads (tuple[int]): Number of attention head of each stage.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop_rate (float): Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
use_checkpoint=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
||||
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint)
|
||||
self.layers.append(layer)
|
||||
|
||||
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
# add a norm layer for each output
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
self.init_weights() # the pre-trained model will be loaded later if needed
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 1 and self.ape:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 2:
|
||||
self.pos_drop.eval()
|
||||
for i in range(0, self.frozen_stages - 1):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
if isinstance(pretrained, str):
|
||||
self.apply(_init_weights)
|
||||
# logger = get_root_logger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
pass
|
||||
elif pretrained is None:
|
||||
self.apply(_init_weights)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
382
plugins/ai_method/packages/models/encoders/timm_efficientnet.py
Normal file
@ -0,0 +1,382 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.models.efficientnet import EfficientNet
|
||||
from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs
|
||||
from timm.models.layers.activations import Swish
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
|
||||
"""Creates an EfficientNet model.
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
||||
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_channels(1280, channel_multiplier, 8, None),
|
||||
stem_size=32,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
act_layer=Swish,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=0.2,
|
||||
)
|
||||
return model_kwargs
|
||||
|
||||
def gen_efficientnet_lite_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
|
||||
"""Creates an EfficientNet-Lite model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
|
||||
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r2_k5_s2_e6_c40'],
|
||||
['ir_r3_k3_s2_e6_c80'],
|
||||
['ir_r3_k5_s1_e6_c112'],
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
|
||||
num_features=1280,
|
||||
stem_size=32,
|
||||
fix_stem=True,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
act_layer=nn.ReLU6,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=0.2,
|
||||
)
|
||||
return model_kwargs
|
||||
|
||||
class EfficientNetBaseEncoder(EfficientNet, EncoderMixin):
|
||||
|
||||
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._stage_idxs = stage_idxs
|
||||
self._out_channels = out_channels
|
||||
self._depth = depth
|
||||
self._in_channels = 3
|
||||
|
||||
del self.classifier
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.conv_stem, self.bn1, self.act1),
|
||||
self.blocks[:self._stage_idxs[0]],
|
||||
self.blocks[self._stage_idxs[0]:self._stage_idxs[1]],
|
||||
self.blocks[self._stage_idxs[1]:self._stage_idxs[2]],
|
||||
self.blocks[self._stage_idxs[2]:],
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("classifier.bias", None)
|
||||
state_dict.pop("classifier.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
class EfficientNetEncoder(EfficientNetBaseEncoder):
|
||||
|
||||
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
|
||||
kwargs = get_efficientnet_kwargs(channel_multiplier, depth_multiplier, drop_rate)
|
||||
super().__init__(stage_idxs, out_channels, depth, **kwargs)
|
||||
|
||||
|
||||
class EfficientNetLiteEncoder(EfficientNetBaseEncoder):
|
||||
|
||||
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
|
||||
kwargs = gen_efficientnet_lite_kwargs(channel_multiplier, depth_multiplier, drop_rate)
|
||||
super().__init__(stage_idxs, out_channels, depth, **kwargs)
|
||||
|
||||
|
||||
def prepare_settings(settings):
|
||||
return {
|
||||
"mean": settings["mean"],
|
||||
"std": settings["std"],
|
||||
"url": settings["url"],
|
||||
"input_range": (0, 1),
|
||||
"input_space": "RGB",
|
||||
}
|
||||
|
||||
|
||||
timm_efficientnet_encoders = {
|
||||
|
||||
"timm-efficientnet-b0": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b0"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b0_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b0_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.0,
|
||||
"depth_multiplier": 1.0,
|
||||
"drop_rate": 0.2,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b1": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b1"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b1_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b1_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.0,
|
||||
"depth_multiplier": 1.1,
|
||||
"drop_rate": 0.2,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b2": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b2"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b2_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b2_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 48, 120, 352),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.1,
|
||||
"depth_multiplier": 1.2,
|
||||
"drop_rate": 0.3,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b3": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b3"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b3_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b3_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 40, 32, 48, 136, 384),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.2,
|
||||
"depth_multiplier": 1.4,
|
||||
"drop_rate": 0.3,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b4": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b4"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b4_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b4_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 48, 32, 56, 160, 448),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.4,
|
||||
"depth_multiplier": 1.8,
|
||||
"drop_rate": 0.4,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b5": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b5"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b5_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b5_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 48, 40, 64, 176, 512),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.6,
|
||||
"depth_multiplier": 2.2,
|
||||
"drop_rate": 0.4,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b6": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b6"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b6_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b6_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 56, 40, 72, 200, 576),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.8,
|
||||
"depth_multiplier": 2.6,
|
||||
"drop_rate": 0.5,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b7": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b7"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b7_ap"]),
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b7_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 64, 48, 80, 224, 640),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 2.0,
|
||||
"depth_multiplier": 3.1,
|
||||
"drop_rate": 0.5,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-b8": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b8"]),
|
||||
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b8_ap"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 72, 56, 88, 248, 704),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 2.2,
|
||||
"depth_multiplier": 3.6,
|
||||
"drop_rate": 0.5,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-efficientnet-l2": {
|
||||
"encoder": EfficientNetEncoder,
|
||||
"pretrained_settings": {
|
||||
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_l2_ns"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 136, 104, 176, 480, 1376),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 4.3,
|
||||
"depth_multiplier": 5.3,
|
||||
"drop_rate": 0.5,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-tf_efficientnet_lite0": {
|
||||
"encoder": EfficientNetLiteEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite0"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.0,
|
||||
"depth_multiplier": 1.0,
|
||||
"drop_rate": 0.2,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-tf_efficientnet_lite1": {
|
||||
"encoder": EfficientNetLiteEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite1"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 40, 112, 320),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.0,
|
||||
"depth_multiplier": 1.1,
|
||||
"drop_rate": 0.2,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-tf_efficientnet_lite2": {
|
||||
"encoder": EfficientNetLiteEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite2"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 24, 48, 120, 352),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.1,
|
||||
"depth_multiplier": 1.2,
|
||||
"drop_rate": 0.3,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-tf_efficientnet_lite3": {
|
||||
"encoder": EfficientNetLiteEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite3"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 32, 48, 136, 384),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.2,
|
||||
"depth_multiplier": 1.4,
|
||||
"drop_rate": 0.3,
|
||||
},
|
||||
},
|
||||
|
||||
"timm-tf_efficientnet_lite4": {
|
||||
"encoder": EfficientNetLiteEncoder,
|
||||
"pretrained_settings": {
|
||||
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite4"]),
|
||||
},
|
||||
"params": {
|
||||
"out_channels": (3, 32, 32, 56, 160, 448),
|
||||
"stage_idxs": (2, 3, 5),
|
||||
"channel_multiplier": 1.4,
|
||||
"depth_multiplier": 1.8,
|
||||
"drop_rate": 0.4,
|
||||
},
|
||||
},
|
||||
}
|
124
plugins/ai_method/packages/models/encoders/timm_gernet.py
Normal file
@ -0,0 +1,124 @@
|
||||
from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet
|
||||
|
||||
from ._base import EncoderMixin
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GERNetEncoder(ByobNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.head
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
self.stem,
|
||||
self.stages[0],
|
||||
self.stages[1],
|
||||
self.stages[2],
|
||||
nn.Sequential(self.stages[3], self.stages[4], self.final_conv)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("head.fc.weight", None)
|
||||
state_dict.pop("head.fc.bias", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
regnet_weights = {
|
||||
'timm-gernet_s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth',
|
||||
},
|
||||
'timm-gernet_m': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth',
|
||||
},
|
||||
'timm-gernet_l': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
|
||||
},
|
||||
}
|
||||
|
||||
pretrained_settings = {}
|
||||
for model_name, sources in regnet_weights.items():
|
||||
pretrained_settings[model_name] = {}
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
timm_gernet_encoders = {
|
||||
'timm-gernet_s': {
|
||||
'encoder': GERNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-gernet_s"],
|
||||
'params': {
|
||||
'out_channels': (3, 13, 48, 48, 384, 1920),
|
||||
'cfg': ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
|
||||
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
|
||||
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
|
||||
),
|
||||
stem_chs=13,
|
||||
stem_pool=None,
|
||||
num_features=1920,
|
||||
)
|
||||
},
|
||||
},
|
||||
'timm-gernet_m': {
|
||||
'encoder': GERNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-gernet_m"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 128, 192, 640, 2560),
|
||||
'cfg': ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
|
||||
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
|
||||
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
|
||||
),
|
||||
stem_chs=32,
|
||||
stem_pool=None,
|
||||
num_features=2560,
|
||||
)
|
||||
},
|
||||
},
|
||||
'timm-gernet_l': {
|
||||
'encoder': GERNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-gernet_l"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 128, 192, 640, 2560),
|
||||
'cfg': ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
|
||||
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
|
||||
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
|
||||
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
|
||||
),
|
||||
stem_chs=32,
|
||||
stem_pool=None,
|
||||
num_features=2560,
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
175
plugins/ai_method/packages/models/encoders/timm_mobilenetv3.py
Normal file
@ -0,0 +1,175 @@
|
||||
import timm
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from ._base import EncoderMixin
|
||||
|
||||
|
||||
def _make_divisible(x, divisible_by=8):
|
||||
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
|
||||
|
||||
|
||||
class MobileNetV3Encoder(nn.Module, EncoderMixin):
|
||||
def __init__(self, model_name, width_mult, depth=5, **kwargs):
|
||||
super().__init__()
|
||||
if "large" not in model_name and "small" not in model_name:
|
||||
raise ValueError(
|
||||
'MobileNetV3 wrong model name {}'.format(model_name)
|
||||
)
|
||||
|
||||
self._mode = "small" if "small" in model_name else "large"
|
||||
self._depth = depth
|
||||
self._out_channels = self._get_channels(self._mode, width_mult)
|
||||
self._in_channels = 3
|
||||
|
||||
# minimal models replace hardswish with relu
|
||||
self.model = timm.create_model(
|
||||
model_name=model_name,
|
||||
scriptable=True, # torch.jit scriptable
|
||||
exportable=True, # onnx export
|
||||
features_only=True,
|
||||
)
|
||||
|
||||
def _get_channels(self, mode, width_mult):
|
||||
if mode == "small":
|
||||
channels = [16, 16, 24, 48, 576]
|
||||
else:
|
||||
channels = [16, 24, 40, 112, 960]
|
||||
channels = [3,] + [_make_divisible(x * width_mult) for x in channels]
|
||||
return tuple(channels)
|
||||
|
||||
def get_stages(self):
|
||||
if self._mode == 'small':
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(
|
||||
self.model.conv_stem,
|
||||
self.model.bn1,
|
||||
self.model.act1,
|
||||
),
|
||||
self.model.blocks[0],
|
||||
self.model.blocks[1],
|
||||
self.model.blocks[2:4],
|
||||
self.model.blocks[4:],
|
||||
]
|
||||
elif self._mode == 'large':
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(
|
||||
self.model.conv_stem,
|
||||
self.model.bn1,
|
||||
self.model.act1,
|
||||
self.model.blocks[0],
|
||||
),
|
||||
self.model.blocks[1],
|
||||
self.model.blocks[2],
|
||||
self.model.blocks[3:5],
|
||||
self.model.blocks[5:],
|
||||
]
|
||||
else:
|
||||
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self._mode))
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop('conv_head.weight', None)
|
||||
state_dict.pop('conv_head.bias', None)
|
||||
state_dict.pop('classifier.weight', None)
|
||||
state_dict.pop('classifier.bias', None)
|
||||
self.model.load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
mobilenetv3_weights = {
|
||||
'tf_mobilenetv3_large_075': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
|
||||
},
|
||||
'tf_mobilenetv3_large_100': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
|
||||
},
|
||||
'tf_mobilenetv3_large_minimal_100': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
|
||||
},
|
||||
'tf_mobilenetv3_small_075': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
|
||||
},
|
||||
'tf_mobilenetv3_small_100': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
|
||||
},
|
||||
'tf_mobilenetv3_small_minimal_100': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
|
||||
},
|
||||
|
||||
|
||||
}
|
||||
|
||||
pretrained_settings = {}
|
||||
for model_name, sources in mobilenetv3_weights.items():
|
||||
pretrained_settings[model_name] = {}
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'input_space': 'RGB',
|
||||
}
|
||||
|
||||
|
||||
timm_mobilenetv3_encoders = {
|
||||
'timm-mobilenetv3_large_075': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_large_075',
|
||||
'width_mult': 0.75
|
||||
}
|
||||
},
|
||||
'timm-mobilenetv3_large_100': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_large_100',
|
||||
'width_mult': 1.0
|
||||
}
|
||||
},
|
||||
'timm-mobilenetv3_large_minimal_100': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_large_minimal_100',
|
||||
'width_mult': 1.0
|
||||
}
|
||||
},
|
||||
'timm-mobilenetv3_small_075': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_small_075',
|
||||
'width_mult': 0.75
|
||||
}
|
||||
},
|
||||
'timm-mobilenetv3_small_100': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_small_100',
|
||||
'width_mult': 1.0
|
||||
}
|
||||
},
|
||||
'timm-mobilenetv3_small_minimal_100': {
|
||||
'encoder': MobileNetV3Encoder,
|
||||
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
|
||||
'params': {
|
||||
'model_name': 'tf_mobilenetv3_small_minimal_100',
|
||||
'width_mult': 1.0
|
||||
}
|
||||
},
|
||||
}
|
332
plugins/ai_method/packages/models/encoders/timm_regnet.py
Normal file
@ -0,0 +1,332 @@
|
||||
from ._base import EncoderMixin
|
||||
from timm.models.regnet import RegNet
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RegNetEncoder(RegNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.head
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
self.stem,
|
||||
self.s1,
|
||||
self.s2,
|
||||
self.s3,
|
||||
self.s4,
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("head.fc.weight", None)
|
||||
state_dict.pop("head.fc.bias", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
regnet_weights = {
|
||||
'timm-regnetx_002': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth',
|
||||
},
|
||||
'timm-regnetx_004': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth',
|
||||
},
|
||||
'timm-regnetx_006': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth',
|
||||
},
|
||||
'timm-regnetx_008': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth',
|
||||
},
|
||||
'timm-regnetx_016': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth',
|
||||
},
|
||||
'timm-regnetx_032': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth',
|
||||
},
|
||||
'timm-regnetx_040': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth',
|
||||
},
|
||||
'timm-regnetx_064': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth',
|
||||
},
|
||||
'timm-regnetx_080': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth',
|
||||
},
|
||||
'timm-regnetx_120': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth',
|
||||
},
|
||||
'timm-regnetx_160': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth',
|
||||
},
|
||||
'timm-regnetx_320': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth',
|
||||
},
|
||||
'timm-regnety_002': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth',
|
||||
},
|
||||
'timm-regnety_004': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth',
|
||||
},
|
||||
'timm-regnety_006': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth',
|
||||
},
|
||||
'timm-regnety_008': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth',
|
||||
},
|
||||
'timm-regnety_016': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth',
|
||||
},
|
||||
'timm-regnety_032': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'
|
||||
},
|
||||
'timm-regnety_040': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'
|
||||
},
|
||||
'timm-regnety_064': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'
|
||||
},
|
||||
'timm-regnety_080': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth',
|
||||
},
|
||||
'timm-regnety_120': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth',
|
||||
},
|
||||
'timm-regnety_160': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth',
|
||||
},
|
||||
'timm-regnety_320': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
|
||||
}
|
||||
}
|
||||
|
||||
pretrained_settings = {}
|
||||
for model_name, sources in regnet_weights.items():
|
||||
pretrained_settings[model_name] = {}
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
|
||||
|
||||
|
||||
def _mcfg(**kwargs):
|
||||
cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
|
||||
cfg.update(**kwargs)
|
||||
return cfg
|
||||
|
||||
|
||||
timm_regnet_encoders = {
|
||||
'timm-regnetx_002': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_002"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 24, 56, 152, 368),
|
||||
'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_004': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_004"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 32, 64, 160, 384),
|
||||
'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_006': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_006"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 48, 96, 240, 528),
|
||||
'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_008': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_008"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 64, 128, 288, 672),
|
||||
'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_016': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_016"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 72, 168, 408, 912),
|
||||
'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_032': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_032"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 96, 192, 432, 1008),
|
||||
'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_040': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_040"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 80, 240, 560, 1360),
|
||||
'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_064': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_064"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 168, 392, 784, 1624),
|
||||
'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_080': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_080"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 80, 240, 720, 1920),
|
||||
'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_120': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_120"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 224, 448, 896, 2240),
|
||||
'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_160': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_160"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 256, 512, 896, 2048),
|
||||
'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22)
|
||||
},
|
||||
},
|
||||
'timm-regnetx_320': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnetx_320"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 336, 672, 1344, 2520),
|
||||
'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23)
|
||||
},
|
||||
},
|
||||
#regnety
|
||||
'timm-regnety_002': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_002"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 24, 56, 152, 368),
|
||||
'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_004': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_004"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 48, 104, 208, 440),
|
||||
'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_006': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_006"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 48, 112, 256, 608),
|
||||
'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_008': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_008"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 64, 128, 320, 768),
|
||||
'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_016': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_016"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 48, 120, 336, 888),
|
||||
'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_032': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_032"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 72, 216, 576, 1512),
|
||||
'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_040': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_040"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 128, 192, 512, 1088),
|
||||
'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_064': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_064"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 144, 288, 576, 1296),
|
||||
'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_080': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_080"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 168, 448, 896, 2016),
|
||||
'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_120': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_120"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 224, 448, 896, 2240),
|
||||
'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_160': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_160"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 224, 448, 1232, 3024),
|
||||
'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
'timm-regnety_320': {
|
||||
'encoder': RegNetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-regnety_320"],
|
||||
'params': {
|
||||
'out_channels': (3, 32, 232, 696, 1392, 3712),
|
||||
'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25)
|
||||
},
|
||||
},
|
||||
}
|
163
plugins/ai_method/packages/models/encoders/timm_res2net.py
Normal file
@ -0,0 +1,163 @@
|
||||
from ._base import EncoderMixin
|
||||
from timm.models.resnet import ResNet
|
||||
from timm.models.res2net import Bottle2neck
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Res2NetEncoder(ResNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.fc
|
||||
del self.global_pool
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.conv1, self.bn1, self.act1),
|
||||
nn.Sequential(self.maxpool, self.layer1),
|
||||
self.layer2,
|
||||
self.layer3,
|
||||
self.layer4,
|
||||
]
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
raise ValueError("Res2Net encoders do not support dilated mode")
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("fc.bias", None)
|
||||
state_dict.pop("fc.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
res2net_weights = {
|
||||
'timm-res2net50_26w_4s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'
|
||||
},
|
||||
'timm-res2net50_48w_2s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'
|
||||
},
|
||||
'timm-res2net50_14w_8s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth',
|
||||
},
|
||||
'timm-res2net50_26w_6s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth',
|
||||
},
|
||||
'timm-res2net50_26w_8s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth',
|
||||
},
|
||||
'timm-res2net101_26w_4s': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth',
|
||||
},
|
||||
'timm-res2next50': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth',
|
||||
}
|
||||
}
|
||||
|
||||
pretrained_settings = {}
|
||||
for model_name, sources in res2net_weights.items():
|
||||
pretrained_settings[model_name] = {}
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
|
||||
timm_res2net_encoders = {
|
||||
'timm-res2net50_26w_4s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 26,
|
||||
'block_args': {'scale': 4}
|
||||
},
|
||||
},
|
||||
'timm-res2net101_26w_4s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 23, 3],
|
||||
'base_width': 26,
|
||||
'block_args': {'scale': 4}
|
||||
},
|
||||
},
|
||||
'timm-res2net50_26w_6s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 26,
|
||||
'block_args': {'scale': 6}
|
||||
},
|
||||
},
|
||||
'timm-res2net50_26w_8s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 26,
|
||||
'block_args': {'scale': 8}
|
||||
},
|
||||
},
|
||||
'timm-res2net50_48w_2s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 48,
|
||||
'block_args': {'scale': 2}
|
||||
},
|
||||
},
|
||||
'timm-res2net50_14w_8s': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 14,
|
||||
'block_args': {'scale': 8}
|
||||
},
|
||||
},
|
||||
'timm-res2next50': {
|
||||
'encoder': Res2NetEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-res2next50"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': Bottle2neck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'base_width': 4,
|
||||
'cardinality': 8,
|
||||
'block_args': {'scale': 4}
|
||||
},
|
||||
}
|
||||
}
|
208
plugins/ai_method/packages/models/encoders/timm_resnest.py
Normal file
@ -0,0 +1,208 @@
|
||||
from ._base import EncoderMixin
|
||||
from timm.models.resnet import ResNet
|
||||
from timm.models.resnest import ResNestBottleneck
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResNestEncoder(ResNet, EncoderMixin):
|
||||
def __init__(self, out_channels, depth=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._depth = depth
|
||||
self._out_channels = out_channels
|
||||
self._in_channels = 3
|
||||
|
||||
del self.fc
|
||||
del self.global_pool
|
||||
|
||||
def get_stages(self):
|
||||
return [
|
||||
nn.Identity(),
|
||||
nn.Sequential(self.conv1, self.bn1, self.act1),
|
||||
nn.Sequential(self.maxpool, self.layer1),
|
||||
self.layer2,
|
||||
self.layer3,
|
||||
self.layer4,
|
||||
]
|
||||
|
||||
def make_dilated(self, output_stride):
|
||||
raise ValueError("ResNest encoders do not support dilated mode")
|
||||
|
||||
def forward(self, x):
|
||||
stages = self.get_stages()
|
||||
|
||||
features = []
|
||||
for i in range(self._depth + 1):
|
||||
x = stages[i](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
state_dict.pop("fc.bias", None)
|
||||
state_dict.pop("fc.weight", None)
|
||||
super().load_state_dict(state_dict, **kwargs)
|
||||
|
||||
|
||||
resnest_weights = {
|
||||
'timm-resnest14d': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'
|
||||
},
|
||||
'timm-resnest26d': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'
|
||||
},
|
||||
'timm-resnest50d': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth',
|
||||
},
|
||||
'timm-resnest101e': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
|
||||
},
|
||||
'timm-resnest200e': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
|
||||
},
|
||||
'timm-resnest269e': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
|
||||
},
|
||||
'timm-resnest50d_4s2x40d': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
|
||||
},
|
||||
'timm-resnest50d_1s4x24d': {
|
||||
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
|
||||
}
|
||||
}
|
||||
|
||||
pretrained_settings = {}
|
||||
for model_name, sources in resnest_weights.items():
|
||||
pretrained_settings[model_name] = {}
|
||||
for source_name, source_url in sources.items():
|
||||
pretrained_settings[model_name][source_name] = {
|
||||
"url": source_url,
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
|
||||
|
||||
timm_resnest_encoders = {
|
||||
'timm-resnest14d': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest14d"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [1, 1, 1, 1],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 32,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
}
|
||||
},
|
||||
'timm-resnest26d': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest26d"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [2, 2, 2, 2],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 32,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
}
|
||||
},
|
||||
'timm-resnest50d': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest50d"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 32,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
}
|
||||
},
|
||||
'timm-resnest101e': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest101e"],
|
||||
'params': {
|
||||
'out_channels': (3, 128, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 4, 23, 3],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 64,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
}
|
||||
},
|
||||
'timm-resnest200e': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest200e"],
|
||||
'params': {
|
||||
'out_channels': (3, 128, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 24, 36, 3],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 64,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
}
|
||||
},
|
||||
'timm-resnest269e': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest269e"],
|
||||
'params': {
|
||||
'out_channels': (3, 128, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 30, 48, 8],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 64,
|
||||
'avg_down': True,
|
||||
'base_width': 64,
|
||||
'cardinality': 1,
|
||||
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
|
||||
},
|
||||
},
|
||||
'timm-resnest50d_4s2x40d': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 32,
|
||||
'avg_down': True,
|
||||
'base_width': 40,
|
||||
'cardinality': 2,
|
||||
'block_args': {'radix': 4, 'avd': True, 'avd_first': True}
|
||||
}
|
||||
},
|
||||
'timm-resnest50d_1s4x24d': {
|
||||
'encoder': ResNestEncoder,
|
||||
"pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
|
||||
'params': {
|
||||
'out_channels': (3, 64, 256, 512, 1024, 2048),
|
||||
'block': ResNestBottleneck,
|
||||
'layers': [3, 4, 6, 3],
|
||||
'stem_type': 'deep',
|
||||
'stem_width': 32,
|
||||
'avg_down': True,
|
||||
'base_width': 24,
|
||||
'cardinality': 4,
|
||||
'block_args': {'radix': 1, 'avd': True, 'avd_first': True}
|
||||
}
|
||||
}
|
||||
}
|