diff --git a/plugins/ai_method/__init__.py b/plugins/ai_method/__init__.py index 9fd1ddb..719c5bd 100644 --- a/plugins/ai_method/__init__.py +++ b/plugins/ai_method/__init__.py @@ -1,5 +1,6 @@ from misc.utils import Register AI_METHOD = Register('AI Method') -from .packages.sta_net import STANet + from .main import AIPlugin +from .basic_cd import DPFCN, DVCA,RCNN \ No newline at end of file diff --git a/plugins/ai_method/basic_cd.py b/plugins/ai_method/basic_cd.py index b2d4330..4c69963 100644 --- a/plugins/ai_method/basic_cd.py +++ b/plugins/ai_method/basic_cd.py @@ -1,210 +1,183 @@ -from PyQt5.QtCore import QSettings, pyqtSignal, Qt -from PyQt5.QtGui import QIcon, QTextBlock, QTextCursor -from PyQt5.QtWidgets import QDialog, QLabel,QComboBox, QDialogButtonBox, QFileDialog, QHBoxLayout, QMessageBox, QProgressBar, QPushButton, QTextEdit, QVBoxLayout -import subprocess -import threading -import sys +from . import AI_METHOD +from plugins.misc import AlgFrontend from rscder.utils.icons import IconInstance +from rscder.utils.project import PairLayer +from osgeo import gdal, gdal_array import os -import sys -import ai_method.subprcess_python as sp -from ai_method import AI_METHOD -from misc.main import AlgFrontend -import abc +from rscder.utils.project import Project +from rscder.utils.geomath import geo2imageRC, imageRC2geo +import math +from .packages import get_model -class AIMethodDialog(QDialog): +class BasicAICD(AlgFrontend): - stage_end = pyqtSignal(int) - stage_log = pyqtSignal(str) + @staticmethod + def get_icon(): + return IconInstance().ARITHMETIC3 + + @staticmethod + def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, model=None, *args, **kargs): + + if model is None and send_message is not None: + send_message.emit('未能加载模型!') + return - ENV = 'base' - name = '' - stages = [] - setting_widget = None + ds1: gdal.Dataset = gdal.Open(pth1) + ds2: gdal.Dataset = gdal.Open(pth2) - def __init__(self, parent = None) -> None: - super().__init__(parent) + - # self.setting_widget:AlgFrontend = setting_widget + cell_size = (512, 512) + xsize = layer_parent.size[0] + ysize = layer_parent.size[1] - vlayout = QVBoxLayout() + band = ds1.RasterCount + yblocks = ysize // cell_size[1] + xblocks = xsize // cell_size[0] - hlayout = QHBoxLayout() - select_label = QLabel('模式选择:') - self.select_mode = QComboBox() - self.select_mode.addItem('----------', 'NoValue') - for stage in self.stages: - self.select_mode.addItem(stage[1], stage[0]) + driver = gdal.GetDriverByName('GTiff') + out_tif = os.path.join(Project().other_path, 'temp.tif') + out_ds = driver.Create(out_tif, xsize, ysize, 1, gdal.GDT_Float32) + geo = layer_parent.grid.geo + proj = layer_parent.grid.proj + out_ds.SetGeoTransform(geo) + out_ds.SetProjection(proj) - setting_btn = QPushButton(IconInstance().SELECT, '配置') + max_diff = 0 + min_diff = math.inf - hlayout.addWidget(select_label) - hlayout.addWidget(self.select_mode, 2) - hlayout.addWidget(setting_btn) - self.setting_args = [] + start1x, start1y = geo2imageRC(ds1.GetGeoTransform( + ), layer_parent.mask.xy[0], layer_parent.mask.xy[1]) + end1x, end1y = geo2imageRC(ds1.GetGeoTransform( + ), layer_parent.mask.xy[2], layer_parent.mask.xy[3]) - def show_setting(): - if self.setting_widget is None: - return - dialog = QDialog(parent) - vlayout = QVBoxLayout() - dialog.setLayout(vlayout) - widget = self.setting_widget.get_widget(dialog) - vlayout.addWidget(widget) + start2x, start2y = geo2imageRC(ds2.GetGeoTransform( + ), layer_parent.mask.xy[0], layer_parent.mask.xy[1]) + end2x, end2y = geo2imageRC(ds2.GetGeoTransform( + ), layer_parent.mask.xy[2], layer_parent.mask.xy[3]) - dbtn_box = QDialogButtonBox(dialog) - dbtn_box.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - dbtn_box.button(QDialogButtonBox.Ok).setText('确定') - dbtn_box.button(QDialogButtonBox.Cancel).setText('取消') - dbtn_box.button(QDialogButtonBox.Ok).clicked.connect(dialog.accept) - dbtn_box.button(QDialogButtonBox.Cancel).clicked.connect(dialog.reject) - vlayout.addWidget(dbtn_box, 1, Qt.AlignRight) - dialog.setMinimumHeight(500) - dialog.setMinimumWidth(900) - if dialog.exec_() == QDialog.Accepted: - self.setting_args = self.setting_widget.get_params(widget) + for j in range(yblocks + 1): # 该改这里了 + if send_message is not None: + send_message.emit(f'计算{j}/{yblocks}') + for i in range(xblocks +1): - setting_btn.clicked.connect(show_setting) + + block_xy1 = (start1x + i * cell_size[0], start1y+j * cell_size[1]) + block_xy2 = (start2x + i * cell_size[0], start2y+j * cell_size[1]) + block_xy = (i * cell_size[0], j * cell_size[1]) + + if block_xy1[1] > end1y or block_xy2[1] > end2y: + break + if block_xy1[0] > end1x or block_xy2[0] > end2x: + break + block_size = list(cell_size) - btnbox = QDialogButtonBox(self) - btnbox.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - btnbox.button(QDialogButtonBox.Ok).setText('确定') - btnbox.button(QDialogButtonBox.Cancel).setText('取消') + if block_xy[1] + block_size[1] > ysize: + block_xy[1] = (ysize - block_size[1]) + if block_xy[0] + block_size[0] > xsize: + block_xy[0] = ( xsize - block_size[0]) - processbar = QProgressBar(self) - processbar.setMaximum(1) - processbar.setMinimum(0) - processbar.setEnabled(False) - def run_stage(): - self.log = f'开始{self.current_stage}...\n' - btnbox.button(QDialogButtonBox.Cancel).setEnabled(True) - processbar.setMaximum(0) - processbar.setValue(0) - processbar.setEnabled(True) - q = threading.Thread(target=self.run_stage, args=(self.current_stage,)) - q.start() + if block_xy1[1] + block_size[1] > end1y: + block_xy1[1] = (end1y - block_size[1]) + if block_xy1[0] + block_size[0] > end1x: + block_xy1[0] = (end1x - block_size[0]) + if block_xy2[1] + block_size[1] > end2y: + block_xy2[1] = (end2y - block_size[1]) + if block_xy2[0] + block_size[0] > end2x: + block_xy2[0] = (end2x - block_size[0]) + + # if block_size1[0] * block_size1[1] == 0 or block_size2[0] * block_size2[1] == 0: + # continue + + block_data1 = ds1.ReadAsArray(*block_xy1, *block_size) + block_data2 = ds2.ReadAsArray(*block_xy2, *block_size) + # if block_data1.shape[0] == 0: + # continue + if band == 1: + block_data1 = block_data1[None, ...] + block_data2 = block_data2[None, ...] + + block_diff = model(block_data1, block_data2) + + out_ds.GetRasterBand(1).WriteArray(block_diff, *block_xy) + if send_message is not None: + + send_message.emit(f'完成{j}/{yblocks}') + del ds2 + del ds1 + out_ds.FlushCache() + del out_ds + if send_message is not None: + send_message.emit('归一化概率中...') + temp_in_ds = gdal.Open(out_tif) + + out_normal_tif = os.path.join(Project().cmi_path, '{}_{}_cmi.tif'.format( + layer_parent.name, int(np.random.rand() * 100000))) + out_normal_ds = driver.Create( + out_normal_tif, xsize, ysize, 1, gdal.GDT_Byte) + out_normal_ds.SetGeoTransform(geo) + out_normal_ds.SetProjection(proj) + # hist = np.zeros(256, dtype=np.int32) + for j in range(yblocks+1): + block_xy = (0, j * cell_size[1]) + if block_xy[1] > ysize: + break + block_size = (xsize, cell_size[1]) + if block_xy[1] + block_size[1] > ysize: + block_size = (xsize, ysize - block_xy[1]) + block_data = temp_in_ds.ReadAsArray(*block_xy, *block_size) + block_data = (block_data - min_diff) / (max_diff - min_diff) * 255 + block_data = block_data.astype(np.uint8) + out_normal_ds.GetRasterBand(1).WriteArray(block_data, *block_xy) + # hist_t, _ = np.histogram(block_data, bins=256, range=(0, 256)) + # hist += hist_t + # print(hist) + del temp_in_ds + del out_normal_ds + try: + os.remove(out_tif) + except: pass + if send_message is not None: + send_message.emit('计算完成') + return out_normal_tif - btnbox.button(QDialogButtonBox.Cancel).setEnabled(False) - btnbox.accepted.connect(run_stage) - btnbox.rejected.connect(self._stage_stop) - self.processbar = processbar - vlayout.addLayout(hlayout) - self.detail = QTextEdit(self) - vlayout.addWidget(self.detail) - vlayout.addWidget(processbar) - vlayout.addWidget(btnbox) - self.detail.setReadOnly(True) - # self.detail.copyAvailable(True) - self.detail.setText(f'等待开始...') - self.setLayout(vlayout) +@AI_METHOD.register +class DVCA(BasicAICD): - self.setMinimumHeight(500) - self.setMinimumWidth(500) - self.setWindowIcon(IconInstance().AI_DETECT) - self.setWindowTitle(self.name) + @staticmethod + def get_name(): + return 'DVCA' - self.stage_end.connect(self._stage_end) - self.log = f'等待开始...\n' - self.stage_log.connect(self._stage_log) - self.p = None + @staticmethod + def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs): + model = get_model('DVCA') + return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs) - @property - def current_stage(self): - if self.select_mode.currentData() == 'NoValue': - return None - return self.select_mode.currentData() - @property - def _python_base(self): - return os.path.abspath(os.path.join(os.path.dirname(sys.executable), '..', '..')) - @property - def activate_env(self): - script = os.path.join(self._python_base, 'Scripts','activate') - if self.ENV == 'base': - return script - else: - return script + ' ' + self.ENV +@AI_METHOD.register +class DPFCN(BasicAICD): - @property - def python_path(self): - if self.ENV == 'base': - return self._python_base - return os.path.join(self._python_base, 'envs', self.ENV, 'python.exe') + @staticmethod + def get_name(): + return 'DPFCN' + + @staticmethod + def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs): + model = get_model('DPFCN') + return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs) - @property - def workdir(self): - raise NotImplementedError() - - @abc.abstractmethod - def stage_script(self, stage): - return None - - def _stage_log(self, l): - self.log += l + '\n' - self.detail.setText(self.log) - cursor = self.detail.textCursor() - # QTextCursor - cursor.movePosition(QTextCursor.End) - self.detail.setTextCursor(cursor) - def closeEvent(self, e) -> None: - if self.p is None: - return super().closeEvent(e) - dialog = QMessageBox(QMessageBox.Warning, '警告', '关闭窗口将停止' + self.current_stage + '!', QMessageBox.Ok | QMessageBox.Cancel) - dialog.button(QMessageBox.Ok).clicked.connect(dialog.accepted) - dialog.button(QMessageBox.Cancel).clicked.connect(dialog.rejected) - dialog.show() - r = dialog.exec() - # print(r) - # print(QMessageBox.Rejected) - # print(QMessageBox.Accepted) - if r == QMessageBox.Cancel: - e.ignore() - return - return super().closeEvent(e) - def _stage_stop(self): - if self.p is not None: - try: - self.stage_log.emit(f'用户停止{self.stage}...') - self.p.kill() - except: - pass +@AI_METHOD.register +class RCNN(BasicAICD): - def _stage_end(self, c): - self.processbar.setMaximum(1) - self.processbar.setValue(0) - self.processbar.setEnabled(False) - self.log += '完成!' - self.detail.setText(self.log) + @staticmethod + def get_name(): + return 'RCNN' - def run_stage(self, stage): - if stage is None: - return - ss = self.stage_script(stage) - if ss is None: - self.stage_log.emit(f'开始{stage}时未发现脚本') - self.stage_end.emit(1) - return - - if self.workdir is None: - self.stage_log.emit(f'未配置工作目录!') - self.stage_end.emit(2) - return - args = [ss, *self.setting_args] - self.p = sp.SubprocessWraper(self.python_path, self.workdir, args, self.activate_env) - - for line in self.p.run(): - self.stage_log.emit(line) - self.stage_end.emit(self.p.returncode) - - -from collections import OrderedDict -def option_to_gui(parent, options:OrderedDict): - for key in options: - pass - -def gui_to_option(widget, options:OrderedDict): - pass \ No newline at end of file + @staticmethod + def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs): + model = get_model('RCNN') + return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs) \ No newline at end of file diff --git a/plugins/ai_method/main.py b/plugins/ai_method/main.py index 1b493cc..4672a3c 100644 --- a/plugins/ai_method/main.py +++ b/plugins/ai_method/main.py @@ -1,27 +1,194 @@ -from rscder.plugins.basic import BasicPlugin -from rscder.gui.actions import ActionManager -from ai_method import AI_METHOD -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox -from rscder.utils.icons import IconInstance -from plugins.misc.main import AlgFrontend from functools import partial +from threading import Thread +from plugins.misc.main import AlgFrontend +from rscder.gui.actions import ActionManager +from rscder.plugins.basic import BasicPlugin +from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox + +from rscder.gui.layercombox import PairLayerCombox +from rscder.utils.icons import IconInstance +from filter_collection import FILTER +from . import AI_METHOD +from thres import THRES +from misc import table_layer, AlgSelectWidget +from follow import FOLLOW +import os +class AICDMethod(QDialog): + def __init__(self,parent=None, alg:AlgFrontend=None): + super(AICDMethod, self).__init__(parent) + self.alg = alg + self.setWindowTitle('AI变化检测:{}'.format(alg.get_name())) + self.setWindowIcon(IconInstance().LOGO) + self.initUI() + self.setMinimumWidth(500) + + def initUI(self): + #图层 + self.layer_combox = PairLayerCombox(self) + layerbox = QHBoxLayout() + layerbox.addWidget(self.layer_combox) + + self.filter_select = AlgSelectWidget(self, FILTER) + self.param_widget = self.alg.get_widget(self) + self.unsupervised_menu = self.param_widget + self.thres_select = AlgSelectWidget(self, THRES) + + self.ok_button = QPushButton('确定', self) + self.ok_button.setIcon(IconInstance().OK) + self.ok_button.clicked.connect(self.accept) + self.ok_button.setDefault(True) + + self.cancel_button = QPushButton('取消', self) + self.cancel_button.setIcon(IconInstance().CANCEL) + self.cancel_button.clicked.connect(self.reject) + self.cancel_button.setDefault(False) + buttonbox=QDialogButtonBox(self) + buttonbox.addButton(self.ok_button,QDialogButtonBox.NoRole) + buttonbox.addButton(self.cancel_button,QDialogButtonBox.NoRole) + buttonbox.setCenterButtons(True) + + totalvlayout=QVBoxLayout() + totalvlayout.addLayout(layerbox) + totalvlayout.addWidget(self.filter_select) + if self.param_widget is not None: + totalvlayout.addWidget(self.param_widget) + totalvlayout.addWidget(self.thres_select) + totalvlayout.addStretch(1) + hbox = QHBoxLayout() + hbox.addStretch(1) + hbox.addWidget(buttonbox) + totalvlayout.addLayout(hbox) + # totalvlayout.addStretch() + + self.setLayout(totalvlayout) + +@FOLLOW.register +class AICDFollow(AlgFrontend): + + @staticmethod + def get_name(): + return 'AI变化检测' + + @staticmethod + def get_icon(): + return IconInstance().UNSUPERVISED + + @staticmethod + def get_widget(parent=None): + widget = QWidget(parent) + layer_combox = PairLayerCombox(widget) + layer_combox.setObjectName('layer_combox') + + filter_select = AlgSelectWidget(widget, FILTER) + filter_select.setObjectName('filter_select') + ai_select = AlgSelectWidget(widget, AI_METHOD) + ai_select.setObjectName('ai_select') + thres_select = AlgSelectWidget(widget, THRES) + thres_select.setObjectName('thres_select') + + totalvlayout=QVBoxLayout() + totalvlayout.addWidget(layer_combox) + totalvlayout.addWidget(filter_select) + totalvlayout.addWidget(ai_select) + totalvlayout.addWidget(thres_select) + totalvlayout.addStretch() + + widget.setLayout(totalvlayout) + + return widget + + @staticmethod + def get_params(widget:QWidget=None): + if widget is None: + return dict() + + layer_combox = widget.findChild(PairLayerCombox, 'layer_combox') + filter_select = widget.findChild(AlgSelectWidget, 'filter_select') + ai_select = widget.findChild(AlgSelectWidget, 'ai_select') + thres_select = widget.findChild(AlgSelectWidget, 'thres_select') + + layer1=layer_combox.layer1 + pth1 = layer_combox.layer1.path + pth2 = layer_combox.layer2.path + + falg, fparams = filter_select.get_alg_and_params() + cdalg, cdparams = ai_select.get_alg_and_params() + thalg, thparams = thres_select.get_alg_and_params() + + if cdalg is None or thalg is None: + return dict() + + return dict( + layer1=layer1, + pth1 = pth1, + pth2 = pth2, + falg = falg, + fparams = fparams, + cdalg = cdalg, + cdparams = cdparams, + thalg = thalg, + thparams = thparams, + ) + + @staticmethod + def run_alg(layer1=None, + pth1 = None, + pth2 = None, + falg = None, + fparams = None, + cdalg = None, + cdparams = None, + thalg = None, + thparams = None, + send_message = None): + + if cdalg is None or thalg is None: + return + + name = layer1.name + method_info = dict() + if falg is not None: + pth1 = falg.run_alg(pth1, name=name, send_message= send_message, **fparams) + pth2 = falg.run_alg(pth2, name=name, send_message= send_message, **fparams) + method_info['滤波算法'] = falg.get_name() + else: + method_info['滤波算法'] = '无' + + cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message= send_message,**cdparams) + + if falg is not None: + try: + os.remove(pth1) + os.remove(pth2) + # send_message.emit('删除临时文件') + except: + # send_message.emit('删除临时文件失败!') + pass + + thpth, th = thalg.run_alg(cdpth, name=name, send_message= send_message, **thparams) + + method_info['变化检测算法'] = cdalg.get_name() + method_info['二值化算法'] = thalg.get_name() + + table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = send_message) + class AIPlugin(BasicPlugin): + @staticmethod def info(): return { - 'name': 'AI 变化检测', - 'author': 'RSC', + 'name': 'AIPlugin', + 'description': 'AIPlugin', + 'author': 'RSCDER', 'version': '1.0.0', - 'description': 'AI 变化检测', - 'category': 'Ai method' } + def set_action(self): - ai_menu = ActionManager().ai_menu - # ai_menu.setIcon(IconInstance().UNSUPERVISED) - # ActionManager().change_detection_menu.addMenu(ai_menu) + AI_menu = QMenu('&AI变化检测', self.mainwindow) + AI_menu.setIcon(IconInstance().AI_DETECT) + ActionManager().change_detection_menu.addMenu(AI_menu) toolbar = ActionManager().add_toolbar('AI method') for key in AI_METHOD.keys(): alg:AlgFrontend = AI_METHOD[key] @@ -30,15 +197,58 @@ class AIPlugin(BasicPlugin): else: name = alg.get_name() - action = QAction(alg.get_icon(), name, ai_menu) + action = QAction(alg.get_icon(), name, AI_menu) func = partial(self.run_cd, alg) action.triggered.connect(func) toolbar.addAction(action) - ai_menu.addAction(action) + AI_menu.addAction(action) - def run_cd(self, alg:AlgFrontend): - dialog = alg.get_widget(self.mainwindow) - dialog.setWindowModality(Qt.NonModal) + def run_cd(self, alg): + # print(alg.get_name()) + dialog = AICDMethod(self.mainwindow, alg) dialog.show() - # dialog.exec() \ No newline at end of file + + if dialog.exec_() == QDialog.Accepted: + t = Thread(target=self.run_cd_alg, args=(dialog,)) + t.start() + + def run_cd_alg(self, w:AICDMethod): + + layer1=w.layer_combox.layer1 + pth1 = w.layer_combox.layer1.path + pth2 = w.layer_combox.layer2.path + name = layer1.layer_parent.name + + falg, fparams = w.filter_select.get_alg_and_params() + cdalg = w.alg + cdparams = w.alg.get_params(w.param_widget) + thalg, thparams = w.thres_select.get_alg_and_params() + + if cdalg is None or thalg is None: + return + method_info = dict() + if falg is not None: + pth1 = falg.run_alg(pth1, name=name, send_message=self.send_message, **fparams) + pth2 = falg.run_alg(pth2, name=name, send_message=self.send_message, **fparams) + method_info['滤波算法'] = falg.get_name() + + cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message=self.send_message,**cdparams) + + if falg is not None: + try: + os.remove(pth1) + os.remove(pth2) + # send_message.emit('删除临时文件') + except: + # send_message.emit('删除临时文件失败!') + pass + + thpth, th = thalg.run_alg(cdpth, name=name, send_message=self.send_message, **thparams) + + method_info['变化检测算法'] = cdalg.get_name() + method_info['二值化算法'] = thalg.get_name() + + table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = self.send_message) + # table_layer(thpth,layer1,name,self.send_message) + \ No newline at end of file diff --git a/plugins/ai_method/packages/STA_net/.gitignore b/plugins/ai_method/packages/STA_net/.gitignore deleted file mode 100644 index 1dcbc95..0000000 --- a/plugins/ai_method/packages/STA_net/.gitignore +++ /dev/null @@ -1,19 +0,0 @@ -.DS_Store -checkpoints/ -results/ -result/ -build/ -*.pth -*/*.pth -*/*/*.pth -torch.egg-info/ -*/*.pyc -*/**/*.pyc -*/**/**/*.pyc -*/**/**/**/*.pyc -*/**/**/**/**/*.pyc -*/*.so* -*/**/*.so* -*/**/*.dylib* -.idea - diff --git a/plugins/ai_method/packages/STA_net/LICENSE b/plugins/ai_method/packages/STA_net/LICENSE deleted file mode 100644 index df82bbd..0000000 --- a/plugins/ai_method/packages/STA_net/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -BSD 2-Clause License - -Copyright (c) 2020, justchenhao -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/plugins/ai_method/packages/STA_net/README.md b/plugins/ai_method/packages/STA_net/README.md deleted file mode 100644 index 7464b9c..0000000 --- a/plugins/ai_method/packages/STA_net/README.md +++ /dev/null @@ -1,233 +0,0 @@ -# STANet for remote sensing image change detection - -It is the implementation of the paper: A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection. - -Here, we provide the pytorch implementation of the spatial-temporal attention neural network (STANet) for remote sensing image change detection. - -![image-20200601213320103](/src/stanet-overview.png) - -## Change log -20210112: -- add the pretraining weight of PAM. [baidupan link](https://pan.baidu.com/s/1O1kg7JWunqd87ajtVMM6pg), code: 2rja - -20201105: - -- add a demo for quick start. -- add more dataset loader modes. -- enhance the image augmentation module (crop and rotation). - -20200601: - -- first commit - -## Prerequisites - -- windows or Linux -- Python 3.6+ -- CPU or NVIDIA GPU -- CUDA 9.0+ -- PyTorch > 1.0 -- visdom - -## Installation - -Clone this repo: - -```bash -git clone https://github.com/justchenhao/STANet -cd STANet -``` - -Install [PyTorch](http://pytorch.org/) 1.0+ and other dependencies (e.g., torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)) - -## Quick Start - -You can run a demo to get started. - -```bash -python demo.py -``` - -The input samples are in `samples`. After successfully run this script, you can find the predicted results in `samples/output`. - -## Prepare Datasets - -### download the change detection dataset - -You could download the LEVIR-CD at https://justchenhao.github.io/LEVIR/; - -The path list in the downloaded folder is as follows: - -``` -path to LEVIR-CD: - ├─train - │ ├─A - │ ├─B - │ ├─label - ├─val - │ ├─A - │ ├─B - │ ├─label - ├─test - │ ├─A - │ ├─B - │ ├─label -``` - -where A contains images of pre-phase, B contains images of post-phase, and label contains label maps. - -### cut bitemporal image pairs - -The original image in LEVIR-CD has a size of 1024 * 1024, which will consume too much memory when training. Therefore, we can cut the origin images into smaller patches (e.g., 256 * 256, or 512 * 512). In our paper, we cut the original image into patches of 256 * 256 size without overlapping. - -Make sure that the corresponding patch samples in the A, B, and label subfolders have the same name. - -## Train - -### Monitor training status - -To view training results and loss plots, run this script and click the URL [http://localhost:8097](http://localhost:8097/). - -```bash -python -m visdom.server -``` - -### train with our base method - -Run the following script: - -```bash -python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDF0 --lr 0.001 --model CDF0 --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop -``` - -Once finished, you could find the best model and the log files in the project folder. - -### train with Basic spatial-temporal Attention Module (BAM) method - -```bash -python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop -``` - -### train with Pyramid spatial-temporal Attention Module (PAM) method - -```bash -python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop -``` - -## Test - -You could edit the file val.py, for example: - -```python -if __name__ == '__main__': - opt = TestOptions().parse() # get training options - opt = make_val_opt(opt) - opt.phase = 'test' - opt.dataroot = 'path-to-LEVIR-CD-test' # data root - opt.dataset_mode = 'changedetection' - opt.n_class = 2 - opt.SA_mode = 'PAM' # BAM | PAM - opt.arch = 'mynet3' - opt.model = 'CDFA' # model type - opt.name = 'LEVIR-CDFAp0' # project name - opt.results_dir = './results/' # save predicted images - opt.epoch = 'best-epoch-in-val' # which epoch to test - opt.num_test = np.inf - val(opt) -``` - -then run the script: `python val.py`. Once finished, you can find the prediction log file in the project directory and predicted image files in the result directory. - -## Using other dataset mode - -### List mode - -```bash -list=train -lr=0.001 -dataset_mode=list -dataroot=path-to-dataroot -name=project_name - -python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop -``` - -In this case, the data structure should be the following: - -``` -""" -data structure --dataroot - ├─A - ├─train1.png - ... - ├─B - ├─train1.png - ... - ├─label - ├─train1.png - ... - └─list - ├─val.txt - ├─test.txt - └─train.txt - -# In list/train.txt, each low writes the filename of each sample, - # for example: - list/train.txt - train1.png - train2.png - ... -""" -``` - -### Concat mode for loading multiple datasets (each default mode is List) - -```bash -list=train -lr=0.001 -dataset_type=CD_data1,CD_data2,..., -val_dataset_type=CD_data -dataset_mode=concat -name=project_name - -python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop -``` - -Note, in this case, you should modify the `get_dataset_info` in `data/data_config.py` to add the corresponding ` dataset_name` and `dataroot` in it. - -```python -if dataset_type == 'LEVIR_CD': - root = 'path-to-LEVIR_CD-dataroot' -elif ... -# add more dataset ... -``` - -## Other TIPS - -For more Training/Testing guides, you could see the option files in the `./options/` folder. - -## Citation - -If you use this code for your research, please cite our papers. - -``` -@Article{rs12101662, -AUTHOR = {Chen, Hao and Shi, Zhenwei}, -TITLE = {A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection}, -JOURNAL = {Remote Sensing}, -VOLUME = {12}, -YEAR = {2020}, -NUMBER = {10}, -ARTICLE-NUMBER = {1662}, -URL = {https://www.mdpi.com/2072-4292/12/10/1662}, -ISSN = {2072-4292}, -DOI = {10.3390/rs12101662} -} -``` - -## Acknowledgments - -Our code is inspired by [pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). - - diff --git a/plugins/ai_method/packages/STA_net/STA.zip b/plugins/ai_method/packages/STA_net/STA.zip deleted file mode 100644 index c796184..0000000 Binary files a/plugins/ai_method/packages/STA_net/STA.zip and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/SimCRNN/Change_Detection_in_Multisource_VHR_Images_via_Deep_Siamese_Convolutional_Multiple-Layers_Recurrent_Neural_Network.pdf b/plugins/ai_method/packages/STA_net/SimCRNN/Change_Detection_in_Multisource_VHR_Images_via_Deep_Siamese_Convolutional_Multiple-Layers_Recurrent_Neural_Network.pdf deleted file mode 100644 index f92ad93..0000000 Binary files a/plugins/ai_method/packages/STA_net/SimCRNN/Change_Detection_in_Multisource_VHR_Images_via_Deep_Siamese_Convolutional_Multiple-Layers_Recurrent_Neural_Network.pdf and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/SimCRNN/SiamCRNN-master.zip b/plugins/ai_method/packages/STA_net/SimCRNN/SiamCRNN-master.zip deleted file mode 100644 index 852fff7..0000000 Binary files a/plugins/ai_method/packages/STA_net/SimCRNN/SiamCRNN-master.zip and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/data/__init__.py b/plugins/ai_method/packages/STA_net/data/__init__.py deleted file mode 100644 index 487cce3..0000000 --- a/plugins/ai_method/packages/STA_net/data/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -import importlib -import torch.utils.data -from data.base_dataset import BaseDataset -import torch -from torch.utils.data import ConcatDataset -from data.data_config import get_dataset_info - - -def find_dataset_using_name(dataset_name): - """Import the module "data/[dataset_name]_dataset.py". - - In the file, the class called DatasetNameDataset() will - be instantiated. It has to be a subclass of BaseDataset, - and it is case-insensitive. - """ - dataset_filename = "data." + dataset_name + "_dataset" - datasetlib = importlib.import_module(dataset_filename) - - dataset = None - target_dataset_name = dataset_name.replace('_', '') + 'dataset' - for name, cls in datasetlib.__dict__.items(): - if name.lower() == target_dataset_name.lower() \ - and issubclass(cls, BaseDataset): - dataset = cls - - if dataset is None: - raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) - - return dataset - - -def get_option_setter(dataset_name): - """Return the static method of the dataset class.""" - dataset_class = find_dataset_using_name(dataset_name) - return dataset_class.modify_commandline_options - - -def create_dataset(opt): - """Create a dataset given the option. - - This function wraps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from data import create_dataset - >>> dataset = create_dataset(opt) - """ - data_loader = CustomDatasetDataLoader(opt) - dataset = data_loader.load_data() - return dataset - - -def create_single_dataset(opt, dataset_type_): - # return dataset_class - dataset_class = find_dataset_using_name('list') - # get dataset root - opt.dataroot = get_dataset_info(dataset_type_) - return dataset_class(opt) - - -class CustomDatasetDataLoader(): - """Wrapper class of Dataset class that performs multi-threaded data loading""" - - def __init__(self, opt): - """Initialize this class - Step 1: create a dataset instance given the name [dataset_mode] - Step 2: create a multi-threaded data loader. - """ - self.opt = opt - print(opt.dataset_mode) - if opt.dataset_mode == 'concat': - # 叠加多个数据集 - datasets = [] - # 获取concat的多个数据集列表 - self.dataset_type = opt.dataset_type.split(',') - # 去除“,”的影响 - if self.dataset_type[-1] == '': - self.dataset_type = self.dataset_type[:-1] - for dataset_type_ in self.dataset_type: - dataset_ = create_single_dataset(opt, dataset_type_) - datasets.append(dataset_) - self.dataset = ConcatDataset(datasets) - else: - dataset_class = find_dataset_using_name(opt.dataset_mode) - - self.dataset = dataset_class(opt) - - print("dataset [%s] was created" % type(self.dataset).__name__) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batch_size, - shuffle=not opt.serial_batches, - num_workers=int(opt.num_threads), - drop_last=True) - - def load_data(self): - return self - - def __len__(self): - """Return the number of data in the dataset""" - return min(len(self.dataset), self.opt.max_dataset_size) - - def __iter__(self): - """Return a batch of data""" - for i, data in enumerate(self.dataloader): - if i * self.opt.batch_size >= self.opt.max_dataset_size: - break - yield data diff --git a/plugins/ai_method/packages/STA_net/data/base_dataset.py b/plugins/ai_method/packages/STA_net/data/base_dataset.py deleted file mode 100644 index 228a128..0000000 --- a/plugins/ai_method/packages/STA_net/data/base_dataset.py +++ /dev/null @@ -1,189 +0,0 @@ -"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. - -It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. -""" -import random -import numpy as np -import torch.utils.data as data -from PIL import Image,ImageFilter -import torchvision.transforms as transforms -from abc import ABC, abstractmethod -import math - - -class BaseDataset(data.Dataset, ABC): - """This class is an abstract base class (ABC) for datasets. - - To create a subclass, you need to implement the following four functions: - -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). - -- <__len__>: return the size of dataset. - -- <__getitem__>: get a data point. - -- : (optionally) add dataset-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the class; save the options in the class - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - self.opt = opt - self.root = opt.dataroot - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def __len__(self): - """Return the total number of images in the dataset.""" - return 0 - - @abstractmethod - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index - - a random integer for data indexing - - Returns: - a dictionary of data with their names. It ususally contains the data itself and its metadata information. - """ - pass - - -def get_params(opt, size, test=False): - w, h = size - new_h = h - new_w = w - angle = 0 - if opt.preprocess == 'resize_and_crop': - new_h = new_w = opt.load_size - if 'rotate' in opt.preprocess and test is False: - angle = random.uniform(0, opt.angle) - # print(angle) - new_w = int(new_w * math.cos(angle*math.pi/180) \ - + new_h*math.sin(angle*math.pi/180)) - new_h = int(new_h * math.cos(angle*math.pi/180) \ - + new_w*math.sin(angle*math.pi/180)) - new_w = min(new_w,new_h) - new_h = min(new_w,new_h) - # print(new_h,new_w) - x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) - y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) - # print('x,y: ',x,y) - flip = random.random() > 0.5 # left-right - return {'crop_pos': (x, y), 'flip': flip, 'angle': angle} - - -def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, - convert=True, normalize=True, test=False): - transform_list = [] - if grayscale: - transform_list.append(transforms.Grayscale(1)) - if 'resize' in opt.preprocess: - osize = [opt.load_size, opt.load_size] - transform_list.append(transforms.Resize(osize, method)) - # gaussian blur - if 'blur' in opt.preprocess: - transform_list.append(transforms.Lambda(lambda img: __blur(img))) - - if 'rotate' in opt.preprocess and test==False: - if params is None: - transform_list.append(transforms.RandomRotation(5)) - else: - degree = params['angle'] - transform_list.append(transforms.Lambda(lambda img: __rotate(img, degree))) - - if 'crop' in opt.preprocess: - if params is None: - transform_list.append(transforms.RandomCrop(opt.crop_size)) - else: - transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], - opt.crop_size))) - - if not opt.no_flip: - if params is None: - transform_list.append(transforms.RandomHorizontalFlip()) - elif params['flip']: - transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) - if convert: - transform_list += [transforms.ToTensor()] - if normalize: - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - return transforms.Compose(transform_list) - -def __blur(img): - if img.mode == 'RGB': - img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) - return img - -def __rotate(img, degree): - if img.mode =='RGB': - # set img padding == 128 - img2 = img.convert('RGBA') - rot = img2.rotate(degree,expand=1) - fff = Image.new('RGBA', rot.size, (128,) * 4) # 灰色 - out = Image.composite(rot, fff, rot) - img = out.convert(img.mode) - return img - else: - # set label padding == 0 - img2 = img.convert('RGBA') - rot = img2.rotate(degree,expand=1) - # a white image same size as rotated image - fff = Image.new('RGBA', rot.size, (255,) * 4) - # create a composite image using the alpha layer of rot as a mask - out = Image.composite(rot, fff, rot) - img = out.convert(img.mode) - return img - - -def __crop(img, pos, size): - - ow, oh = img.size - x1, y1 = pos - tw = th = size - # print('imagesize:',ow,oh) - # only 图像尺寸大于截取尺寸才截取,否则要padding - if (ow > tw and oh > th): - return img.crop((x1, y1, x1 + tw, y1 + th)) - - size = [size, size] - if img.mode == 'RGB': - new_image = Image.new('RGB', size, (128, 128, 128)) - new_image.paste(img, (int((1+size[1] - img.size[0]) / 2), - int((1+size[0] - img.size[1]) / 2))) - - return new_image - else: - new_image = Image.new(img.mode, size, 255) - # upper left corner - new_image.paste(img, (int((1 + size[1] - img.size[0]) / 2), - int((1 + size[0] - img.size[1]) / 2))) - return new_image - -def __flip(img, flip): - if flip: - return img.transpose(Image.FLIP_LEFT_RIGHT) - return img - - -def __print_size_warning(ow, oh, w, h): - """Print warning information about image size(only print once)""" - if not hasattr(__print_size_warning, 'has_printed'): - print("The image size needs to be a multiple of 4. " - "The loaded image size was (%d, %d), so it was adjusted to " - "(%d, %d). This adjustment will be done to all images " - "whose sizes are not multiples of 4" % (ow, oh, w, h)) - __print_size_warning.has_printed = True diff --git a/plugins/ai_method/packages/STA_net/data/changedetection_dataset.py b/plugins/ai_method/packages/STA_net/data/changedetection_dataset.py deleted file mode 100644 index 27d14ba..0000000 --- a/plugins/ai_method/packages/STA_net/data/changedetection_dataset.py +++ /dev/null @@ -1,63 +0,0 @@ -from data.base_dataset import BaseDataset, get_transform, get_params -from data.image_folder import make_dataset -from PIL import Image -import os -import numpy as np - - -class ChangeDetectionDataset(BaseDataset): - """This dataset class can load a set of images specified by the path --dataroot /path/to/data. - - datafolder-tree - dataroot:. - ├─A - ├─B - ├─label - """ - - def __init__(self, opt): - BaseDataset.__init__(self, opt) - folder_A = 'A' - folder_B = 'B' - folder_L = 'label' - self.istest = False - if opt.phase == 'test': - self.istest = True - self.A_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_A), opt.max_dataset_size)) - self.B_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_B), opt.max_dataset_size)) - if not self.istest: - self.L_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_L), opt.max_dataset_size)) - - print(self.A_paths) - - def __getitem__(self, index): - A_path = self.A_paths[index] - B_path = self.B_paths[index] - A_img = Image.open(A_path).convert('RGB') - B_img = Image.open(B_path).convert('RGB') - - transform_params = get_params(self.opt, A_img.size, test=self.istest) - # apply the same transform to A B L - transform = get_transform(self.opt, transform_params, test=self.istest) - - A = transform(A_img) - B = transform(B_img) - - if self.istest: - return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path} - - L_path = self.L_paths[index] - tmp = np.array(Image.open(L_path), dtype=np.uint32)/255 - L_img = Image.fromarray(tmp) - transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False, - test=self.istest) - L = transform_L(L_img) - - - return {'A': A, 'A_paths': A_path, - 'B': B, 'B_paths': B_path, - 'L': L, 'L_paths': L_path} - - def __len__(self): - """Return the total number of images in the dataset.""" - return len(self.A_paths) diff --git a/plugins/ai_method/packages/STA_net/data/data_config.py b/plugins/ai_method/packages/STA_net/data/data_config.py deleted file mode 100644 index b38e0da..0000000 --- a/plugins/ai_method/packages/STA_net/data/data_config.py +++ /dev/null @@ -1,12 +0,0 @@ - - -def get_dataset_info(dataset_type): - """define dataset_name and its dataroot""" - root = '' - if dataset_type == 'LEVIR_CD': - root = 'path-to-LEVIR_CD-dataroot' - # add more dataset ... - else: - raise TypeError("not define the %s" % dataset_type) - - return root diff --git a/plugins/ai_method/packages/STA_net/data/image_folder.py b/plugins/ai_method/packages/STA_net/data/image_folder.py deleted file mode 100644 index e95d8ec..0000000 --- a/plugins/ai_method/packages/STA_net/data/image_folder.py +++ /dev/null @@ -1,66 +0,0 @@ -"""A modified image folder class - -We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) -so that this class can load images from both current directory and its subdirectories. -""" - -import torch.utils.data as data - -from PIL import Image -import os -import os.path - - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 'tif', 'tiff' -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - -def make_dataset(dir, max_dataset_size=float("inf")): - images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - return images[:min(max_dataset_size, len(images))] - - -def default_loader(path): - return Image.open(path).convert('RGB') - - -class ImageFolder(data.Dataset): - - def __init__(self, root, transform=None, return_paths=False, - loader=default_loader): - imgs = make_dataset(root) - if len(imgs) == 0: - raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + - ",".join(IMG_EXTENSIONS))) - - self.root = root - self.imgs = imgs - self.transform = transform - self.return_paths = return_paths - self.loader = loader - - def __getitem__(self, index): - path = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.return_paths: - return img, path - else: - return img - - def __len__(self): - return len(self.imgs) diff --git a/plugins/ai_method/packages/STA_net/data/list_dataset.py b/plugins/ai_method/packages/STA_net/data/list_dataset.py deleted file mode 100644 index f6fe947..0000000 --- a/plugins/ai_method/packages/STA_net/data/list_dataset.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import collections -import numpy as np -from data.base_dataset import BaseDataset, get_transform, get_params -from PIL import Image - - -class listDataset(BaseDataset): - """ - data structure - -dataroot - ├─A - ├─train1.png - ... - ├─B - ├─train1.png - ... - ├─label - ├─train1.png - ... - └─list - ├─val.txt - ├─test.txt - └─train.txt - - # In list/train.txt, each low writes the filename of each sample, - # for example: - list/train.txt - train1.png - train2.png - ... - """ - def __init__(self, opt): - BaseDataset.__init__(self, opt) - self.split = opt.split - self.files = collections.defaultdict(list) - self.istest = False if opt.phase == 'train' else True # 是否为测试/验证;若是,对数据不做尺度变换和旋转变换; - - path = os.path.join(self.root, 'list', self.split + '.txt') - file_list = tuple(open(path, 'r')) - file_list = [id_.rstrip() for id_ in file_list] - self.files[self.split] = file_list - # print(file_list) - - def __len__(self): - return len(self.files[self.split]) - - def __getitem__(self, index): - paths = self.files[self.split][index] - - path = paths.split(" ") - - A_path = os.path.join(self.root,'A', path[0]) - B_path = os.path.join(self.root,'B', path[0]) - - L_path = os.path.join(self.root,'label', path[0]) - - A = Image.open(A_path).convert('RGB') - B = Image.open(B_path).convert('RGB') - - tmp = np.array(Image.open(L_path), dtype=np.uint32) / 255 - # print(tmp.max()) - L = Image.fromarray(tmp) - - transform_params = get_params(self.opt, A.size, self.istest) - - transform = get_transform(self.opt, transform_params, test=self.istest) - transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False, - test=self.istest) # 标签不做归一化 - A = transform(A) - B = transform(B) - - L = transform_L(L) - - return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path, 'L': L, 'L_paths': L_path} - - -# Leave code for debugging purposes - -if __name__ == '__main__': - from options.train_options import TrainOptions - opt = TrainOptions().parse() - opt.dataroot = r'I:\data\change_detection\LEVIR-CD-r' - opt.split = 'train' - opt.load_size = 500 - opt.crop_size = 500 - opt.batch_size = 1 - opt.dataset_mode = 'list' - from data import create_dataset - dataset = create_dataset(opt) - import matplotlib.pyplot as plt - from util.util import tensor2im - - for i, data in enumerate(dataset): - A = data['A'] - L = data['L'] - A = tensor2im(A) - color = tensor2im(L)[:,:,0]*255 - plt.imshow(A) - plt.show() - diff --git a/plugins/ai_method/packages/STA_net/demo.py b/plugins/ai_method/packages/STA_net/demo.py deleted file mode 100644 index c306526..0000000 --- a/plugins/ai_method/packages/STA_net/demo.py +++ /dev/null @@ -1,82 +0,0 @@ -import time -from options.test_options import TestOptions -from data import create_dataset -from models import create_model -import os -from util.util import save_images -import numpy as np -from util.util import mkdir -from PIL import Image - - -def make_val_opt(opt): - - # hard-code some parameters for test - opt.num_threads = 0 # test code only supports num_threads = 1 - opt.batch_size = 1 # test code only supports batch_size = 1 - opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. - opt.no_flip = True # no flip; comment this line if results on flipped images are needed. - opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed. - - opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. - opt.phase = 'val' - opt.preprocess = 'none1' - opt.isTrain = False - opt.aspect_ratio = 1 - opt.eval = True - - return opt - - - -def val(opt): - - dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - save_path = opt.results_dir - mkdir(save_path) - model.eval() - - for i, data in enumerate(dataset): - if i >= opt.num_test: # only apply our model to opt.num_test images. - break - model.set_input(data) # unpack data from data loader - pred = model.test(val=False) # run inference return pred - - img_path = model.get_image_paths() # get image paths - if i % 5 == 0: # save images to an HTML file - print('processing (%04d)-th image... %s' % (i, img_path)) - - save_images(pred, save_path, img_path) - - -def pred_image(data_root, results_dir): - opt = TestOptions().parse() # get training options - opt = make_val_opt(opt) - opt.phase = 'test' - opt.dataset_mode = 'changedetection' - opt.n_class = 2 - opt.SA_mode = 'PAM' - opt.arch = 'mynet3' - opt.model = 'CDFA' - opt.epoch = 'pam' - opt.num_test = np.inf - opt.name = 'pam' - opt.dataroot = data_root - opt.results_dir = results_dir - - val(opt) - - -if __name__ == '__main__': - # define the data_root and the results_dir - # note: - # data_root should have such structure: - # ├─A - # ├─B - # A for before images - # B for after images - data_root = './samples' - results_dir = './samples/output/' - pred_image(data_root, results_dir) diff --git a/plugins/ai_method/packages/STA_net/models/BAM.py b/plugins/ai_method/packages/STA_net/models/BAM.py deleted file mode 100644 index 5602234..0000000 --- a/plugins/ai_method/packages/STA_net/models/BAM.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - - -class BAM(nn.Module): - """ Basic self-attention module - """ - - def __init__(self, in_dim, ds=8, activation=nn.ReLU): - super(BAM, self).__init__() - self.chanel_in = in_dim - self.key_channel = self.chanel_in //8 - self.activation = activation - self.ds = ds # - self.pool = nn.AvgPool2d(self.ds) - print('ds: ',ds) - self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) - self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) - self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) - self.gamma = nn.Parameter(torch.zeros(1)) - - self.softmax = nn.Softmax(dim=-1) # - - def forward(self, input): - """ - inputs : - x : input feature maps( B X C X W X H) - returns : - out : self attention value + input feature - attention: B X N X N (N is Width*Height) - """ - x = self.pool(input) - m_batchsize, C, width, height = x.size() - proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) - proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) - energy = torch.bmm(proj_query, proj_key) # transpose check - energy = (self.key_channel**-.5) * energy - - attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) - - proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N - - out = torch.bmm(proj_value, attention.permute(0, 2, 1)) - out = out.view(m_batchsize, C, width, height) - - out = F.interpolate(out, [width*self.ds,height*self.ds]) - out = out + input - - return out - - diff --git a/plugins/ai_method/packages/STA_net/models/CDF0_model.py b/plugins/ai_method/packages/STA_net/models/CDF0_model.py deleted file mode 100644 index fa8cfec..0000000 --- a/plugins/ai_method/packages/STA_net/models/CDF0_model.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -import itertools -from .base_model import BaseModel -from . import backbone -import torch.nn.functional as F -from . import loss - - -class CDF0Model(BaseModel): - """ - change detection module: - feature extractor - contrastive loss - """ - @staticmethod - def modify_commandline_options(parser, is_train=True): - return parser - - def __init__(self, opt): - BaseModel.__init__(self, opt) - self.istest = opt.istest - # specify the training losses you want to print out. The training/test scripts will call - self.loss_names = ['f'] - # specify the images you want to save/display. The training/test scripts will call - self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B - if self.istest: - self.visual_names = ['A', 'B', 'pred_L_show'] - self.visual_features = ['feat_A', 'feat_B'] - # specify the models you want to save to the disk. The training/test scripts will call and . - if self.isTrain: - self.model_names = ['F'] - else: # during test time, only load Gs - self.model_names = ['F'] - self.ds=1 - # define networks (both Generators and discriminators) - self.n_class = 2 - self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device) - - if self.isTrain: - # define loss functions - self.criterionF = loss.BCL() - # initialize optimizers; schedulers will be automatically created by function . - self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizers.append(self.optimizer_G) - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): include the data itself and its metadata information. - - The option 'direction' can be used to swap domain A and domain B. - """ - self.A = input['A'].to(self.device) - self.B = input['B'].to(self.device) - if not self.istest: - self.L = input['L'].to(self.device).long() - self.image_paths = input['A_paths'] - if self.isTrain: - self.L_s = self.L.float() - self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest') - self.L_s[self.L_s == 1] = -1 # change - self.L_s[self.L_s == 0] = 1 # no change - - - def test(self, val=False): - """Forward function used in test time. - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with torch.no_grad(): - self.forward() - self.compute_visuals() - if val: # score - from util.metrics import RunningMetrics - metrics = RunningMetrics(self.n_class) - pred = self.pred_L.long() - - metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy()) - scores = metrics.get_cm() - return scores - - - def forward(self): - """Run forward pass; called by both functions and .""" - self.feat_A = self.netF(self.A) # f(A) - self.feat_B = self.netF(self.B) # f(B) - - self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) - # print(self.dist.shape) - self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True) - self.pred_L = (self.dist > 1).float() - self.pred_L_show = self.pred_L.long() - return self.pred_L - - def backward(self): - """Calculate the loss for generators F and L""" - # print(self.weight) - self.loss_f = self.criterionF(self.dist, self.L_s) - - self.loss = self.loss_f - if torch.isnan(self.loss): - print(self.image_paths) - - self.loss.backward() - - def optimize_parameters(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - # forward - self.forward() # compute feat and dist - - self.optimizer_G.zero_grad() # set G's gradients to zero - self.backward() # calculate graidents for G - self.optimizer_G.step() # udpate G's weights diff --git a/plugins/ai_method/packages/STA_net/models/CDFA_model.py b/plugins/ai_method/packages/STA_net/models/CDFA_model.py deleted file mode 100644 index 8b1d063..0000000 --- a/plugins/ai_method/packages/STA_net/models/CDFA_model.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -import itertools -from .base_model import BaseModel -from . import backbone -import torch.nn.functional as F -from . import loss - - -class CDFAModel(BaseModel): - """ - change detection module: - feature extractor+ spatial-temporal-self-attention - contrastive loss - """ - @staticmethod - def modify_commandline_options(parser, is_train=True): - return parser - def __init__(self, opt): - - BaseModel.__init__(self, opt) - # specify the training losses you want to print out. The training/test scripts will call - self.loss_names = ['f'] - # specify the images you want to save/display. The training/test scripts will call - if opt.phase == 'test': - self.istest = True - self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B - if self.istest: - self.visual_names = ['A', 'B', 'pred_L_show'] # visualizations for A and B - - self.visual_features = ['feat_A','feat_B'] - # specify the models you want to save to the disk. The training/test scripts will call and . - if self.isTrain: - self.model_names = ['F','A'] - else: # during test time, only load Gs - self.model_names = ['F','A'] - self.istest = False - self.ds = 1 - self.n_class =2 - self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device) - self.netA = backbone.CDSA(in_c=opt.f_c, ds=opt.ds, mode=opt.SA_mode).to(self.device) - - if self.isTrain: - # define loss functions - self.criterionF = loss.BCL() - - # initialize optimizers; schedulers will be automatically created by function . - self.optimizer_G = torch.optim.Adam(itertools.chain( - self.netF.parameters(), - ), lr=opt.lr*opt.lr_decay, betas=(opt.beta1, 0.999)) - self.optimizer_A = torch.optim.Adam(self.netA.parameters(), lr=opt.lr*1, betas=(opt.beta1, 0.999)) - self.optimizers.append(self.optimizer_G) - self.optimizers.append(self.optimizer_A) - - - def set_input(self, input): - self.A = input['A'].to(self.device) - self.B = input['B'].to(self.device) - if self.istest is False: - if 'L' in input.keys(): - self.L = input['L'].to(self.device).long() - - self.image_paths = input['A_paths'] - if self.isTrain: - self.L_s = self.L.float() - self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest') - self.L_s[self.L_s == 1] = -1 # change - self.L_s[self.L_s == 0] = 1 # no change - - - def test(self, val=False): - with torch.no_grad(): - self.forward() - self.compute_visuals() - if val: # 返回score - from util.metrics import RunningMetrics - metrics = RunningMetrics(self.n_class) - pred = self.pred_L.long() - - metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy()) - scores = metrics.get_cm() - return scores - else: - return self.pred_L.long() - - def forward(self): - """Run forward pass; called by both functions and .""" - self.feat_A = self.netF(self.A) # f(A) - self.feat_B = self.netF(self.B) # f(B) - - self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B) - - self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) # 特征距离 - - self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True) - - self.pred_L = (self.dist > 1).float() - # self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest') - self.pred_L_show = self.pred_L.long() - - return self.pred_L - - def backward(self): - self.loss_f = self.criterionF(self.dist, self.L_s) - - self.loss = self.loss_f - # print(self.loss) - self.loss.backward() - - def optimize_parameters(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - # forward - self.forward() # compute feat and dist - - self.set_requires_grad([self.netF, self.netA], True) - self.optimizer_G.zero_grad() # set G's gradients to zero - self.optimizer_A.zero_grad() - self.backward() # calculate graidents for G - self.optimizer_G.step() # udpate G's weights - self.optimizer_A.step() \ No newline at end of file diff --git a/plugins/ai_method/packages/STA_net/models/PAM2.py b/plugins/ai_method/packages/STA_net/models/PAM2.py deleted file mode 100644 index e93cbc2..0000000 --- a/plugins/ai_method/packages/STA_net/models/PAM2.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - - -class _PAMBlock(nn.Module): - ''' - The basic implementation for self-attention block/non-local block - Input/Output: - N * C * H * (2*W) - Parameters: - in_channels : the dimension of the input feature map - key_channels : the dimension after the key/query transform - value_channels : the dimension after the value transform - scale : choose the scale to partition the input feature maps - ds : downsampling scale - ''' - def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1): - super(_PAMBlock, self).__init__() - self.scale = scale - self.ds = ds - self.pool = nn.AvgPool2d(self.ds) - self.in_channels = in_channels - self.key_channels = key_channels - self.value_channels = value_channels - - self.f_key = nn.Sequential( - nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(self.key_channels) - ) - self.f_query = nn.Sequential( - nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(self.key_channels) - ) - self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, - kernel_size=1, stride=1, padding=0) - - - - def forward(self, input): - x = input - if self.ds != 1: - x = self.pool(input) - # input shape: b,c,h,2w - batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)//2 - - - local_y = [] - local_x = [] - step_h, step_w = h//self.scale, w//self.scale - for i in range(0, self.scale): - for j in range(0, self.scale): - start_x, start_y = i*step_h, j*step_w - end_x, end_y = min(start_x+step_h, h), min(start_y+step_w, w) - if i == (self.scale-1): - end_x = h - if j == (self.scale-1): - end_y = w - local_x += [start_x, end_x] - local_y += [start_y, end_y] - - value = self.f_value(x) - query = self.f_query(x) - key = self.f_key(x) - - value = torch.stack([value[:, :, :, :w], value[:,:,:,w:]], 4) # B*N*H*W*2 - query = torch.stack([query[:, :, :, :w], query[:,:,:,w:]], 4) # B*N*H*W*2 - key = torch.stack([key[:, :, :, :w], key[:,:,:,w:]], 4) # B*N*H*W*2 - - - local_block_cnt = 2*self.scale*self.scale - - # self-attention func - def func(value_local, query_local, key_local): - batch_size_new = value_local.size(0) - h_local, w_local = value_local.size(2), value_local.size(3) - value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1) - - query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1) - query_local = query_local.permute(0, 2, 1) - key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1) - - sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication - sim_map = (self.key_channels**-.5) * sim_map - sim_map = F.softmax(sim_map, dim=-1) - - context_local = torch.bmm(value_local, sim_map.permute(0,2,1)) - # context_local = context_local.permute(0, 2, 1).contiguous() - context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2) - return context_local - - # Parallel Computing to speed up - # reshape value_local, q, k - v_list = [value[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)] - v_locals = torch.cat(v_list,dim=0) - q_list = [query[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)] - q_locals = torch.cat(q_list,dim=0) - k_list = [key[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)] - k_locals = torch.cat(k_list,dim=0) - # print(v_locals.shape) - context_locals = func(v_locals,q_locals,k_locals) - - context_list = [] - for i in range(0, self.scale): - row_tmp = [] - for j in range(0, self.scale): - left = batch_size*(j+i*self.scale) - right = batch_size*(j+i*self.scale) + batch_size - tmp = context_locals[left:right] - row_tmp.append(tmp) - context_list.append(torch.cat(row_tmp, 3)) - - context = torch.cat(context_list, 2) - context = torch.cat([context[:,:,:,:,0],context[:,:,:,:,1]],3) - - - if self.ds !=1: - context = F.interpolate(context, [h*self.ds, 2*w*self.ds]) - - return context - - -class PAMBlock(_PAMBlock): - def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1): - if key_channels == None: - key_channels = in_channels//8 - if value_channels == None: - value_channels = in_channels - super(PAMBlock, self).__init__(in_channels,key_channels,value_channels,scale,ds) - - -class PAM(nn.Module): - """ - PAM module - """ - - def __init__(self, in_channels, out_channels, sizes=([1]), ds=1): - super(PAM, self).__init__() - self.group = len(sizes) - self.stages = [] - self.ds = ds # output stride - self.value_channels = out_channels - self.key_channels = out_channels // 8 - - - self.stages = nn.ModuleList( - [self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds) - for size in sizes]) - self.conv_bn = nn.Sequential( - nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0,bias=False), - # nn.BatchNorm2d(out_channels), - ) - - def _make_stage(self, in_channels, key_channels, value_channels, size, ds): - return PAMBlock(in_channels,key_channels,value_channels,size,ds) - - def forward(self, feats): - priors = [stage(feats) for stage in self.stages] - - # concat - context = [] - for i in range(0, len(priors)): - context += [priors[i]] - output = self.conv_bn(torch.cat(context, 1)) - - return output diff --git a/plugins/ai_method/packages/STA_net/models/__init__.py b/plugins/ai_method/packages/STA_net/models/__init__.py deleted file mode 100644 index fc01113..0000000 --- a/plugins/ai_method/packages/STA_net/models/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -"""This package contains modules related to objective functions, optimizations, and network architectures. - -To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. -You need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate loss, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - -In the function <__init__>, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): define networks used in our training. - -- self.visual_names (str list): specify the images that you want to display and save. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. - -Now you can use the model class by specifying flag '--model dummy'. -See our template model class 'template_model.py' for more details. -""" - -import importlib -from models.base_model import BaseModel - - -def find_model_using_name(model_name): - """Import the module "models/[model_name]_model.py". - - In the file, the class called DatasetNameModel() will - be instantiated. It has to be a subclass of BaseModel, - and it is case-insensitive. - """ - model_filename = "models." + model_name + "_model" - modellib = importlib.import_module(model_filename) - model = None - target_model_name = model_name.replace('_', '') + 'model' - for name, cls in modellib.__dict__.items(): - if name.lower() == target_model_name.lower() \ - and issubclass(cls, BaseModel): - model = cls - - if model is None: - print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) - exit(0) - - return model - - -def get_option_setter(model_name): - """Return the static method of the model class.""" - model_class = find_model_using_name(model_name) - return model_class.modify_commandline_options - - -def create_model(opt): - """Create a model given the option. - - This function warps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from models import create_model - >>> model = create_model(opt) - """ - model = find_model_using_name(opt.model) - instance = model(opt) - print("model [%s] was created" % type(instance).__name__) - return instance diff --git a/plugins/ai_method/packages/STA_net/models/backbone.py b/plugins/ai_method/packages/STA_net/models/backbone.py deleted file mode 100644 index fc608ec..0000000 --- a/plugins/ai_method/packages/STA_net/models/backbone.py +++ /dev/null @@ -1,51 +0,0 @@ -# coding: utf-8 -import torch.nn as nn -import torch -from .mynet3 import F_mynet3 -from .BAM import BAM -from .PAM2 import PAM as PAM - - - -def define_F(in_c, f_c, type='unet'): - if type == 'mynet3': - print("using mynet3 backbone") - return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32) - else: - NotImplementedError('no such F type!') - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - - -class CDSA(nn.Module): - """self attention module for change detection - - """ - def __init__(self, in_c, ds=1, mode='BAM'): - super(CDSA, self).__init__() - self.in_C = in_c - self.ds = ds - print('ds: ',self.ds) - self.mode = mode - if self.mode == 'BAM': - self.Self_Att = BAM(self.in_C, ds=self.ds) - elif self.mode == 'PAM': - self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds) - self.apply(weights_init) - - def forward(self, x1, x2): - height = x1.shape[3] - x = torch.cat((x1, x2), 3) - x = self.Self_Att(x) - return x[:,:,:,0:height], x[:,:,:,height:] - - - - diff --git a/plugins/ai_method/packages/STA_net/models/base_model.py b/plugins/ai_method/packages/STA_net/models/base_model.py deleted file mode 100644 index 78cb731..0000000 --- a/plugins/ai_method/packages/STA_net/models/base_model.py +++ /dev/null @@ -1,333 +0,0 @@ -import os -import torch -from collections import OrderedDict -from abc import ABC, abstractmethod -from torch.optim import lr_scheduler - - -def get_scheduler(optimizer, opt): - """Return a learning rate scheduler - - Parameters: - optimizer -- the optimizer of the network - opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  - opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine - - For 'linear', we keep the same learning rate for the first epochs - and linearly decay the rate to zero over the next epochs. - For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. - See https://pytorch.org/docs/stable/optim.html for more details. - """ - if opt.lr_policy == 'linear': - def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) - return lr_l - scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) - elif opt.lr_policy == 'step': - scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) - elif opt.lr_policy == 'plateau': - scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) - elif opt.lr_policy == 'cosine': - scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) - else: - return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) - return scheduler - - -class BaseModel(ABC): - """This class is an abstract base class (ABC) for models. - To create a subclass, you need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate losses, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the BaseModel class. - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - - When creating your custom class, you need to implement your own initialization. - In this fucntion, you should first call - Then, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): specify the images that you want to display and save. - -- self.visual_names (str list): define networks used in our training. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. - """ - self.opt = opt - self.gpu_ids = opt.gpu_ids - self.isTrain = opt.isTrain - self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU - self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir - if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. - torch.backends.cudnn.benchmark = True - self.loss_names = [] - self.model_names = [] - self.visual_names = [] - self.visual_features = [] - self.optimizers = [] - self.image_paths = [] - self.metric = 0 # used for learning rate policy 'plateau' - self.istest = True if opt.phase == 'test' else False # 如果是测试,该模式下,没有标注样本; - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new model-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): includes the data itself and its metadata information. - """ - pass - - @abstractmethod - def forward(self): - """Run forward pass; called by both functions and .""" - pass - - @abstractmethod - def optimize_parameters(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - pass - - def setup(self, opt): - """Load and print networks; create schedulers - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - if self.isTrain: - self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers] - if not self.isTrain or opt.continue_train: - load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch - self.load_networks(load_suffix) - self.print_networks(opt.verbose) - - def eval(self): - """Make models eval mode during test time""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, 'net' + name) - net.eval() - - - def train(self): - """Make models train mode during train time""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, 'net' + name) - net.train() - - def test(self): - """Forward function used in test time. - - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with torch.no_grad(): - self.forward() - self.compute_visuals() - - def compute_visuals(self): - """Calculate additional output images for visdom and HTML visualization""" - pass - - def get_image_paths(self): - """ Return image paths that are used to load current data""" - return self.image_paths - - def update_learning_rate(self): - """Update learning rates for all the networks; called at the end of every epoch""" - for scheduler in self.schedulers: - if self.opt.lr_policy == 'plateau': - scheduler.step(self.metric) - else: - scheduler.step() - - lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate = %.7f' % lr) - - def get_current_visuals(self): - """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" - visual_ret = OrderedDict() - for name in self.visual_names: - if isinstance(name, str): - visual_ret[name] = getattr(self, name) - return visual_ret - - def get_current_losses(self): - """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" - errors_ret = OrderedDict() - for name in self.loss_names: - if isinstance(name, str): - errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number - return errors_ret - - def save_networks(self, epoch): - """Save all the networks to the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - for name in self.model_names: - if isinstance(name, str): - save_filename = '%s_net_%s.pth' % (epoch, name) - save_path = os.path.join(self.save_dir, save_filename) - net = getattr(self, 'net' + name) - - if len(self.gpu_ids) > 0 and torch.cuda.is_available(): - # torch.save(net.module.cpu().state_dict(), save_path) - torch.save(net.cpu().state_dict(), save_path) - net.cuda(self.gpu_ids[0]) - else: - torch.save(net.cpu().state_dict(), save_path) - - def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): - """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" - key = keys[i] - if i + 1 == len(keys): # at the end, pointing to a parameter/buffer - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'running_mean' or key == 'running_var'): - if getattr(module, key) is None: - state_dict.pop('.'.join(keys)) - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'num_batches_tracked'): - state_dict.pop('.'.join(keys)) - else: - self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) - - def get_visual(self, name): - visual_ret = {} - visual_ret[name] = getattr(self, name) - return visual_ret - - def pred_large(self, A, B, input_size=256, stride=0): - """ - 输入前后时相的大图,获得预测结果 - 假定预测结果中心部分为准确,边缘padding = (input_size-stride)/2 - :param A: tensor, N*C*H*W - :param B: tensor, N*C*H*W - :param input_size: int, 输入网络的图像size - :param stride: int, 预测时的跨步 - :return: pred, tensor, N*1*H*W - """ - import math - import numpy as np - n, c, h, w = A.shape - assert A.shape == B.shape - # 分块数量 - n_h = math.ceil((h - input_size) / stride) + 1 - n_w = math.ceil((w - input_size) / stride) + 1 - # 重新计算长宽 - new_h = (n_h - 1) * stride + input_size - new_w = (n_w - 1) * stride + input_size - print("new_h: ", new_h) - print("new_w: ", new_w) - print("n_h: ", n_h) - print("n_w: ", n_w) - new_A = torch.zeros([n, c, new_h, new_w], dtype=torch.float32) - new_B = torch.zeros([n, c, new_h, new_w], dtype=torch.float32) - new_A[:, :, :h, :w] = A - new_B[:, :, :h, :w] = B - new_pred = torch.zeros([n, 1, new_h, new_w], dtype=torch.uint8) - del A - del B - # - for i in range(0, new_h - input_size + 1, stride): - for j in range(0, new_w - input_size + 1, stride): - left = j - right = input_size + j - top = i - bottom = input_size + i - patch_A = new_A[:, :, top:bottom, left:right] - patch_B = new_B[:, :, top:bottom, left:right] - # print(left,' ',right,' ', top,' ', bottom) - self.A = patch_A.to(self.device) - self.B = patch_B.to(self.device) - with torch.no_grad(): - patch_pred = self.forward() - new_pred[:, :, top:bottom, left:right] = patch_pred.detach().cpu() - pred = new_pred[:, :, :h, :w] - return pred - - def load_networks(self, epoch): - """Load all the networks from the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - for name in self.model_names: - if isinstance(name, str): - load_filename = '%s_net_%s.pth' % (epoch, name) - load_path = os.path.join(self.save_dir, load_filename) - net = getattr(self, 'net' + name) - # if isinstance(net, torch.nn.DataParallel): - # net = net.module - # net = net.module # 适配保存的module - print('loading the model from %s' % load_path) - # if you are using PyTorch newer than 0.4 (e.g., built from - # GitHub source), you can remove str() on self.device - state_dict = torch.load(load_path, map_location=str(self.device)) - - if hasattr(state_dict, '_metadata'): - del state_dict._metadata - # patch InstanceNorm checkpoints prior to 0.4 - for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop - self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) - # print(key) - net.load_state_dict(state_dict,strict=False) - - def print_networks(self, verbose): - """Print the total number of parameters in the network and (if verbose) network architecture - - Parameters: - verbose (bool) -- if verbose: print the network architecture - """ - print('---------- Networks initialized -------------') - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, 'net' + name) - num_params = 0 - for param in net.parameters(): - num_params += param.numel() - if verbose: - print(net) - print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) - print('-----------------------------------------------') - - def set_requires_grad(self, nets, requires_grad=False): - """Set requies_grad=Fasle for all the networks to avoid unnecessary computations - Parameters: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not - """ - if not isinstance(nets, list): - nets = [nets] - for net in nets: - if net is not None: - for param in net.parameters(): - param.requires_grad = requires_grad - - - - -if __name__ == '__main__': - - A = torch.rand([1,3,512,512],dtype=torch.float32) - B = torch.rand([1,3,512,512],dtype=torch.float32) diff --git a/plugins/ai_method/packages/STA_net/models/loss.py b/plugins/ai_method/packages/STA_net/models/loss.py deleted file mode 100644 index 9907174..0000000 --- a/plugins/ai_method/packages/STA_net/models/loss.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch.nn as nn -import torch - - -class BCL(nn.Module): - """ - batch-balanced contrastive loss - no-change,1 - change,-1 - """ - - def __init__(self, margin=2.0): - super(BCL, self).__init__() - self.margin = margin - - def forward(self, distance, label): - label[label==255] = 1 - mask = (label != 255).float() - distance = distance * mask - pos_num = torch.sum((label==1).float())+0.0001 - neg_num = torch.sum((label==-1).float())+0.0001 - - loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num - loss_2 = torch.sum((1-label) / 2 * mask * - torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) - ) / neg_num - loss = loss_1 + loss_2 - return loss - diff --git a/plugins/ai_method/packages/STA_net/models/mynet3.py b/plugins/ai_method/packages/STA_net/models/mynet3.py deleted file mode 100644 index 22585b5..0000000 --- a/plugins/ai_method/packages/STA_net/models/mynet3.py +++ /dev/null @@ -1,352 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -import math - - -class F_mynet3(nn.Module): - def __init__(self, backbone='resnet18',in_c=3, f_c=64, output_stride=8): - self.in_c = in_c - super(F_mynet3, self).__init__() - self.module = mynet3(backbone=backbone, output_stride=output_stride, f_c=f_c, in_c=self.in_c) - def forward(self, input): - return self.module(input) - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', -} - -def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): - """ - output, low_level_feat: - 512, 64 - """ - print(in_c) - model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) - if in_c != 3: - pretrained = False - if pretrained: - model._load_pretrained_model(model_urls['resnet34']) - return model - -def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): - """ - output, low_level_feat: - 512, 256, 128, 64, 64 - """ - model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c) - if in_c !=3: - pretrained=False - if pretrained: - model._load_pretrained_model(model_urls['resnet18']) - return model - -def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): - """ - output, low_level_feat: - 2048, 256 - """ - model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) - if in_c !=3: - pretrained=False - if pretrained: - model._load_pretrained_model(model_urls['resnet50']) - return model - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): - super(BasicBlock, self).__init__() - - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, - dilation=dilation, padding=dilation, bias=False) - self.bn1 = BatchNorm(planes) - - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = BatchNorm(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = BatchNorm(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - dilation=dilation, padding=dilation, bias=False) - self.bn2 = BatchNorm(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = BatchNorm(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3): - - self.inplanes = 64 - self.in_c = in_c - print('in_c: ',self.in_c) - super(ResNet, self).__init__() - blocks = [1, 2, 4] - if output_stride == 32: - strides = [1, 2, 2, 2] - dilations = [1, 1, 1, 1] - elif output_stride == 16: - strides = [1, 2, 2, 1] - dilations = [1, 1, 1, 2] - elif output_stride == 8: - strides = [1, 2, 1, 1] - dilations = [1, 1, 2, 4] - elif output_stride == 4: - strides = [1, 1, 1, 1] - dilations = [1, 2, 4, 8] - else: - raise NotImplementedError - - # Modules - self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = BatchNorm(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) - self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) - self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) - self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) - # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) - self._init_weight() - - - def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - BatchNorm(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) - - return nn.Sequential(*layers) - - def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - BatchNorm(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, - downsample=downsample, BatchNorm=BatchNorm)) - self.inplanes = planes * block.expansion - for i in range(1, len(blocks)): - layers.append(block(self.inplanes, planes, stride=1, - dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) - - return nn.Sequential(*layers) - - def forward(self, input): - x = self.conv1(input) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) # | 4 - x = self.layer1(x) # | 4 - low_level_feat2 = x # | 4 - x = self.layer2(x) # | 8 - low_level_feat3 = x - x = self.layer3(x) # | 16 - low_level_feat4 = x - x = self.layer4(x) # | 32 - return x, low_level_feat2, low_level_feat3, low_level_feat4 - - def _init_weight(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _load_pretrained_model(self, model_path): - pretrain_dict = model_zoo.load_url(model_path) - model_dict = {} - state_dict = self.state_dict() - for k, v in pretrain_dict.items(): - if k in state_dict: - model_dict[k] = v - state_dict.update(model_dict) - self.load_state_dict(state_dict) - - - -def build_backbone(backbone, output_stride, BatchNorm, in_c=3): - if backbone == 'resnet50': - return ResNet50(output_stride, BatchNorm, in_c=in_c) - elif backbone == 'resnet34': - return ResNet34(output_stride, BatchNorm, in_c=in_c) - elif backbone == 'resnet18': - return ResNet18(output_stride, BatchNorm, in_c=in_c) - else: - raise NotImplementedError - - - -class DR(nn.Module): - def __init__(self, in_d, out_d): - super(DR, self).__init__() - self.in_d = in_d - self.out_d = out_d - self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False) - self.bn1 = nn.BatchNorm2d(self.out_d) - self.relu = nn.ReLU() - - def forward(self, input): - x = self.conv1(input) - x = self.bn1(x) - x = self.relu(x) - return x - - -class Decoder(nn.Module): - def __init__(self, fc, BatchNorm): - super(Decoder, self).__init__() - self.fc = fc - self.dr2 = DR(64, 96) - self.dr3 = DR(128, 96) - self.dr4 = DR(256, 96) - self.dr5 = DR(512, 96) - self.last_conv = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1, bias=False), - BatchNorm(256), - nn.ReLU(), - nn.Dropout(0.5), - nn.Conv2d(256, self.fc, kernel_size=1, stride=1, padding=0, bias=False), - BatchNorm(self.fc), - nn.ReLU(), - ) - - self._init_weight() - - - def forward(self, x,low_level_feat2, low_level_feat3, low_level_feat4): - - # x1 = self.dr1(low_level_feat1) - x2 = self.dr2(low_level_feat2) - x3 = self.dr3(low_level_feat3) - x4 = self.dr4(low_level_feat4) - x = self.dr5(x) - x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True) - # x2 = F.interpolate(x2, size=x3.size()[2:], mode='bilinear', align_corners=True) - x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True) - x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True) - - x = torch.cat((x, x2, x3, x4), dim=1) - - x = self.last_conv(x) - - return x - - def _init_weight(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - torch.nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - -def build_decoder(fc, backbone, BatchNorm): - return Decoder(fc, BatchNorm) - - -class mynet3(nn.Module): - def __init__(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3): - super(mynet3, self).__init__() - print('arch: mynet3') - BatchNorm = nn.BatchNorm2d - self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c) - self.decoder = build_decoder(f_c, backbone, BatchNorm) - - if freeze_bn: - self.freeze_bn() - - def forward(self, input): - x, f2, f3, f4 = self.backbone(input) - x = self.decoder(x, f2, f3, f4) - return x - - def freeze_bn(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eval() - - diff --git a/plugins/ai_method/packages/STA_net/options/__init__.py b/plugins/ai_method/packages/STA_net/options/__init__.py deleted file mode 100644 index e7eedeb..0000000 --- a/plugins/ai_method/packages/STA_net/options/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/plugins/ai_method/packages/STA_net/options/base_options.py b/plugins/ai_method/packages/STA_net/options/base_options.py deleted file mode 100644 index c49714f..0000000 --- a/plugins/ai_method/packages/STA_net/options/base_options.py +++ /dev/null @@ -1,147 +0,0 @@ -import argparse -import os -from util import util -import torch -import models -import data - - -class BaseOptions(): - """This class defines options used during both training and test time. - - It also implements several helper functions such as parsing, printing, and saving the options. - It also gathers additional options defined in functions in both dataset class and model class. - """ - - def __init__(self): - """Reset the class; indicates the class hasn't been initailized""" - self.initialized = False - - def initialize(self, parser): - """Define the common options that are used in both training and test.""" - # basic parameters - parser.add_argument('--dataroot', type=str, default='./LEVIR-CD', help='path to images (should have subfolders A, B, label)') - parser.add_argument('--val_dataroot', type=str, default='./LEVIR-CD', help='path to images in the val phase (should have subfolders A, B, label)') - - parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') - parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') - parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') - # model parameters - parser.add_argument('--model', type=str, default='CDF0', help='chooses which model to use. [CDF0 | CDFA]') - parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ') - parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB') - parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3') - parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num') - parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes') - parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') - parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') - parser.add_argument('--SA_mode', type=str, default='BAM', help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...') - # dataset parameters - parser.add_argument('--dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat | list | json]') - parser.add_argument('--val_dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat| list | json]') - parser.add_argument('--dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]') - parser.add_argument('--val_dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]') - parser.add_argument('--split', type=str, default='train', help='chooses wihch list-file to open when use listDataset. [train | val | test]') - parser.add_argument('--val_split', type=str, default='val', help='chooses wihch list-file to open when use listDataset. [train | val | test]') - parser.add_argument('--json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase') - parser.add_argument('--val_json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase') - parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate') - parser.add_argument('--angle', type=int, default=0, help='rotate angle') - parser.add_argument('--istest', type=bool, default=False, help='True for the case without label') - - parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') - parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') - parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') - parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') - parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | none]') - parser.add_argument('--no_flip', type=bool, default=True, help='if specified, do not flip(left-right) the images for data augmentation') - - parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') - # additional parameters - parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') - parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') - parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') - parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - self.initialized = True - return parser - - def gather_options(self): - """Initialize our parser with basic options(only once). - Add additional model-specific and dataset-specific options. - These options are defined in the function - in model and dataset classes. - """ - if not self.initialized: # check if it has been initialized - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser = self.initialize(parser) - - # get the basic options - opt, _ = parser.parse_known_args() - - # modify model-related parser options - model_name = opt.model - model_option_setter = models.get_option_setter(model_name) - parser = model_option_setter(parser, self.isTrain) - opt, _ = parser.parse_known_args() # parse again with new defaults - - # modify dataset-related parser options - dataset_name = opt.dataset_mode - if dataset_name != 'concat': - dataset_option_setter = data.get_option_setter(dataset_name) - parser = dataset_option_setter(parser, self.isTrain) - - # save and return the parser - self.parser = parser - return parser.parse_args() - - def print_options(self, opt): - """Print and save options - - It will print both current options and default values(if different). - It will save options into a text file / [checkpoints_dir] / opt.txt - """ - message = '' - message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): - comment = '' - default = self.parser.get_default(k) - if v != default: - comment = '\t[default: %s]' % str(default) - message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) - message += '----------------- End -------------------' - print(message) - - # save to the disk - expr_dir = os.path.join(opt.checkpoints_dir, opt.name) - util.mkdirs(expr_dir) - file_name = os.path.join(expr_dir, 'opt.txt') - with open(file_name, 'wt') as opt_file: - opt_file.write(message) - opt_file.write('\n') - - def parse(self): - """Parse our options, create checkpoints directory suffix, and set up gpu device.""" - opt = self.gather_options() - opt.isTrain = self.isTrain # train or test - - # process opt.suffix - if opt.suffix: - suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' - opt.name = opt.name + suffix - - self.print_options(opt) - - # set gpu ids - str_ids = opt.gpu_ids.split(',') - opt.gpu_ids = [] - for str_id in str_ids: - id = int(str_id) - if id >= 0: - opt.gpu_ids.append(id) - if len(opt.gpu_ids) > 0: - torch.cuda.set_device(opt.gpu_ids[0]) - - self.opt = opt - return self.opt diff --git a/plugins/ai_method/packages/STA_net/options/test_options.py b/plugins/ai_method/packages/STA_net/options/test_options.py deleted file mode 100644 index 3c7ac56..0000000 --- a/plugins/ai_method/packages/STA_net/options/test_options.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') - parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') - parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - # Dropout and Batchnorm has different behavioir during training and test. - parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') - parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') - # To avoid cropping, the load_size should be the same as crop_size - parser.set_defaults(load_size=parser.get_default('crop_size')) - self.isTrain = False - return parser diff --git a/plugins/ai_method/packages/STA_net/options/train_options.py b/plugins/ai_method/packages/STA_net/options/train_options.py deleted file mode 100644 index e9af838..0000000 --- a/plugins/ai_method/packages/STA_net/options/train_options.py +++ /dev/null @@ -1,40 +0,0 @@ -from .base_options import BaseOptions - - -class TrainOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # visdom and HTML visualization parameters - parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') - parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') - parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') - parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') - parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') - parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') - # network saving and loading parameters - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--lr_decay', type=float, default=1, help='learning rate decay for certain module ...') - - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - # training parameters - parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') - parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') - parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') - parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') - parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - - self.isTrain = True - return parser diff --git a/plugins/ai_method/packages/STA_net/remotesensing-12-01662-v2.pdf b/plugins/ai_method/packages/STA_net/remotesensing-12-01662-v2.pdf deleted file mode 100644 index ea335b7..0000000 Binary files a/plugins/ai_method/packages/STA_net/remotesensing-12-01662-v2.pdf and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_102_0512_0000.png b/plugins/ai_method/packages/STA_net/samples/A/test_102_0512_0000.png deleted file mode 100644 index 279ca3c..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_102_0512_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_121_0768_0256.png b/plugins/ai_method/packages/STA_net/samples/A/test_121_0768_0256.png deleted file mode 100644 index c0c5761..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_121_0768_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0000.png b/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0000.png deleted file mode 100644 index 3f654fe..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0512.png b/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0512.png deleted file mode 100644 index 2cc547d..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_2_0000_0512.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_55_0256_0000.png b/plugins/ai_method/packages/STA_net/samples/A/test_55_0256_0000.png deleted file mode 100644 index f4a5056..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_55_0256_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/A/test_77_0512_0256.png b/plugins/ai_method/packages/STA_net/samples/A/test_77_0512_0256.png deleted file mode 100644 index f5ee7d4..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/A/test_77_0512_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_102_0512_0000.png b/plugins/ai_method/packages/STA_net/samples/B/test_102_0512_0000.png deleted file mode 100644 index 168b791..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_102_0512_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_121_0768_0256.png b/plugins/ai_method/packages/STA_net/samples/B/test_121_0768_0256.png deleted file mode 100644 index 9cdd95f..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_121_0768_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0000.png b/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0000.png deleted file mode 100644 index a340401..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0512.png b/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0512.png deleted file mode 100644 index 2316bee..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_2_0000_0512.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_55_0256_0000.png b/plugins/ai_method/packages/STA_net/samples/B/test_55_0256_0000.png deleted file mode 100644 index 1e2cd61..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_55_0256_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/B/test_77_0512_0256.png b/plugins/ai_method/packages/STA_net/samples/B/test_77_0512_0256.png deleted file mode 100644 index ce451b0..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/B/test_77_0512_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_102_0512_0000.png b/plugins/ai_method/packages/STA_net/samples/label/test_102_0512_0000.png deleted file mode 100644 index 7a97ed1..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_102_0512_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_121_0768_0256.png b/plugins/ai_method/packages/STA_net/samples/label/test_121_0768_0256.png deleted file mode 100644 index 319e411..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_121_0768_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0000.png b/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0000.png deleted file mode 100644 index cfa8b2a..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0512.png b/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0512.png deleted file mode 100644 index b806851..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_2_0000_0512.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_55_0256_0000.png b/plugins/ai_method/packages/STA_net/samples/label/test_55_0256_0000.png deleted file mode 100644 index 31dff95..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_55_0256_0000.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/samples/label/test_77_0512_0256.png b/plugins/ai_method/packages/STA_net/samples/label/test_77_0512_0256.png deleted file mode 100644 index 1afe1f0..0000000 Binary files a/plugins/ai_method/packages/STA_net/samples/label/test_77_0512_0256.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/scripts/run.sh b/plugins/ai_method/packages/STA_net/scripts/run.sh deleted file mode 100644 index 2153835..0000000 --- a/plugins/ai_method/packages/STA_net/scripts/run.sh +++ /dev/null @@ -1,5 +0,0 @@ - - -# train PAM -python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop - diff --git a/plugins/ai_method/packages/STA_net/scripts/run_concat_mode.sh b/plugins/ai_method/packages/STA_net/scripts/run_concat_mode.sh deleted file mode 100644 index f5a32db..0000000 --- a/plugins/ai_method/packages/STA_net/scripts/run_concat_mode.sh +++ /dev/null @@ -1,11 +0,0 @@ - -# concat mode -list=train -lr=0.001 -dataset_type=CD_data1,CD_data2,..., -dataset_type=LEVIR_CD1 -val_dataset_type=LEVIR_CD1 -dataset_mode=concat -name=project_name - -python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop \ No newline at end of file diff --git a/plugins/ai_method/packages/STA_net/scripts/run_list_mode.sh b/plugins/ai_method/packages/STA_net/scripts/run_list_mode.sh deleted file mode 100644 index bc9da92..0000000 --- a/plugins/ai_method/packages/STA_net/scripts/run_list_mode.sh +++ /dev/null @@ -1,9 +0,0 @@ - -# List mode -list=train -lr=0.001 -dataset_mode=list -dataroot=path-to-dataroot -name=project_name -# -python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 25 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop diff --git a/plugins/ai_method/packages/STA_net/src/stanet-overview.png b/plugins/ai_method/packages/STA_net/src/stanet-overview.png deleted file mode 100644 index 28b7ee9..0000000 Binary files a/plugins/ai_method/packages/STA_net/src/stanet-overview.png and /dev/null differ diff --git a/plugins/ai_method/packages/STA_net/test.py b/plugins/ai_method/packages/STA_net/test.py deleted file mode 100644 index facf7a3..0000000 --- a/plugins/ai_method/packages/STA_net/test.py +++ /dev/null @@ -1,105 +0,0 @@ -from data import create_dataset -from models import create_model -from util.util import save_images -import numpy as np -from util.util import mkdir -import argparse -from PIL import Image -import torchvision.transforms as transforms - - -def transform(): - transform_list = [] - transform_list += [transforms.ToTensor()] - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - return transforms.Compose(transform_list) - - -def val(opt): - image_1_path = opt.image1_path - image_2_path = opt.image2_path - A_img = Image.open(image_1_path).convert('RGB') - B_img = Image.open(image_2_path).convert('RGB') - trans = transform() - A = trans(A_img).unsqueeze(0) - B = trans(B_img).unsqueeze(0) - # dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - save_path = opt.results_dir - mkdir(save_path) - model.eval() - data = {} - data['A']= A - data['B'] = B - data['A_paths'] = [image_1_path] - - model.set_input(data) # unpack data from data loader - pred = model.test(val=False) # run inference return pred - - img_path = [image_1_path] # get image paths - - save_images(pred, save_path, img_path) - - - -if __name__ == '__main__': - # 从外界调用方式: - # python test.py --image1_path [path-to-img1] --image2_path [path-to-img2] --results_dir [path-to-result_dir] - - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument('--image1_path', type=str, default='./samples/A/test_2_0000_0000.png', - help='path to images A') - parser.add_argument('--image2_path', type=str, default='./samples/B/test_2_0000_0000.png', - help='path to images B') - parser.add_argument('--results_dir', type=str, default='./samples/output/', help='saves results here.') - - parser.add_argument('--name', type=str, default='pam', - help='name of the experiment. It decides where to store samples and models') - parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') - parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') - # model parameters - parser.add_argument('--model', type=str, default='CDFA', help='chooses which model to use. [CDF0 | CDFA]') - parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ') - parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB') - parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3') - parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num') - parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes') - - parser.add_argument('--SA_mode', type=str, default='PAM', - help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...') - # dataset parameters - parser.add_argument('--dataset_mode', type=str, default='changedetection', - help='chooses how datasets are loaded. [changedetection | json]') - parser.add_argument('--val_dataset_mode', type=str, default='changedetection', - help='chooses how datasets are loaded. [changedetection | json]') - parser.add_argument('--split', type=str, default='train', - help='chooses wihch list-file to open when use listDataset. [train | val | test]') - parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate') - parser.add_argument('--angle', type=int, default=0, help='rotate angle') - parser.add_argument('--istest', type=bool, default=False, help='True for the case without label') - parser.add_argument('--serial_batches', action='store_true', - help='if true, takes images in order to make batches, otherwise takes them randomly') - parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--load_size', type=int, default=256, help='scale images to this size') - parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') - parser.add_argument('--max_dataset_size', type=int, default=float("inf"), - help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') - parser.add_argument('--preprocess', type=str, default='resize_and_crop', - help='scaling and cropping of images at load time [resize_and_crop | none]') - parser.add_argument('--no_flip', type=bool, default=True, - help='if specified, do not flip(left-right) the images for data augmentation') - parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') - parser.add_argument('--epoch', type=str, default='pam', - help='which epoch to load? set to latest to use latest cached model') - parser.add_argument('--load_iter', type=int, default='0', - help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') - parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - parser.add_argument('--isTrain', type=bool, default=False, help='is or not') - parser.add_argument('--num_test', type=int, default=np.inf, help='how many test images to run') - - opt = parser.parse_args() - val(opt) diff --git a/plugins/ai_method/packages/STA_net/train.py b/plugins/ai_method/packages/STA_net/train.py deleted file mode 100644 index bb8879f..0000000 --- a/plugins/ai_method/packages/STA_net/train.py +++ /dev/null @@ -1,181 +0,0 @@ -import time -from options.train_options import TrainOptions -from data import create_dataset -from models import create_model -from util.visualizer import Visualizer -import os -from util import html -from util.visualizer import save_images -from util.metrics import AverageMeter -import copy -import numpy as np -import torch -import random - - -def seed_torch(seed=2019): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - -# set seeds -# seed_torch(2019) -ifSaveImage = False - -def make_val_opt(opt): - - val_opt = copy.deepcopy(opt) - val_opt.preprocess = '' # - # hard-code some parameters for test - val_opt.num_threads = 0 # test code only supports num_threads = 1 - val_opt.batch_size = 4 # test code only supports batch_size = 1 - val_opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. - val_opt.no_flip = True # no flip; comment this line if results on flipped images are needed. - val_opt.angle = 0 - val_opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. - val_opt.phase = 'val' - val_opt.split = opt.val_split # function in jsonDataset and ListDataset - val_opt.isTrain = False - val_opt.aspect_ratio = 1 - val_opt.results_dir = './results/' - val_opt.dataroot = opt.val_dataroot - val_opt.dataset_mode = opt.val_dataset_mode - val_opt.dataset_type = opt.val_dataset_type - val_opt.json_name = opt.val_json_name - val_opt.eval = True - - val_opt.num_test = 2000 - return val_opt - - -def print_current_acc(log_name, epoch, score): - """print current acc on console; also save the losses to the disk - Parameters: - """ - message = '(epoch: %d) ' % epoch - for k, v in score.items(): - message += '%s: %.3f ' % (k, v) - print(message) # print the message - with open(log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message - - -def val(opt, model): - opt = make_val_opt(opt) - dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - # model = create_model(opt) # create a model given opt.model and other options - # model.setup(opt) # regular setup: load and print networks; create schedulers - - web_dir = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory - webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) - model.eval() - # create a logging file to store training losses - log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val_log.txt') - with open(log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ val acc (%s) ================\n' % now) - - running_metrics = AverageMeter() - for i, data in enumerate(dataset): - if i >= opt.num_test: # only apply our model to opt.num_test images. - break - model.set_input(data) # unpack data from data loader - score = model.test(val=True) # run inference - running_metrics.update(score) - visuals = model.get_current_visuals() # get image results - img_path = model.get_image_paths() # get image paths - if i % 5 == 0: # save images to an HTML file - print('processing (%04d)-th image... %s' % (i, img_path)) - if ifSaveImage: - save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) - - score = running_metrics.get_scores() - print_current_acc(log_name, epoch, score) - if opt.display_id > 0: - visualizer.plot_current_acc(epoch, float(epoch_iter) / dataset_size, score) - webpage.save() # save the HTML - - return score[metric_name] - -metric_name = 'F1_1' - - -if __name__ == '__main__': - opt = TrainOptions().parse() # get training options - dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - dataset_size = len(dataset) # get the number of images in the dataset. - print('The number of training images = %d' % dataset_size) - - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - visualizer = Visualizer(opt) # create a visualizer that display/save images and plots - total_iters = 0 # the total number of training iterations - miou_best = 0 - n_epoch_bad = 0 - epoch_best = 0 - time_metric = AverageMeter() - time_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'time_log.txt') - with open(time_log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ training time (%s) ================\n' % now) - - for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + - epoch_start_time = time.time() # timer for entire epoch - iter_data_time = time.time() # timer for data loading per iteration - epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch - model.train() - # miou_current = val(opt, model) - for i, data in enumerate(dataset): # inner loop within one epoch - iter_start_time = time.time() # timer for computation per iteration - if total_iters % opt.print_freq == 0: - t_data = iter_start_time - iter_data_time - visualizer.reset() - total_iters += opt.batch_size - epoch_iter += opt.batch_size - n_epoch = opt.niter + opt.niter_decay - - model.set_input(data) # unpack data from dataset and apply preprocessing - model.optimize_parameters() # calculate loss functions, get gradients, update network weights - if ifSaveImage: - if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file - save_result = total_iters % opt.update_html_freq == 0 - model.compute_visuals() - visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) - if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk - losses = model.get_current_losses() - t_comp = (time.time() - iter_start_time) / opt.batch_size - visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) - if opt.display_id > 0: - visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) - - if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations - print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) - save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' - model.save_networks(save_suffix) - - iter_data_time = time.time() - - t_epoch = time.time()-epoch_start_time - time_metric.update(t_epoch) - print_current_acc(time_log_name, epoch,{"current_t_epoch": t_epoch}) - - - if epoch % opt.save_epoch_freq == 0: # cache our model every epochs - print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) - model.save_networks('latest') - miou_current = val(opt, model) - - if miou_current > miou_best: - miou_best = miou_current - epoch_best = epoch - model.save_networks(str(epoch_best)+"_"+metric_name+'_'+'%0.5f'% miou_best) - - - print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) - model.update_learning_rate() # update learning rates at the end of every epoch. - - time_ave = time_metric.average() - print_current_acc(time_log_name, epoch, {"ave_t_epoch": time_ave}) diff --git a/plugins/ai_method/packages/STA_net/util/__init__.py b/plugins/ai_method/packages/STA_net/util/__init__.py deleted file mode 100644 index ae36f63..0000000 --- a/plugins/ai_method/packages/STA_net/util/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/plugins/ai_method/packages/STA_net/util/html.py b/plugins/ai_method/packages/STA_net/util/html.py deleted file mode 100644 index cc3262a..0000000 --- a/plugins/ai_method/packages/STA_net/util/html.py +++ /dev/null @@ -1,86 +0,0 @@ -import dominate -from dominate.tags import meta, h3, table, tr, td, p, a, img, br -import os - - -class HTML: - """This HTML class allows us to save images and write texts into a single HTML file. - - It consists of functions such as (add a text header to the HTML file), - (add a row of images to the HTML file), and (save the HTML to the disk). - It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. - """ - - def __init__(self, web_dir, title, refresh=0): - """Initialize the HTML classes - - Parameters: - web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: - with self.doc.head: - meta(http_equiv="refresh", content=str(refresh)) - - def get_image_dir(self): - """Return the directory that stores images""" - return self.img_dir - - def add_header(self, text): - """Insert a header to the HTML file - - Parameters: - text (str) -- the header text - """ - with self.doc: - h3(text) - - def add_images(self, ims, txts, links, width=400): - """add images to the HTML file - - Parameters: - ims (str list) -- a list of image paths - txts (str list) -- a list of image names shown on the website - links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page - """ - self.t = table(border=1, style="table-layout: fixed;") # Insert a table - self.doc.add(self.t) - with self.t: - with tr(): - for im, txt, link in zip(ims, txts, links): - with td(style="word-wrap: break-word;", halign="center", valign="top"): - with p(): - with a(href=os.path.join('images', link)): - img(style="width:%dpx" % width, src=os.path.join('images', im)) - br() - p(txt) - - def save(self): - """save the current content to the HMTL file""" - html_file = '%s/index.html' % self.web_dir - f = open(html_file, 'wt') - f.write(self.doc.render()) - f.close() - - -if __name__ == '__main__': # we show an example usage here. - html = HTML('web/', 'test_html') - html.add_header('hello world') - - ims, txts, links = [], [], [] - for n in range(4): - ims.append('image_%d.png' % n) - txts.append('text_%d' % n) - links.append('image_%d.png' % n) - html.add_images(ims, txts, links) - html.save() diff --git a/plugins/ai_method/packages/STA_net/util/metrics.py b/plugins/ai_method/packages/STA_net/util/metrics.py deleted file mode 100644 index 7e823ee..0000000 --- a/plugins/ai_method/packages/STA_net/util/metrics.py +++ /dev/null @@ -1,176 +0,0 @@ -# Adapted from score written by wkentaro -# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py -import numpy as np -eps=np.finfo(float).eps - - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.initialized = False - self.val = None - self.avg = None - self.sum = None - self.count = None - - def initialize(self, val, weight): - self.val = val - self.avg = val - self.sum = val * weight - self.count = weight - self.initialized = True - - def update(self, val, weight=1): - if not self.initialized: - self.initialize(val, weight) - else: - self.add(val, weight) - - def add(self, val, weight): - self.val = val - self.sum += val * weight - self.count += weight - self.avg = self.sum / self.count - - def value(self): - return self.val - - def average(self): - return self.avg - - def get_scores(self): - scores, cls_iu, m_1 = cm2score(self.sum) - scores.update(cls_iu) - scores.update(m_1) - return scores - - -def cm2score(confusion_matrix): - hist = confusion_matrix - n_class = hist.shape[0] - tp = np.diag(hist) - sum_a1 = hist.sum(axis=1) - sum_a0 = hist.sum(axis=0) - # ---------------------------------------------------------------------- # - # 1. Accuracy & Class Accuracy - # ---------------------------------------------------------------------- # - acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) - - acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps) - - # precision - precision = tp / (sum_a0 + np.finfo(np.float32).eps) - - # F1 score - F1 = 2*acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps) - # ---------------------------------------------------------------------- # - # 2. Mean IoU - # ---------------------------------------------------------------------- # - iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) - mean_iu = np.nanmean(iu) - - cls_iu = dict(zip(range(n_class), iu)) - - - - return {'Overall_Acc': acc, - 'Mean_IoU': mean_iu}, cls_iu, \ - { - 'precision_1': precision[1], - 'recall_1': acc_cls_[1], - 'F1_1': F1[1],} - - -class RunningMetrics(object): - def __init__(self, num_classes): - """ - Computes and stores the Metric values from Confusion Matrix - - overall accuracy - - mean accuracy - - mean IU - - fwavacc - For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix - :param num_classes: number of classes - """ - self.num_classes = num_classes - self.confusion_matrix = np.zeros((num_classes, num_classes)) - - def __fast_hist(self, label_gt, label_pred): - """ - Collect values for Confusion Matrix - For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix - :param label_gt: ground-truth - :param label_pred: prediction - :return: values for confusion matrix - """ - mask = (label_gt >= 0) & (label_gt < self.num_classes) - hist = np.bincount(self.num_classes * label_gt[mask].astype(int) + label_pred[mask], - minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes) - return hist - - def update(self, label_gts, label_preds): - """ - Compute Confusion Matrix - For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix - :param label_gts: ground-truths - :param label_preds: predictions - :return: - """ - for lt, lp in zip(label_gts, label_preds): - self.confusion_matrix += self.__fast_hist(lt.flatten(), lp.flatten()) - - def reset(self): - """ - Reset Confusion Matrix - :return: - """ - self.confusion_matrix = np.zeros((self.num_classes, self.num_classes)) - - def get_cm(self): - return self.confusion_matrix - - def get_scores(self): - """ - Returns score about: - - overall accuracy - - mean accuracy - - mean IU - - fwavacc - For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix - :return: - """ - hist = self.confusion_matrix - tp = np.diag(hist) - sum_a1 = hist.sum(axis=1) - sum_a0 = hist.sum(axis=0) - - # ---------------------------------------------------------------------- # - # 1. Accuracy & Class Accuracy - # ---------------------------------------------------------------------- # - acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) - - # recall - acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps) - - # precision - precision = tp / (sum_a0 + np.finfo(np.float32).eps) - # ---------------------------------------------------------------------- # - # 2. Mean IoU - # ---------------------------------------------------------------------- # - iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) - mean_iu = np.nanmean(iu) - - cls_iu = dict(zip(range(self.num_classes), iu)) - - # F1 score - F1 = 2 * acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps) - - scores = {'Overall_Acc': acc, - 'Mean_IoU': mean_iu} - scores.update(cls_iu) - scores.update({'precision_1': precision[1], - 'recall_1': acc_cls_[1], - 'F1_1': F1[1]}) - return scores - diff --git a/plugins/ai_method/packages/STA_net/util/util.py b/plugins/ai_method/packages/STA_net/util/util.py deleted file mode 100644 index 5ce0314..0000000 --- a/plugins/ai_method/packages/STA_net/util/util.py +++ /dev/null @@ -1,110 +0,0 @@ -"""This module contains simple helper functions """ -from __future__ import print_function -import torch -import numpy as np -from PIL import Image -import os -import ntpath - - -def save_images(images, img_dir, name): - """save images in img_dir, with name - iamges: torch.float, B*C*H*W - img_dir: str - name: list [str] - """ - for i, image in enumerate(images): - print(image.shape) - image_numpy = tensor2im(image.unsqueeze(0),normalize=False)*255 - basename = os.path.basename(name[i]) - print('name:', basename) - save_path = os.path.join(img_dir,basename) - save_image(image_numpy,save_path) - - -def save_visuals(visuals,img_dir,name): - """ - """ - name = ntpath.basename(name) - name = name.split(".")[0] - print(name) - # save images to the disk - for label, image in visuals.items(): - image_numpy = tensor2im(image) - img_path = os.path.join(img_dir, '%s_%s.png' % (name, label)) - save_image(image_numpy, img_path) - - -def tensor2im(input_image, imtype=np.uint8, normalize=True): - """"Converts a Tensor array into a numpy image array. - - Parameters: - input_image (tensor) -- the input image tensor array - imtype (type) -- the desired type of the converted numpy array - """ - if not isinstance(input_image, np.ndarray): - if isinstance(input_image, torch.Tensor): # get the data from a variable - image_tensor = input_image.data - else: - return input_image - image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array - if image_numpy.shape[0] == 1: # grayscale to RGB - image_numpy = np.tile(image_numpy, (3, 1, 1)) - - image_numpy = np.transpose(image_numpy, (1, 2, 0)) - if normalize: - image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling - - else: # if it is a numpy array, do nothing - image_numpy = input_image - return image_numpy.astype(imtype) - - -def save_image(image_numpy, image_path): - """Save a numpy image to the disk - - Parameters: - image_numpy (numpy array) -- input numpy array - image_path (str) -- the path of the image - """ - image_pil = Image.fromarray(image_numpy) - image_pil.save(image_path) - - -def print_numpy(x, val=True, shp=False): - """Print the mean, min, max, median, std, and size of a numpy array - - Parameters: - val (bool) -- if print the values of the numpy array - shp (bool) -- if print the shape of the numpy array - """ - x = x.astype(np.float64) - if shp: - print('shape,', x.shape) - if val: - x = x.flatten() - print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( - np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) - - -def mkdirs(paths): - """create empty directories if they don't exist - - Parameters: - paths (str list) -- a list of directory paths - """ - if isinstance(paths, list) and not isinstance(paths, str): - for path in paths: - mkdir(path) - else: - mkdir(paths) - - -def mkdir(path): - """create a single empty directory if it didn't exist - - Parameters: - path (str) -- a single directory path - """ - if not os.path.exists(path): - os.makedirs(path) diff --git a/plugins/ai_method/packages/STA_net/util/visualizer.py b/plugins/ai_method/packages/STA_net/util/visualizer.py deleted file mode 100644 index 99b930d..0000000 --- a/plugins/ai_method/packages/STA_net/util/visualizer.py +++ /dev/null @@ -1,246 +0,0 @@ -import numpy as np -import os -import sys -import ntpath -import time -from . import util, html -from subprocess import Popen, PIPE -# from scipy.misc import imresize - - -if sys.version_info[0] == 2: - VisdomExceptionBase = Exception -else: - VisdomExceptionBase = ConnectionError - - -def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): - """Save images to the disk. - - Parameters: - webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) - visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs - image_path (str) -- the string is used to create image paths - aspect_ratio (float) -- the aspect ratio of saved images - width (int) -- the images will be resized to width x width - - This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. - """ - image_dir = webpage.get_image_dir() - short_path = ntpath.basename(image_path[0]) - name = os.path.splitext(short_path)[0] - - webpage.add_header(name) - ims, txts, links = [], [], [] - - for label, im_data in visuals.items(): - im = util.tensor2im(im_data) - image_name = '%s_%s.png' % (name, label) - save_path = os.path.join(image_dir, image_name) - h, w, _ = im.shape - # if aspect_ratio > 1.0: - # im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') - # if aspect_ratio < 1.0: - # im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') - util.save_image(im, save_path) - - ims.append(image_name) - txts.append(label) - links.append(image_name) - webpage.add_images(ims, txts, links, width=width) - - -class Visualizer(): - """This class includes several functions that can display/save images and print/save logging information. - - It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. - """ - - def __init__(self, opt): - """Initialize the Visualizer class - - Parameters: - opt -- stores all the experiment flags; needs to be a subclass of BaseOptions - Step 1: Cache the training/test options - Step 2: connect to a visdom server - Step 3: create an HTML object for saveing HTML filters - Step 4: create a logging file to store training losses - """ - self.opt = opt # cache the option - self.display_id = opt.display_id - self.use_html = opt.isTrain and not opt.no_html - self.win_size = opt.display_winsize - self.name = opt.name - self.port = opt.display_port - self.saved = False - if self.display_id > 0: # connect to a visdom server given and - import visdom - self.ncols = opt.display_ncols - self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) - if not self.vis.check_connection(): - self.create_visdom_connections() - - if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ - self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') - self.img_dir = os.path.join(self.web_dir, 'images') - print('create web directory %s...' % self.web_dir) - util.mkdirs([self.web_dir, self.img_dir]) - # create a logging file to store training losses - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ Training Loss (%s) ================\n' % now) - - def reset(self): - """Reset the self.saved status""" - self.saved = False - - def create_visdom_connections(self): - """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ - cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port - print('\n\nCould not connect to Visdom server. \n Trying to start a server....') - print('Command: %s' % cmd) - Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) - - def display_current_results(self, visuals, epoch, save_result): - """Display current results on visdom; save current results to an HTML file. - - Parameters: - visuals (OrderedDict) - - dictionary of images to display or save - epoch (int) - - the current epoch - save_result (bool) - - if save the current results to an HTML file - """ - if self.display_id > 0: # show images in the browser using visdom - ncols = self.ncols - if ncols > 0: # show all the images in one visdom panel - ncols = min(ncols, len(visuals)) - h, w = next(iter(visuals.values())).shape[:2] - table_css = """""" % (w, h) # create a table css - # create a table of images. - title = self.name - label_html = '' - label_html_row = '' - images = [] - idx = 0 - for label, image in visuals.items(): - image_numpy = util.tensor2im(image) - label_html_row += '%s' % label - images.append(image_numpy.transpose([2, 0, 1])) - idx += 1 - if idx % ncols == 0: - label_html += '%s' % label_html_row - label_html_row = '' - white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 - while idx % ncols != 0: - images.append(white_image) - label_html_row += '' - idx += 1 - if label_html_row != '': - label_html += '%s' % label_html_row - try: - self.vis.images(images, nrow=ncols, win=self.display_id + 1, - padding=2, opts=dict(title=title + ' images')) - label_html = '%s
' % label_html - self.vis.text(table_css + label_html, win=self.display_id + 2, - opts=dict(title=title + ' labels')) - except VisdomExceptionBase: - self.create_visdom_connections() - - else: # show each image in a separate visdom panel; - idx = 1 - try: - for label, image in visuals.items(): - image_numpy = util.tensor2im(image) - self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), - win=self.display_id + idx) - idx += 1 - except VisdomExceptionBase: - self.create_visdom_connections() - - if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. - self.saved = True - # save images to the disk - for label, image in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) - util.save_image(image_numpy, img_path) - - # update website - webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) - for n in range(epoch, 0, -1): - webpage.add_header('epoch [%d]' % n) - ims, txts, links = [], [], [] - - for label, image_numpy in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = 'epoch%.3d_%s.png' % (n, label) - ims.append(img_path) - txts.append(label) - links.append(img_path) - webpage.add_images(ims, txts, links, width=self.win_size) - webpage.save() - - def plot_current_losses(self, epoch, counter_ratio, losses): - """display the current losses on visdom display: dictionary of error labels and values - - Parameters: - epoch (int) -- current epoch - counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - """ - if not hasattr(self, 'plot_data'): - self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} - self.plot_data['X'].append(epoch + counter_ratio) - self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) - try: - self.vis.line( - X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), - Y=np.array(self.plot_data['Y']), - opts={ - 'title': self.name + ' loss over time', - 'legend': self.plot_data['legend'], - 'xlabel': 'epoch', - 'ylabel': 'loss'}, - win=self.display_id) - except VisdomExceptionBase: - self.create_visdom_connections() - - def plot_current_acc(self, epoch, counter_ratio, acc): - if not hasattr(self, 'acc_data'): - self.acc_data = {'X': [], 'Y': [], 'legend': list(acc.keys())} - self.acc_data['X'].append(epoch + counter_ratio) - self.acc_data['Y'].append([acc[k] for k in self.acc_data['legend']]) - try: - self.vis.line( - X=np.stack([np.array(self.acc_data['X'])] * len(self.acc_data['legend']), 1), - Y=np.array(self.acc_data['Y']), - opts={ - 'title': self.name + ' acc over time', - 'legend': self.acc_data['legend'], - 'xlabel': 'epoch', - 'ylabel': 'acc'}, - win=self.display_id+3) - except VisdomExceptionBase: - self.create_visdom_connections() - - # losses: same format as |losses| of plot_current_losses - def print_current_losses(self, epoch, iters, losses, t_comp, t_data): - """print current losses on console; also save the losses to the disk - - Parameters: - epoch (int) -- current epoch - iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - t_comp (float) -- computational time per data point (normalized by batch_size) - t_data (float) -- data loading time per data point (normalized by batch_size) - """ - message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - - print(message) # print the message - with open(self.log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message diff --git a/plugins/ai_method/packages/STA_net/val.py b/plugins/ai_method/packages/STA_net/val.py deleted file mode 100644 index 494e88a..0000000 --- a/plugins/ai_method/packages/STA_net/val.py +++ /dev/null @@ -1,96 +0,0 @@ -import time -from options.test_options import TestOptions -from data import create_dataset -from models import create_model -import os -from util.util import save_visuals -from util.metrics import AverageMeter -import numpy as np -from util.util import mkdir - -def make_val_opt(opt): - - # hard-code some parameters for test - opt.num_threads = 0 # test code only supports num_threads = 1 - opt.batch_size = 1 # test code only supports batch_size = 1 - opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. - opt.no_flip = True # no flip; comment this line if results on flipped images are needed. - opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed. - - opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. - opt.phase = 'val' - opt.preprocess = 'none1' - opt.isTrain = False - opt.aspect_ratio = 1 - opt.eval = True - - return opt - - -def print_current_acc(log_name, epoch, score): - """print current acc on console; also save the losses to the disk - Parameters: - """ - - message = '(epoch: %s) ' % str(epoch) - for k, v in score.items(): - message += '%s: %.3f ' % (k, v) - print(message) # print the message - with open(log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message - - -def val(opt): - - dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - save_path = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) - mkdir(save_path) - model.eval() - # create a logging file to store training losses - log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val1_log.txt') - with open(log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ val acc (%s) ================\n' % now) - - running_metrics = AverageMeter() - for i, data in enumerate(dataset): - if i >= opt.num_test: # only apply our model to opt.num_test images. - break - model.set_input(data) # unpack data from data loader - score = model.test(val=True) # run inference return confusion_matrix - running_metrics.update(score) - - visuals = model.get_current_visuals() # get image results - img_path = model.get_image_paths() # get image paths - if i % 5 == 0: # save images to an HTML file - print('processing (%04d)-th image... %s' % (i, img_path)) - - save_visuals(visuals,save_path,img_path[0]) - score = running_metrics.get_scores() - print_current_acc(log_name, opt.epoch, score) - - -if __name__ == '__main__': - opt = TestOptions().parse() # get training options - opt = make_val_opt(opt) - opt.phase = 'val' - opt.dataroot = 'path-to-LEVIR-CD-test' - - opt.dataset_mode = 'changedetection' - - opt.n_class = 2 - - opt.SA_mode = 'PAM' - opt.arch = 'mynet3' - - opt.model = 'CDFA' - - opt.name = 'pam' - opt.results_dir = './results/' - - opt.epoch = '78_F1_1_0.88780' - opt.num_test = np.inf - - val(opt) diff --git a/plugins/ai_method/packages/__init__.py b/plugins/ai_method/packages/__init__.py new file mode 100644 index 0000000..b4939be --- /dev/null +++ b/plugins/ai_method/packages/__init__.py @@ -0,0 +1,39 @@ +from .models import create_model +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class Front: + + def __init__(self, model) -> None: + self.model = model + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self.model = model.to(self.device) + + def __call__(self, inp1, inp2): + inp1 = torch.from_numpy(inp1).to(self.device) + inp1 = torch.transpose(inp1, 0, 2).transpose(1, 2).unsqueeze(0) + inp2 = torch.from_numpy(inp2).to(self.device) + inp2 = torch.transpose(inp2, 0, 2).transpose(1, 2).unsqueeze(0) + + out = self.model(inp1, inp2) + out = out.sigmoid() + out = out.cpu().detach().numpy()[0,0] + return out + +def get_model(name): + + try: + model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 + encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization + in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) + classes=1, # model output channels (number of classes in your datasets) + siam_encoder=True, # whether to use a siamese encoder + fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.) + ) + except: + return None + + return Front(model) \ No newline at end of file diff --git a/plugins/ai_method/packages/models/DPFCN/__init__.py b/plugins/ai_method/packages/models/DPFCN/__init__.py new file mode 100644 index 0000000..3f0a73f --- /dev/null +++ b/plugins/ai_method/packages/models/DPFCN/__init__.py @@ -0,0 +1 @@ +from .model import DPFCN \ No newline at end of file diff --git a/plugins/ai_method/packages/models/DPFCN/decoder.py b/plugins/ai_method/packages/models/DPFCN/decoder.py new file mode 100644 index 0000000..1056e86 --- /dev/null +++ b/plugins/ai_method/packages/models/DPFCN/decoder.py @@ -0,0 +1,134 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..base import modules as md +from ..base import Decoder + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=None, + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type, in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class DPFCNDecoder(Decoder): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + fusion_form="concat", + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution + encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + # adjust encoder channels according to fusion form + self.fusion_form = fusion_form + if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: + skip_channels = [ch*2 for ch in skip_channels] + in_channels[0] = in_channels[0] * 2 + head_channels = head_channels * 2 + + if center: + self.center = CenterBlock( + head_channels, head_channels, use_batchnorm=use_batchnorm + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + blocks = [ + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = self.aggregation_layer(features[0], features[1], + self.fusion_form, ignore_original_img=True) + # features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + + return x diff --git a/plugins/ai_method/packages/models/DPFCN/model.py b/plugins/ai_method/packages/models/DPFCN/model.py new file mode 100644 index 0000000..39e1d82 --- /dev/null +++ b/plugins/ai_method/packages/models/DPFCN/model.py @@ -0,0 +1,72 @@ +from typing import Optional, Union, List +from .decoder import DPFCNDecoder +from ..encoders import get_encoder +from ..base import SegmentationModel +from ..base import SegmentationHead, ClassificationHead + + +class DPFCN(SegmentationModel): + + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + siam_encoder: bool = True, + fusion_form: str = "concat", + **kwargs + ): + super().__init__() + + self.siam_encoder = siam_encoder + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + if not self.siam_encoder: + self.encoder_non_siam = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = DPFCNDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + attention_type=decoder_attention_type, + fusion_form=fusion_form, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "u-{}".format(encoder_name) + self.initialize() + diff --git a/plugins/ai_method/packages/models/DVCA/__init__.py b/plugins/ai_method/packages/models/DVCA/__init__.py new file mode 100644 index 0000000..0124b36 --- /dev/null +++ b/plugins/ai_method/packages/models/DVCA/__init__.py @@ -0,0 +1 @@ +from .model import DVCA \ No newline at end of file diff --git a/plugins/ai_method/packages/models/DVCA/decoder.py b/plugins/ai_method/packages/models/DVCA/decoder.py new file mode 100644 index 0000000..72fb393 --- /dev/null +++ b/plugins/ai_method/packages/models/DVCA/decoder.py @@ -0,0 +1,179 @@ +""" +BSD 3-Clause License + +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +from torch import nn +from torch.nn import functional as F + +from ..base import Decoder + +__all__ = ["DVCADecoder"] + + +class DVCADecoder(Decoder): + def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36), fusion_form="concat"): + super().__init__() + + # adjust encoder channels according to fusion form + if fusion_form in self.FUSION_DIC["2to2_fusion"]: + in_channels = in_channels * 2 + + self.aspp = nn.Sequential( + ASPP(in_channels, out_channels, atrous_rates), + nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + self.out_channels = out_channels + self.fusion_form = fusion_form + + def forward(self, *features): + x = self.fusion(features[0][-1], features[1][-1], self.fusion_form) + x = self.aspp(x) + return x + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPSeparableConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + SeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super().__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + size = x.shape[-2:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, out_channels, atrous_rates, separable=False): + super(ASPP, self).__init__() + modules = [] + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv + + modules.append(ASPPConvModule(in_channels, out_channels, rate1)) + modules.append(ASPPConvModule(in_channels, out_channels, rate2)) + modules.append(ASPPConvModule(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +class SeparableConv2d(nn.Sequential): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + dephtwise_conv = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + pointwise_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=bias, + ) + super().__init__(dephtwise_conv, pointwise_conv) diff --git a/plugins/ai_method/packages/models/DVCA/model.py b/plugins/ai_method/packages/models/DVCA/model.py new file mode 100644 index 0000000..6757856 --- /dev/null +++ b/plugins/ai_method/packages/models/DVCA/model.py @@ -0,0 +1,69 @@ +from typing import Optional + +import torch.nn as nn + +from ..base import ClassificationHead, SegmentationHead, SegmentationModel +from ..encoders import get_encoder +from .decoder import DVCADecoder + + +class DVCA(SegmentationModel): + + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_channels: int = 256, + in_channels: int = 3, + classes: int = 1, + activation: Optional[str] = None, + upsampling: int = 8, + aux_params: Optional[dict] = None, + siam_encoder: bool = True, + fusion_form: str = "concat", + **kwargs + ): + super().__init__() + + self.siam_encoder = siam_encoder + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + output_stride=8, + ) + + if not self.siam_encoder: + self.encoder_non_siam = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + output_stride=8, + ) + + self.decoder = DVCADecoder( + in_channels=self.encoder.out_channels[-1], + out_channels=decoder_channels, + fusion_form=fusion_form, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + diff --git a/plugins/ai_method/packages/models/RCNN/__init__.py b/plugins/ai_method/packages/models/RCNN/__init__.py new file mode 100644 index 0000000..a0b2d80 --- /dev/null +++ b/plugins/ai_method/packages/models/RCNN/__init__.py @@ -0,0 +1 @@ +from .model import RCNN \ No newline at end of file diff --git a/plugins/ai_method/packages/models/RCNN/decoder.py b/plugins/ai_method/packages/models/RCNN/decoder.py new file mode 100644 index 0000000..03b63ef --- /dev/null +++ b/plugins/ai_method/packages/models/RCNN/decoder.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..base import Decoder + + +class Conv3x3GNReLU(nn.Module): + def __init__(self, in_channels, out_channels, upsample=False): + super().__init__() + self.upsample = upsample + self.block = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False + ), + nn.GroupNorm(32, out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.block(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + return x + + +class FPNBlock(nn.Module): + def __init__(self, pyramid_channels, skip_channels): + super().__init__() + self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + skip = self.skip_conv(skip) + x = x + skip + return x + + +class SegmentationBlock(nn.Module): + def __init__(self, in_channels, out_channels, n_upsamples=0): + super().__init__() + + blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] + + if n_upsamples > 1: + for _ in range(1, n_upsamples): + blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) + + self.block = nn.Sequential(*blocks) + + def forward(self, x): + return self.block(x) + + +class MergeBlock(nn.Module): + def __init__(self, policy): + super().__init__() + if policy not in ["add", "cat"]: + raise ValueError( + "`merge_policy` must be one of: ['add', 'cat'], got {}".format( + policy + ) + ) + self.policy = policy + + def forward(self, x): + if self.policy == 'add': + return sum(x) + elif self.policy == 'cat': + return torch.cat(x, dim=1) + else: + raise ValueError( + "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) + ) + + +class RCNNDecoder(Decoder): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=128, + dropout=0.2, + merge_policy="add", + fusion_form="concat", + ): + super().__init__() + + self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 + if encoder_depth < 3: + raise ValueError("Encoder depth for RCNN decoder cannot be less than 3, got {}.".format(encoder_depth)) + + encoder_channels = encoder_channels[::-1] + encoder_channels = encoder_channels[:encoder_depth + 1] + # (512, 256, 128, 64, 64, 3) + + # adjust encoder channels according to fusion form + self.fusion_form = fusion_form + if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: + encoder_channels = [ch*2 for ch in encoder_channels] + + self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) + self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) + self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) + self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) + + self.seg_blocks = nn.ModuleList([ + SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) + for n_upsamples in [3, 2, 1, 0] + ]) + + self.merge = MergeBlock(merge_policy) + self.dropout = nn.Dropout2d(p=dropout, inplace=True) + + def forward(self, *features): + + features = self.aggregation_layer(features[0], features[1], + self.fusion_form, ignore_original_img=True) + c2, c3, c4, c5 = features[-4:] + + p5 = self.p5(c5) + p4 = self.p4(p5, c4) + p3 = self.p3(p4, c3) + p2 = self.p2(p3, c2) + + feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] + x = self.merge(feature_pyramid) + x = self.dropout(x) + + return x diff --git a/plugins/ai_method/packages/models/RCNN/model.py b/plugins/ai_method/packages/models/RCNN/model.py new file mode 100644 index 0000000..61656ad --- /dev/null +++ b/plugins/ai_method/packages/models/RCNN/model.py @@ -0,0 +1,73 @@ +from typing import Optional, Union + +from ..base import ClassificationHead, SegmentationHead, SegmentationModel +from ..encoders import get_encoder +from .decoder import RCNNDecoder + + +class RCNN(SegmentationModel): + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 128, + decoder_merge_policy: str = "add", + decoder_dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + siam_encoder: bool = True, + fusion_form: str = "concat", + **kwargs + ): + super().__init__() + + self.siam_encoder = siam_encoder + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + if not self.siam_encoder: + self.encoder_non_siam = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = RCNNDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + dropout=decoder_dropout, + merge_policy=decoder_merge_policy, + fusion_form=fusion_form, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "rcnn-{}".format(encoder_name) + self.initialize() diff --git a/plugins/ai_method/packages/models/__init__.py b/plugins/ai_method/packages/models/__init__.py new file mode 100644 index 0000000..5c3a1fa --- /dev/null +++ b/plugins/ai_method/packages/models/__init__.py @@ -0,0 +1,42 @@ +from .RCNN import RCNN +from .DVCA import DVCA +from .DPFCN import DPFCN + +from . import encoders +from . import utils +from . import losses +from . import datasets + +from .__version__ import __version__ + +from typing import Optional +import torch + + +def create_model( + arch: str, + encoder_name: str = "resnet34", + encoder_weights: Optional[str] = "imagenet", + in_channels: int = 3, + classes: int = 1, + **kwargs, +) -> torch.nn.Module: + """Models wrapper. Allows to create any model just with parametes + + """ + + archs = [DVCA, DPFCN, RCNN] + archs_dict = {a.__name__.lower(): a for a in archs} + try: + model_class = archs_dict[arch.lower()] + except KeyError: + raise KeyError("Wrong architecture type `{}`. Available options are: {}".format( + arch, list(archs_dict.keys()), + )) + return model_class( + encoder_name=encoder_name, + encoder_weights=encoder_weights, + in_channels=in_channels, + classes=classes, + **kwargs, + ) diff --git a/plugins/ai_method/packages/models/__version__.py b/plugins/ai_method/packages/models/__version__.py new file mode 100644 index 0000000..d0727ee --- /dev/null +++ b/plugins/ai_method/packages/models/__version__.py @@ -0,0 +1,3 @@ +VERSION = (0, 1, 4) + +__version__ = '.'.join(map(str, VERSION)) diff --git a/plugins/ai_method/packages/models/base/__init__.py b/plugins/ai_method/packages/models/base/__init__.py new file mode 100644 index 0000000..5b4972e --- /dev/null +++ b/plugins/ai_method/packages/models/base/__init__.py @@ -0,0 +1,12 @@ +from .model import SegmentationModel +from .decoder import Decoder + +from .modules import ( + Conv2dReLU, + Attention, +) + +from .heads import ( + SegmentationHead, + ClassificationHead, +) diff --git a/plugins/ai_method/packages/models/base/decoder.py b/plugins/ai_method/packages/models/base/decoder.py new file mode 100644 index 0000000..79c67d3 --- /dev/null +++ b/plugins/ai_method/packages/models/base/decoder.py @@ -0,0 +1,33 @@ +import torch + + +class Decoder(torch.nn.Module): + # TODO: support learnable fusion modules + def __init__(self): + super().__init__() + self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"], + "2to2_fusion": ["concat"]} + + def fusion(self, x1, x2, fusion_form="concat"): + """Specify the form of feature fusion""" + if fusion_form == "concat": + x = torch.cat([x1, x2], dim=1) + elif fusion_form == "sum": + x = x1 + x2 + elif fusion_form == "diff": + x = x2 - x1 + elif fusion_form == "abs_diff": + x = torch.abs(x1 - x2) + else: + raise ValueError('the fusion form "{}" is not defined'.format(fusion_form)) + + return x + + def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True): + """aggregate features from siamese or non-siamese branches""" + + start_idx = 1 if ignore_original_img else 0 + aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form) + for idx in range(start_idx, len(fea1))] + + return aggregate_fea diff --git a/plugins/ai_method/packages/models/base/heads.py b/plugins/ai_method/packages/models/base/heads.py new file mode 100644 index 0000000..f0fa177 --- /dev/null +++ b/plugins/ai_method/packages/models/base/heads.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from .modules import Flatten, Activation + + +class SegmentationHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1, align_corners=True): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity() + activation = Activation(activation) + super().__init__(conv2d, upsampling, activation) + + +class ClassificationHead(nn.Sequential): + + def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None): + if pooling not in ("max", "avg"): + raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling)) + pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1) + flatten = Flatten() + dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() + linear = nn.Linear(in_channels, classes, bias=True) + activation = Activation(activation) + super().__init__(pool, flatten, dropout, linear, activation) diff --git a/plugins/ai_method/packages/models/base/initialization.py b/plugins/ai_method/packages/models/base/initialization.py new file mode 100644 index 0000000..9622130 --- /dev/null +++ b/plugins/ai_method/packages/models/base/initialization.py @@ -0,0 +1,27 @@ +import torch.nn as nn + + +def initialize_decoder(module): + for m in module.modules(): + + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def initialize_head(module): + for m in module.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) diff --git a/plugins/ai_method/packages/models/base/model.py b/plugins/ai_method/packages/models/base/model.py new file mode 100644 index 0000000..2d3f846 --- /dev/null +++ b/plugins/ai_method/packages/models/base/model.py @@ -0,0 +1,53 @@ +import torch +from . import initialization as init + + +class SegmentationModel(torch.nn.Module): + + def initialize(self): + init.initialize_decoder(self.decoder) + init.initialize_head(self.segmentation_head) + if self.classification_head is not None: + init.initialize_head(self.classification_head) + + def base_forward(self, x1, x2): + """Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads""" + if self.siam_encoder: + features = self.encoder(x1), self.encoder(x2) + else: + features = self.encoder(x1), self.encoder_non_siam(x2) + + decoder_output = self.decoder(*features) + + # TODO: features = self.fusion_policy(features) + + masks = self.segmentation_head(decoder_output) + + if self.classification_head is not None: + raise AttributeError("`classification_head` is not supported now.") + # labels = self.classification_head(features[-1]) + # return masks, labels + + return masks + + def forward(self, x1, x2): + """Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads""" + return self.base_forward(x1, x2) + + def predict(self, x1, x2): + """Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()` + + Args: + x1, x2: 4D torch tensor with shape (batch_size, channels, height, width) + + Return: + prediction: 4D torch tensor with shape (batch_size, classes, height, width) + + """ + if self.training: + self.eval() + + with torch.no_grad(): + x = self.forward(x1, x2) + + return x diff --git a/plugins/ai_method/packages/models/base/modules.py b/plugins/ai_method/packages/models/base/modules.py new file mode 100644 index 0000000..db12c23 --- /dev/null +++ b/plugins/ai_method/packages/models/base/modules.py @@ -0,0 +1,242 @@ +import torch +import torch.nn as nn + +try: + from inplace_abn import InPlaceABN +except ImportError: + InPlaceABN = None + + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + + if use_batchnorm == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + + "To install see: https://github.com/mapillary/inplace_abn" + ) + + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + if use_batchnorm == "inplace": + bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) + relu = nn.Identity() + + elif use_batchnorm and use_batchnorm != "inplace": + bn = nn.BatchNorm2d(out_channels) + + else: + bn = nn.Identity() + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class SCSEModule(nn.Module): + def __init__(self, in_channels, reduction=16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) + + def forward(self, x): + return x * self.cSE(x) + x * self.sSE(x) + + +class CBAMChannel(nn.Module): + def __init__(self, in_channels, reduction=16): + super(CBAMChannel, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), + nn.ReLU(), + nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return x * self.sigmoid(out) + + +class CBAMSpatial(nn.Module): + def __init__(self, in_channels, kernel_size=7): + super(CBAMSpatial, self).__init__() + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + out = torch.cat([avg_out, max_out], dim=1) + out = self.conv1(out) + return x * self.sigmoid(out) + + +class CBAM(nn.Module): + """ + Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C] + //Proceedings of the European conference on computer vision (ECCV). + """ + def __init__(self, in_channels, reduction=16, kernel_size=7): + super(CBAM, self).__init__() + self.ChannelGate = CBAMChannel(in_channels, reduction) + self.SpatialGate = CBAMSpatial(kernel_size) + + def forward(self, x): + x = self.ChannelGate(x) + x = self.SpatialGate(x) + return x + + +class ECAM(nn.Module): + """ + Ensemble Channel Attention Module for UNetPlusPlus. + Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J]. + IEEE Geoscience and Remote Sensing Letters, 2021. + Not completely consistent, to be improved. + """ + def __init__(self, in_channels, out_channels, map_num=4): + super(ECAM, self).__init__() + self.ca1 = CBAMChannel(in_channels * map_num, reduction=16) + self.ca2 = CBAMChannel(in_channels, reduction=16 // 4) + self.up = nn.ConvTranspose2d(in_channels * map_num, in_channels * map_num, 2, stride=2) + self.conv_final = nn.Conv2d(in_channels * map_num, out_channels, kernel_size=1) + + def forward(self, x): + """ + x (list[tensor] or tuple(tensor)) + """ + out = torch.cat(x, 1) + intra = torch.sum(torch.stack(x), dim=0) + ca2 = self.ca2(intra) + out = self.ca1(out) * (out + ca2.repeat(1, 4, 1, 1)) + out = self.up(out) + out = self.conv_final(out) + return out + + +class SEModule(nn.Module): + """ + Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C] + //Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141. + """ + def __init__(self, in_channels, reduction=16): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(in_channels // reduction, in_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class ArgMax(nn.Module): + + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.argmax(x, dim=self.dim) + + +class Clamp(nn.Module): + def __init__(self, min=0, max=1): + super().__init__() + self.min, self.max = min, max + + def forward(self, x): + return torch.clamp(x, self.min, self.max) + + +class Activation(nn.Module): + + def __init__(self, name, **params): + + super().__init__() + + if name is None or name == 'identity': + self.activation = nn.Identity(**params) + elif name == 'sigmoid': + self.activation = nn.Sigmoid() + elif name == 'softmax2d': + self.activation = nn.Softmax(dim=1, **params) + elif name == 'softmax': + self.activation = nn.Softmax(**params) + elif name == 'logsoftmax': + self.activation = nn.LogSoftmax(**params) + elif name == 'tanh': + self.activation = nn.Tanh() + elif name == 'argmax': + self.activation = ArgMax(**params) + elif name == 'argmax2d': + self.activation = ArgMax(dim=1, **params) + elif name == 'clamp': + self.activation = Clamp(**params) + elif callable(name): + self.activation = name(**params) + else: + raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name)) + + def forward(self, x): + return self.activation(x) + + +class Attention(nn.Module): + + def __init__(self, name, **params): + super().__init__() + + if name is None: + self.attention = nn.Identity(**params) + elif name == 'scse': + self.attention = SCSEModule(**params) + elif name == 'cbam_channel': + self.attention = CBAMChannel(**params) + elif name == 'cbam_spatial': + self.attention = CBAMSpatial(**params) + elif name == 'cbam': + self.attention = CBAM(**params) + elif name == 'se': + self.attention = SEModule(**params) + else: + raise ValueError("Attention {} is not implemented".format(name)) + + def forward(self, x): + return self.attention(x) + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.shape[0], -1) diff --git a/plugins/ai_method/packages/models/encoders/__init__.py b/plugins/ai_method/packages/models/encoders/__init__.py new file mode 100644 index 0000000..231d1d2 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/__init__.py @@ -0,0 +1,113 @@ +import functools + +import torch +import torch.utils.model_zoo as model_zoo + +from ._preprocessing import preprocess_input +from .densenet import densenet_encoders +from .dpn import dpn_encoders +from .efficientnet import efficient_net_encoders +from .inceptionresnetv2 import inceptionresnetv2_encoders +from .inceptionv4 import inceptionv4_encoders +from .mobilenet import mobilenet_encoders +from .resnet import resnet_encoders +from .senet import senet_encoders +from .timm_efficientnet import timm_efficientnet_encoders +from .timm_gernet import timm_gernet_encoders +from .timm_mobilenetv3 import timm_mobilenetv3_encoders +from .timm_regnet import timm_regnet_encoders +from .timm_res2net import timm_res2net_encoders +from .timm_resnest import timm_resnest_encoders +from .timm_sknet import timm_sknet_encoders +from .timm_universal import TimmUniversalEncoder +from .vgg import vgg_encoders +from .xception import xception_encoders +from .swin_transformer import swin_transformer_encoders +from .mit_encoder import mit_encoders +# from .hrnet import hrnet_encoders + +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + +encoders = {} +encoders.update(resnet_encoders) +encoders.update(dpn_encoders) +encoders.update(vgg_encoders) +encoders.update(senet_encoders) +encoders.update(densenet_encoders) +encoders.update(inceptionresnetv2_encoders) +encoders.update(inceptionv4_encoders) +encoders.update(efficient_net_encoders) +encoders.update(mobilenet_encoders) +encoders.update(xception_encoders) +encoders.update(timm_efficientnet_encoders) +encoders.update(timm_resnest_encoders) +encoders.update(timm_res2net_encoders) +encoders.update(timm_regnet_encoders) +encoders.update(timm_sknet_encoders) +encoders.update(timm_mobilenetv3_encoders) +encoders.update(timm_gernet_encoders) +encoders.update(swin_transformer_encoders) +encoders.update(mit_encoders) +# encoders.update(hrnet_encoders) + + +def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): + + if name.startswith("tu-"): + name = name[3:] + encoder = TimmUniversalEncoder( + name=name, + in_channels=in_channels, + depth=depth, + output_stride=output_stride, + pretrained=weights is not None, + **kwargs + ) + return encoder + + try: + Encoder = encoders[name]["encoder"] + except KeyError: + raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) + + params = encoders[name]["params"] + params.update(depth=depth) + encoder = Encoder(**params) + + if weights is not None: + try: + settings = encoders[name]["pretrained_settings"][weights] + except KeyError: + raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( + weights, name, list(encoders[name]["pretrained_settings"].keys()), + )) + encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE))) + + encoder.set_in_channels(in_channels, pretrained=weights is not None) + if output_stride != 32: + encoder.make_dilated(output_stride) + + return encoder + + +def get_encoder_names(): + return list(encoders.keys()) + + +def get_preprocessing_params(encoder_name, pretrained="imagenet"): + settings = encoders[encoder_name]["pretrained_settings"] + + if pretrained not in settings.keys(): + raise ValueError("Available pretrained options {}".format(settings.keys())) + + formatted_settings = {} + formatted_settings["input_space"] = settings[pretrained].get("input_space") + formatted_settings["input_range"] = settings[pretrained].get("input_range") + formatted_settings["mean"] = settings[pretrained].get("mean") + formatted_settings["std"] = settings[pretrained].get("std") + return formatted_settings + + +def get_preprocessing_fn(encoder_name, pretrained="imagenet"): + params = get_preprocessing_params(encoder_name, pretrained=pretrained) + return functools.partial(preprocess_input, **params) diff --git a/plugins/ai_method/packages/models/encoders/_base.py b/plugins/ai_method/packages/models/encoders/_base.py new file mode 100644 index 0000000..f4bca8b --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/_base.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from typing import List +from collections import OrderedDict + +from . import _utils as utils + + +class EncoderMixin: + """Add encoder functionality such as: + - output channels specification of feature tensors (produced by encoder) + - patching first convolution for arbitrary input channels + """ + + @property + def out_channels(self): + """Return channels dimensions for each tensor of forward output of encoder""" + return self._out_channels[: self._depth + 1] + + def set_in_channels(self, in_channels, pretrained=True): + """Change first convolution channels""" + if in_channels == 3: + return + + self._in_channels = in_channels + if self._out_channels[0] == 3: + self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) + + utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) + + def get_stages(self): + """Method should be overridden in encoder""" + raise NotImplementedError + + def make_dilated(self, output_stride): + + if output_stride == 16: + stage_list=[5,] + dilation_list=[2,] + + elif output_stride == 8: + stage_list=[4, 5] + dilation_list=[2, 4] + + else: + raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) + + stages = self.get_stages() + for stage_indx, dilation_rate in zip(stage_list, dilation_list): + utils.replace_strides_with_dilation( + module=stages[stage_indx], + dilation_rate=dilation_rate, + ) diff --git a/plugins/ai_method/packages/models/encoders/_preprocessing.py b/plugins/ai_method/packages/models/encoders/_preprocessing.py new file mode 100644 index 0000000..ec19d54 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/_preprocessing.py @@ -0,0 +1,23 @@ +import numpy as np + + +def preprocess_input( + x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs +): + + if input_space == "BGR": + x = x[..., ::-1].copy() + + if input_range is not None: + if x.max() > 1 and input_range[1] == 1: + x = x / 255.0 + + if mean is not None: + mean = np.array(mean) + x = x - mean + + if std is not None: + std = np.array(std) + x = x / std + + return x diff --git a/plugins/ai_method/packages/models/encoders/_utils.py b/plugins/ai_method/packages/models/encoders/_utils.py new file mode 100644 index 0000000..859151c --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/_utils.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + + +def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): + """Change first convolution layer input channels. + In case: + in_channels == 1 or in_channels == 2 -> reuse original weights + in_channels > 3 -> make random kaiming normal initialization + """ + + # get first conv + for module in model.modules(): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: + break + + weight = module.weight.detach() + module.in_channels = new_in_channels + + if not pretrained: + module.weight = nn.parameter.Parameter( + torch.Tensor( + module.out_channels, + new_in_channels // module.groups, + *module.kernel_size + ) + ) + module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + module.weight = nn.parameter.Parameter(new_weight) + + else: + new_weight = torch.Tensor( + module.out_channels, + new_in_channels // module.groups, + *module.kernel_size + ) + + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + module.weight = nn.parameter.Parameter(new_weight) + + +def replace_strides_with_dilation(module, dilation_rate): + """Patch Conv2d modules replacing strides with dilation""" + for mod in module.modules(): + if isinstance(mod, nn.Conv2d): + mod.stride = (1, 1) + mod.dilation = (dilation_rate, dilation_rate) + kh, kw = mod.kernel_size + mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) + + # Kostyl for EfficientNet + if hasattr(mod, "static_padding"): + mod.static_padding = nn.Identity() diff --git a/plugins/ai_method/packages/models/encoders/densenet.py b/plugins/ai_method/packages/models/encoders/densenet.py new file mode 100644 index 0000000..fae4e29 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/densenet.py @@ -0,0 +1,146 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import re +import torch.nn as nn + +from pretrainedmodels.models.torchvision_models import pretrained_settings +from torchvision.models.densenet import DenseNet + +from ._base import EncoderMixin + + +class TransitionWithSkip(nn.Module): + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + for module in self.module: + x = module(x) + if isinstance(module, nn.ReLU): + skip = x + return x, skip + + +class DenseNetEncoder(DenseNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + del self.classifier + + def make_dilated(self, output_stride): + raise ValueError("DenseNet encoders do not support dilated mode " + "due to pooling operation for downsampling!") + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0), + nn.Sequential(self.features.pool0, self.features.denseblock1, + TransitionWithSkip(self.features.transition1)), + nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)), + nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)), + nn.Sequential(self.features.denseblock4, self.features.norm5) + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + if isinstance(x, (list, tuple)): + x, skip = x + features.append(skip) + else: + features.append(x) + + return features + + def load_state_dict(self, state_dict): + pattern = re.compile( + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + # remove linear + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) + + super().load_state_dict(state_dict) + + +densenet_encoders = { + "densenet121": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet121"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 1024), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 24, 16), + }, + }, + "densenet169": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet169"], + "params": { + "out_channels": (3, 64, 256, 512, 1280, 1664), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 32, 32), + }, + }, + "densenet201": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet201"], + "params": { + "out_channels": (3, 64, 256, 512, 1792, 1920), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 48, 32), + }, + }, + "densenet161": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet161"], + "params": { + "out_channels": (3, 96, 384, 768, 2112, 2208), + "num_init_features": 96, + "growth_rate": 48, + "block_config": (6, 12, 36, 24), + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/dpn.py b/plugins/ai_method/packages/models/encoders/dpn.py new file mode 100644 index 0000000..7f1bd7d --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/dpn.py @@ -0,0 +1,170 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pretrainedmodels.models.dpn import DPN +from pretrainedmodels.models.dpn import pretrained_settings + +from ._base import EncoderMixin + + +class DPNEncoder(DPN, EncoderMixin): + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._stage_idxs = stage_idxs + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.last_linear + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act), + nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]), + self.features[self._stage_idxs[0] : self._stage_idxs[1]], + self.features[self._stage_idxs[1] : self._stage_idxs[2]], + self.features[self._stage_idxs[2] : self._stage_idxs[3]], + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + if isinstance(x, (list, tuple)): + features.append(F.relu(torch.cat(x, dim=1), inplace=True)) + else: + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +dpn_encoders = { + "dpn68": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn68"], + "params": { + "stage_idxs": (4, 8, 20, 24), + "out_channels": (3, 10, 144, 320, 704, 832), + "groups": 32, + "inc_sec": (16, 32, 32, 64), + "k_r": 128, + "k_sec": (3, 4, 12, 3), + "num_classes": 1000, + "num_init_features": 10, + "small": True, + "test_time_pool": True, + }, + }, + "dpn68b": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn68b"], + "params": { + "stage_idxs": (4, 8, 20, 24), + "out_channels": (3, 10, 144, 320, 704, 832), + "b": True, + "groups": 32, + "inc_sec": (16, 32, 32, 64), + "k_r": 128, + "k_sec": (3, 4, 12, 3), + "num_classes": 1000, + "num_init_features": 10, + "small": True, + "test_time_pool": True, + }, + }, + "dpn92": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn92"], + "params": { + "stage_idxs": (4, 8, 28, 32), + "out_channels": (3, 64, 336, 704, 1552, 2688), + "groups": 32, + "inc_sec": (16, 32, 24, 128), + "k_r": 96, + "k_sec": (3, 4, 20, 3), + "num_classes": 1000, + "num_init_features": 64, + "test_time_pool": True, + }, + }, + "dpn98": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn98"], + "params": { + "stage_idxs": (4, 10, 30, 34), + "out_channels": (3, 96, 336, 768, 1728, 2688), + "groups": 40, + "inc_sec": (16, 32, 32, 128), + "k_r": 160, + "k_sec": (3, 6, 20, 3), + "num_classes": 1000, + "num_init_features": 96, + "test_time_pool": True, + }, + }, + "dpn107": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn107"], + "params": { + "stage_idxs": (5, 13, 33, 37), + "out_channels": (3, 128, 376, 1152, 2432, 2688), + "groups": 50, + "inc_sec": (20, 64, 64, 128), + "k_r": 200, + "k_sec": (4, 8, 20, 3), + "num_classes": 1000, + "num_init_features": 128, + "test_time_pool": True, + }, + }, + "dpn131": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn131"], + "params": { + "stage_idxs": (5, 13, 41, 45), + "out_channels": (3, 128, 352, 832, 1984, 2688), + "groups": 40, + "inc_sec": (16, 32, 32, 128), + "k_r": 160, + "k_sec": (4, 8, 28, 3), + "num_classes": 1000, + "num_init_features": 128, + "test_time_pool": True, + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/efficientnet.py b/plugins/ai_method/packages/models/encoders/efficientnet.py new file mode 100644 index 0000000..d0bf2d9 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/efficientnet.py @@ -0,0 +1,178 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" +import torch.nn as nn +from efficientnet_pytorch import EfficientNet +from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params + +from ._base import EncoderMixin + + +class EfficientNetEncoder(EfficientNet, EncoderMixin): + def __init__(self, stage_idxs, out_channels, model_name, depth=5): + + blocks_args, global_params = get_model_params(model_name, override_params=None) + super().__init__(blocks_args, global_params) + + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self._fc + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self._conv_stem, self._bn0, self._swish), + self._blocks[:self._stage_idxs[0]], + self._blocks[self._stage_idxs[0]:self._stage_idxs[1]], + self._blocks[self._stage_idxs[1]:self._stage_idxs[2]], + self._blocks[self._stage_idxs[2]:], + ] + + def forward(self, x): + stages = self.get_stages() + + block_number = 0. + drop_connect_rate = self._global_params.drop_connect_rate + + features = [] + for i in range(self._depth + 1): + + # Identity and Sequential stages + if i < 2: + x = stages[i](x) + + # Block stages need drop_connect rate + else: + for module in stages[i]: + drop_connect = drop_connect_rate * block_number / len(self._blocks) + block_number += 1. + x = module(x, drop_connect) + + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("_fc.bias", None) + state_dict.pop("_fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +def _get_pretrained_settings(encoder): + pretrained_settings = { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": url_map[encoder], + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": url_map_advprop[encoder], + "input_space": "RGB", + "input_range": [0, 1], + } + } + return pretrained_settings + + +efficient_net_encoders = { + "efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (3, 5, 9, 16), + "model_name": "efficientnet-b0", + }, + }, + "efficientnet-b1": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (5, 8, 16, 23), + "model_name": "efficientnet-b1", + }, + }, + "efficientnet-b2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (5, 8, 16, 23), + "model_name": "efficientnet-b2", + }, + }, + "efficientnet-b3": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), + "params": { + "out_channels": (3, 40, 32, 48, 136, 384), + "stage_idxs": (5, 8, 18, 26), + "model_name": "efficientnet-b3", + }, + }, + "efficientnet-b4": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), + "params": { + "out_channels": (3, 48, 32, 56, 160, 448), + "stage_idxs": (6, 10, 22, 32), + "model_name": "efficientnet-b4", + }, + }, + "efficientnet-b5": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), + "params": { + "out_channels": (3, 48, 40, 64, 176, 512), + "stage_idxs": (8, 13, 27, 39), + "model_name": "efficientnet-b5", + }, + }, + "efficientnet-b6": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), + "params": { + "out_channels": (3, 56, 40, 72, 200, 576), + "stage_idxs": (9, 15, 31, 45), + "model_name": "efficientnet-b6", + }, + }, + "efficientnet-b7": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), + "params": { + "out_channels": (3, 64, 48, 80, 224, 640), + "stage_idxs": (11, 18, 38, 55), + "model_name": "efficientnet-b7", + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/inceptionresnetv2.py b/plugins/ai_method/packages/models/encoders/inceptionresnetv2.py new file mode 100644 index 0000000..bc2ffa0 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/inceptionresnetv2.py @@ -0,0 +1,90 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 +from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings + +from ._base import EncoderMixin + + +class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # correct paddings + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.kernel_size == (3, 3): + m.padding = (1, 1) + if isinstance(m, nn.MaxPool2d): + m.padding = (1, 1) + + # remove linear layers + del self.avgpool_1a + del self.last_linear + + def make_dilated(self, output_stride): + raise ValueError("InceptionResnetV2 encoder does not support dilated mode " + "due to pooling operation for downsampling!") + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), + nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), + nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), + nn.Sequential(self.mixed_6a, self.repeat_1), + nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +inceptionresnetv2_encoders = { + "inceptionresnetv2": { + "encoder": InceptionResNetV2Encoder, + "pretrained_settings": pretrained_settings["inceptionresnetv2"], + "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, + } +} diff --git a/plugins/ai_method/packages/models/encoders/inceptionv4.py b/plugins/ai_method/packages/models/encoders/inceptionv4.py new file mode 100644 index 0000000..b91d0ef --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/inceptionv4.py @@ -0,0 +1,93 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d +from pretrainedmodels.models.inceptionv4 import pretrained_settings + +from ._base import EncoderMixin + + +class InceptionV4Encoder(InceptionV4, EncoderMixin): + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # correct paddings + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.kernel_size == (3, 3): + m.padding = (1, 1) + if isinstance(m, nn.MaxPool2d): + m.padding = (1, 1) + + # remove linear layers + del self.last_linear + + def make_dilated(self, output_stride): + raise ValueError("InceptionV4 encoder does not support dilated mode " + "due to pooling operation for downsampling!") + + def get_stages(self): + return [ + nn.Identity(), + self.features[: self._stage_idxs[0]], + self.features[self._stage_idxs[0]: self._stage_idxs[1]], + self.features[self._stage_idxs[1]: self._stage_idxs[2]], + self.features[self._stage_idxs[2]: self._stage_idxs[3]], + self.features[self._stage_idxs[3]:], + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +inceptionv4_encoders = { + "inceptionv4": { + "encoder": InceptionV4Encoder, + "pretrained_settings": pretrained_settings["inceptionv4"], + "params": { + "stage_idxs": (3, 5, 9, 15), + "out_channels": (3, 64, 192, 384, 1024, 1536), + "num_classes": 1001, + }, + } +} diff --git a/plugins/ai_method/packages/models/encoders/mit_encoder.py b/plugins/ai_method/packages/models/encoders/mit_encoder.py new file mode 100644 index 0000000..df24632 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/mit_encoder.py @@ -0,0 +1,192 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin +from .mix_transformer import MixVisionTransformer + + +class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + def get_stages(self): + return [nn.Identity()] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for stage in stages: + x = stage(x) + features.append(x) + outs = self.forward_features(x) + add_feature = F.interpolate(outs[0], scale_factor=2) + features = features + [add_feature] + outs + return features + + def load_state_dict(self, state_dict, **kwargs): + new_state_dict = {} + if state_dict.get('state_dict'): + state_dict = state_dict['state_dict'] + for k, v in state_dict.items(): + if k.startswith('backbone'): + new_state_dict[k.replace('backbone.', '')] = v + else: + new_state_dict = deepcopy(state_dict) + super().load_state_dict(new_state_dict, **kwargs) + + +# https://github.com/NVlabs/SegFormer +new_settings = { + "mit-b0": { + "imagenet": "https://lino.local.server/mit_b0.pth" + }, + "mit-b1": { + "imagenet": "https://lino.local.server/mit_b1.pth" + }, + "mit-b2": { + "imagenet": "https://lino.local.server/mit_b2.pth" + }, + "mit-b3": { + "imagenet": "https://lino.local.server/mit_b3.pth" + }, + "mit-b4": { + "imagenet": "https://lino.local.server/mit_b4.pth" + }, + "mit-b5": { + "imagenet": "https://lino.local.server/mit_b5.pth" + }, +} + +pretrained_settings = deepcopy(pretrained_settings) +for model_name, sources in new_settings.items(): + if model_name not in pretrained_settings: + pretrained_settings[model_name] = {} + + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +mit_encoders = { + "mit-b0": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b0"], + "params": { + "patch_size": 4, + "embed_dims": [32, 64, 160, 256], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 32, 32, 64, 160, 256) + } + }, + "mit-b1": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b1"], + "params": { + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 64, 64, 128, 320, 512) + } + }, + "mit-b2": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b2"], + "params": { + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [3, 4, 6, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 64, 64, 128, 320, 512) + } + }, + "mit-b3": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b3"], + "params": { + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [3, 4, 18, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 64, 64, 128, 320, 512) + } + }, + "mit-b4": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b4"], + "params": { + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [3, 8, 27, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 64, 64, 128, 320, 512) + } + }, + "mit-b5": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": pretrained_settings["mit-b5"], + "params": { + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(nn.LayerNorm, eps=1e-6), + "depths": [3, 6, 40, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "out_channels": (3, 64, 64, 128, 320, 512) + } + }, +} diff --git a/plugins/ai_method/packages/models/encoders/mix_transformer.py b/plugins/ai_method/packages/models/encoders/mix_transformer.py new file mode 100644 index 0000000..77d458a --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/mix_transformer.py @@ -0,0 +1,361 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +import math + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x diff --git a/plugins/ai_method/packages/models/encoders/mobilenet.py b/plugins/ai_method/packages/models/encoders/mobilenet.py new file mode 100644 index 0000000..8bfdb10 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/mobilenet.py @@ -0,0 +1,83 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torchvision +import torch.nn as nn + +from ._base import EncoderMixin + + +class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): + + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + del self.classifier + + def get_stages(self): + return [ + nn.Identity(), + self.features[:2], + self.features[2:4], + self.features[4:7], + self.features[7:14], + self.features[14:], + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("classifier.1.bias", None) + state_dict.pop("classifier.1.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +mobilenet_encoders = { + "mobilenet_v2": { + "encoder": MobileNetV2Encoder, + "pretrained_settings": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "params": { + "out_channels": (3, 16, 24, 32, 96, 1280), + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/resnet.py b/plugins/ai_method/packages/models/encoders/resnet.py new file mode 100644 index 0000000..5528bd5 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/resnet.py @@ -0,0 +1,238 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" +from copy import deepcopy + +import torch.nn as nn + +from torchvision.models.resnet import ResNet +from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import Bottleneck +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin + + +class ResNetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.avgpool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.relu), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +new_settings = { + "resnet18": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth" + }, + "resnet50": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth" + }, + "resnext50_32x4d": { + "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", + }, + "resnext101_32x4d": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth" + }, + "resnext101_32x8d": { + "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", + }, + "resnext101_32x16d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", + }, + "resnext101_32x32d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth", + }, + "resnext101_32x48d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth", + } +} + +pretrained_settings = deepcopy(pretrained_settings) +for model_name, sources in new_settings.items(): + if model_name not in pretrained_settings: + pretrained_settings[model_name] = {} + + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + + +resnet_encoders = { + "resnet18": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet18"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": BasicBlock, + "layers": [2, 2, 2, 2], + }, + }, + "resnet34": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet34"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": BasicBlock, + "layers": [3, 4, 6, 3], + }, + }, + "resnet50": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet50"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 6, 3], + }, + }, + "resnet101": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet101"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + }, + }, + "resnet152": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet152"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 8, 36, 3], + }, + }, + "resnext50_32x4d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext50_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "groups": 32, + "width_per_group": 4, + }, + }, + "resnext101_32x4d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 4, + }, + }, + "resnext101_32x8d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x8d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 8, + }, + }, + "resnext101_32x16d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x16d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 16, + }, + }, + "resnext101_32x32d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x32d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 32, + }, + }, + "resnext101_32x48d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x48d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 48, + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/senet.py b/plugins/ai_method/packages/models/encoders/senet.py new file mode 100644 index 0000000..7cdbdbe --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/senet.py @@ -0,0 +1,174 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn + +from pretrainedmodels.models.senet import ( + SENet, + SEBottleneck, + SEResNetBottleneck, + SEResNeXtBottleneck, + pretrained_settings, +) +from ._base import EncoderMixin + + +class SENetEncoder(SENet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self.last_linear + del self.avg_pool + + def get_stages(self): + return [ + nn.Identity(), + self.layer0[:-1], + nn.Sequential(self.layer0[-1], self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +senet_encoders = { + "senet154": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["senet154"], + "params": { + "out_channels": (3, 128, 256, 512, 1024, 2048), + "block": SEBottleneck, + "dropout_p": 0.2, + "groups": 64, + "layers": [3, 8, 36, 3], + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet50": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet50"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 4, 6, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet101": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet101"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 4, 23, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet152": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet152"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 8, 36, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnext50_32x4d": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNeXtBottleneck, + "layers": [3, 4, 6, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 32, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnext101_32x4d": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNeXtBottleneck, + "layers": [3, 4, 23, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 32, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/swin_transformer.py b/plugins/ai_method/packages/models/encoders/swin_transformer.py new file mode 100644 index 0000000..9a443a2 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/swin_transformer.py @@ -0,0 +1,196 @@ +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from collections import OrderedDict +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin +from .swin_transformer_model import SwinTransformer + + +class SwinTransformerEncoder(SwinTransformer, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + def get_stages(self): + return [nn.Identity()] + + def feature_forward(self, x): + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + return outs + + def forward(self, x): + stages = self.get_stages() + + features = [] + for stage in stages: + x = stage(x) + features.append(x) + outs = self.feature_forward(x) + + # Note: An additional interpolated feature to accommodate five-stage decoders,\ + # the additional feature will be ignored if a decoder with fewer stages is used. + add_feature = F.interpolate(outs[0], scale_factor=2) + features = features + [add_feature] + outs + return features + + def load_state_dict(self, state_dict, **kwargs): + + new_state_dict = OrderedDict() + + if 'state_dict' in state_dict: + _state_dict = state_dict['state_dict'] + elif 'model' in state_dict: + _state_dict = state_dict['model'] + else: + _state_dict = state_dict + + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + else: + new_state_dict[k] = v + + # Note: In swin seg model: `attn_mask` is no longer a class attribute for + # multi-scale inputs; a norm layer is added for each output; the head layer + # is removed. + kwargs.update({'strict': False}) + super().load_state_dict(new_state_dict, **kwargs) + + +# https://github.com/microsoft/Swin-Transformer +new_settings = { + "Swin-T": { + "imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth", + "ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth" + }, + "Swin-S": { + "imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", + "ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_small_patch4_window7_512x512.pth" + }, + "Swin-B": { + "imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth", + "imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth", + "ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth" + }, + "Swin-L": { + "imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth" + }, + +} + +pretrained_settings = deepcopy(pretrained_settings) +for model_name, sources in new_settings.items(): + if model_name not in pretrained_settings: + pretrained_settings[model_name] = {} + + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +swin_transformer_encoders = { + "Swin-T": { + "encoder": SwinTransformerEncoder, + "pretrained_settings": pretrained_settings["Swin-T"], + "params": { + "embed_dim": 96, + "out_channels": (3, 96, 96, 192, 384, 768), + "depths": [2, 2, 6, 2], + "num_heads": [3, 6, 12, 24], + "window_size": 7, + "ape": False, + "drop_path_rate": 0.3, + "patch_norm": True, + "use_checkpoint": False + } + }, + "Swin-S": { + "encoder": SwinTransformerEncoder, + "pretrained_settings": pretrained_settings["Swin-S"], + "params": { + "embed_dim": 96, + "out_channels": (3, 96, 96, 192, 384, 768), + "depths": [2, 2, 18, 2], + "num_heads": [3, 6, 12, 24], + "window_size": 7, + "ape": False, + "drop_path_rate": 0.3, + "patch_norm": True, + "use_checkpoint": False + } + }, + "Swin-B": { + "encoder": SwinTransformerEncoder, + "pretrained_settings": pretrained_settings["Swin-B"], + "params": { + "embed_dim": 128, + "out_channels": (3, 128, 128, 256, 512, 1024), + "depths": [2, 2, 18, 2], + "num_heads": [4, 8, 16, 32], + "window_size": 7, + "ape": False, + "drop_path_rate": 0.3, + "patch_norm": True, + "use_checkpoint": False + } + }, + "Swin-L": { + "encoder": SwinTransformerEncoder, + "pretrained_settings": pretrained_settings["Swin-L"], + "params": { + "embed_dim": 192, + "out_channels": (3, 192, 192, 384, 768, 1536), + "depths": [2, 2, 18, 2], + "num_heads": [6, 12, 24, 48], + "window_size": 7, + "ape": False, + "drop_path_rate": 0.3, + "patch_norm": True, + "use_checkpoint": False + } + } + +} + +if __name__ == "__main__": + import torch + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + input = torch.randn(1, 3, 256, 256).to(device) + + model = SwinTransformerEncoder(2, window_size=8) + # print(model) + + res = model.forward(input) + for i in res: + print(i.shape) diff --git a/plugins/ai_method/packages/models/encoders/swin_transformer_model.py b/plugins/ai_method/packages/models/encoders/swin_transformer_model.py new file mode 100644 index 0000000..500d18a --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/swin_transformer_model.py @@ -0,0 +1,626 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + **kwargs + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.init_weights() # the pre-trained model will be loaded later if needed + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + # logger = get_root_logger() + # load_checkpoint(self, pretrained, strict=False, logger=logger) + pass + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/plugins/ai_method/packages/models/encoders/timm_efficientnet.py b/plugins/ai_method/packages/models/encoders/timm_efficientnet.py new file mode 100644 index 0000000..ddac946 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_efficientnet.py @@ -0,0 +1,382 @@ +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.efficientnet import EfficientNet +from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs +from timm.models.layers.activations import Swish + +from ._base import EncoderMixin + + +def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2): + """Creates an EfficientNet model. + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=Swish, + drop_rate=drop_rate, + drop_path_rate=0.2, + ) + return model_kwargs + +def gen_efficientnet_lite_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=nn.ReLU6, + drop_rate=drop_rate, + drop_path_rate=0.2, + ) + return model_kwargs + +class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): + + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self.classifier + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv_stem, self.bn1, self.act1), + self.blocks[:self._stage_idxs[0]], + self.blocks[self._stage_idxs[0]:self._stage_idxs[1]], + self.blocks[self._stage_idxs[1]:self._stage_idxs[2]], + self.blocks[self._stage_idxs[2]:], + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +class EfficientNetEncoder(EfficientNetBaseEncoder): + + def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2): + kwargs = get_efficientnet_kwargs(channel_multiplier, depth_multiplier, drop_rate) + super().__init__(stage_idxs, out_channels, depth, **kwargs) + + +class EfficientNetLiteEncoder(EfficientNetBaseEncoder): + + def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2): + kwargs = gen_efficientnet_lite_kwargs(channel_multiplier, depth_multiplier, drop_rate) + super().__init__(stage_idxs, out_channels, depth, **kwargs) + + +def prepare_settings(settings): + return { + "mean": settings["mean"], + "std": settings["std"], + "url": settings["url"], + "input_range": (0, 1), + "input_space": "RGB", + } + + +timm_efficientnet_encoders = { + + "timm-efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b0"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b0_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b0_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, + + "timm-efficientnet-b1": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b1"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b1_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b1_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.1, + "drop_rate": 0.2, + }, + }, + + "timm-efficientnet-b2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b2"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b2_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b2_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.1, + "depth_multiplier": 1.2, + "drop_rate": 0.3, + }, + }, + + "timm-efficientnet-b3": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b3"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b3_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b3_ns"]), + }, + "params": { + "out_channels": (3, 40, 32, 48, 136, 384), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.2, + "depth_multiplier": 1.4, + "drop_rate": 0.3, + }, + }, + + "timm-efficientnet-b4": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b4"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b4_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b4_ns"]), + }, + "params": { + "out_channels": (3, 48, 32, 56, 160, 448), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.4, + "depth_multiplier": 1.8, + "drop_rate": 0.4, + }, + }, + + "timm-efficientnet-b5": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b5"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b5_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b5_ns"]), + }, + "params": { + "out_channels": (3, 48, 40, 64, 176, 512), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.6, + "depth_multiplier": 2.2, + "drop_rate": 0.4, + }, + }, + + "timm-efficientnet-b6": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b6"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b6_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b6_ns"]), + }, + "params": { + "out_channels": (3, 56, 40, 72, 200, 576), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.8, + "depth_multiplier": 2.6, + "drop_rate": 0.5, + }, + }, + + "timm-efficientnet-b7": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b7"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b7_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b7_ns"]), + }, + "params": { + "out_channels": (3, 64, 48, 80, 224, 640), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 2.0, + "depth_multiplier": 3.1, + "drop_rate": 0.5, + }, + }, + + "timm-efficientnet-b8": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b8"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b8_ap"]), + }, + "params": { + "out_channels": (3, 72, 56, 88, 248, 704), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 2.2, + "depth_multiplier": 3.6, + "drop_rate": 0.5, + }, + }, + + "timm-efficientnet-l2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_l2_ns"]), + }, + "params": { + "out_channels": (3, 136, 104, 176, 480, 1376), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 4.3, + "depth_multiplier": 5.3, + "drop_rate": 0.5, + }, + }, + + "timm-tf_efficientnet_lite0": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite0"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, + + "timm-tf_efficientnet_lite1": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite1"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.1, + "drop_rate": 0.2, + }, + }, + + "timm-tf_efficientnet_lite2": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite2"]), + }, + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.1, + "depth_multiplier": 1.2, + "drop_rate": 0.3, + }, + }, + + "timm-tf_efficientnet_lite3": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite3"]), + }, + "params": { + "out_channels": (3, 32, 32, 48, 136, 384), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.2, + "depth_multiplier": 1.4, + "drop_rate": 0.3, + }, + }, + + "timm-tf_efficientnet_lite4": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite4"]), + }, + "params": { + "out_channels": (3, 32, 32, 56, 160, 448), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.4, + "depth_multiplier": 1.8, + "drop_rate": 0.4, + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/timm_gernet.py b/plugins/ai_method/packages/models/encoders/timm_gernet.py new file mode 100644 index 0000000..f98c030 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_gernet.py @@ -0,0 +1,124 @@ +from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet + +from ._base import EncoderMixin +import torch.nn as nn + + +class GERNetEncoder(ByobNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.stages[0], + self.stages[1], + self.stages[2], + nn.Sequential(self.stages[3], self.stages[4], self.final_conv) + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + 'timm-gernet_s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth', + }, + 'timm-gernet_m': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth', + }, + 'timm-gernet_l': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth', + }, +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +timm_gernet_encoders = { + 'timm-gernet_s': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_s"], + 'params': { + 'out_channels': (3, 13, 48, 48, 384, 1920), + 'cfg': ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ), + stem_chs=13, + stem_pool=None, + num_features=1920, + ) + }, + }, + 'timm-gernet_m': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_m"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 2560), + 'cfg': ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ) + }, + }, + 'timm-gernet_l': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_l"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 2560), + 'cfg': ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ) + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/timm_mobilenetv3.py b/plugins/ai_method/packages/models/encoders/timm_mobilenetv3.py new file mode 100644 index 0000000..a4ab6ec --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_mobilenetv3.py @@ -0,0 +1,175 @@ +import timm +import numpy as np +import torch.nn as nn + +from ._base import EncoderMixin + + +def _make_divisible(x, divisible_by=8): + return int(np.ceil(x * 1. / divisible_by) * divisible_by) + + +class MobileNetV3Encoder(nn.Module, EncoderMixin): + def __init__(self, model_name, width_mult, depth=5, **kwargs): + super().__init__() + if "large" not in model_name and "small" not in model_name: + raise ValueError( + 'MobileNetV3 wrong model name {}'.format(model_name) + ) + + self._mode = "small" if "small" in model_name else "large" + self._depth = depth + self._out_channels = self._get_channels(self._mode, width_mult) + self._in_channels = 3 + + # minimal models replace hardswish with relu + self.model = timm.create_model( + model_name=model_name, + scriptable=True, # torch.jit scriptable + exportable=True, # onnx export + features_only=True, + ) + + def _get_channels(self, mode, width_mult): + if mode == "small": + channels = [16, 16, 24, 48, 576] + else: + channels = [16, 24, 40, 112, 960] + channels = [3,] + [_make_divisible(x * width_mult) for x in channels] + return tuple(channels) + + def get_stages(self): + if self._mode == 'small': + return [ + nn.Identity(), + nn.Sequential( + self.model.conv_stem, + self.model.bn1, + self.model.act1, + ), + self.model.blocks[0], + self.model.blocks[1], + self.model.blocks[2:4], + self.model.blocks[4:], + ] + elif self._mode == 'large': + return [ + nn.Identity(), + nn.Sequential( + self.model.conv_stem, + self.model.bn1, + self.model.act1, + self.model.blocks[0], + ), + self.model.blocks[1], + self.model.blocks[2], + self.model.blocks[3:5], + self.model.blocks[5:], + ] + else: + ValueError('MobileNetV3 mode should be small or large, got {}'.format(self._mode)) + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop('conv_head.weight', None) + state_dict.pop('conv_head.bias', None) + state_dict.pop('classifier.weight', None) + state_dict.pop('classifier.bias', None) + self.model.load_state_dict(state_dict, **kwargs) + + +mobilenetv3_weights = { + 'tf_mobilenetv3_large_075': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth' + }, + 'tf_mobilenetv3_large_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth' + }, + 'tf_mobilenetv3_large_minimal_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth' + }, + 'tf_mobilenetv3_small_075': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth' + }, + 'tf_mobilenetv3_small_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth' + }, + 'tf_mobilenetv3_small_minimal_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth' + }, + + +} + +pretrained_settings = {} +for model_name, sources in mobilenetv3_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'input_space': 'RGB', + } + + +timm_mobilenetv3_encoders = { + 'timm-mobilenetv3_large_075': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'], + 'params': { + 'model_name': 'tf_mobilenetv3_large_075', + 'width_mult': 0.75 + } + }, + 'timm-mobilenetv3_large_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'], + 'params': { + 'model_name': 'tf_mobilenetv3_large_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_large_minimal_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'], + 'params': { + 'model_name': 'tf_mobilenetv3_large_minimal_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_small_075': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'], + 'params': { + 'model_name': 'tf_mobilenetv3_small_075', + 'width_mult': 0.75 + } + }, + 'timm-mobilenetv3_small_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'], + 'params': { + 'model_name': 'tf_mobilenetv3_small_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_small_minimal_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'], + 'params': { + 'model_name': 'tf_mobilenetv3_small_minimal_100', + 'width_mult': 1.0 + } + }, +} diff --git a/plugins/ai_method/packages/models/encoders/timm_regnet.py b/plugins/ai_method/packages/models/encoders/timm_regnet.py new file mode 100644 index 0000000..7d801be --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_regnet.py @@ -0,0 +1,332 @@ +from ._base import EncoderMixin +from timm.models.regnet import RegNet +import torch.nn as nn + + +class RegNetEncoder(RegNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.s1, + self.s2, + self.s3, + self.s4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + 'timm-regnetx_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth', + }, + 'timm-regnetx_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth', + }, + 'timm-regnetx_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth', + }, + 'timm-regnetx_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth', + }, + 'timm-regnetx_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth', + }, + 'timm-regnetx_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth', + }, + 'timm-regnetx_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth', + }, + 'timm-regnetx_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth', + }, + 'timm-regnetx_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth', + }, + 'timm-regnetx_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth', + }, + 'timm-regnetx_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth', + }, + 'timm-regnetx_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth', + }, + 'timm-regnety_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth', + }, + 'timm-regnety_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth', + }, + 'timm-regnety_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth', + }, + 'timm-regnety_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth', + }, + 'timm-regnety_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth', + }, + 'timm-regnety_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth' + }, + 'timm-regnety_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth' + }, + 'timm-regnety_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth' + }, + 'timm-regnety_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth', + }, + 'timm-regnety_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth', + }, + 'timm-regnety_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth', + }, + 'timm-regnety_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' + } +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +timm_regnet_encoders = { + 'timm-regnetx_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13) + }, + }, + 'timm-regnetx_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_004"], + 'params': { + 'out_channels': (3, 32, 32, 64, 160, 384), + 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22) + }, + }, + 'timm-regnetx_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_006"], + 'params': { + 'out_channels': (3, 32, 48, 96, 240, 528), + 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16) + }, + }, + 'timm-regnetx_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 288, 672), + 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16) + }, + }, + 'timm-regnetx_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_016"], + 'params': { + 'out_channels': (3, 32, 72, 168, 408, 912), + 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18) + }, + }, + 'timm-regnetx_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_032"], + 'params': { + 'out_channels': (3, 32, 96, 192, 432, 1008), + 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25) + }, + }, + 'timm-regnetx_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_040"], + 'params': { + 'out_channels': (3, 32, 80, 240, 560, 1360), + 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23) + }, + }, + 'timm-regnetx_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_064"], + 'params': { + 'out_channels': (3, 32, 168, 392, 784, 1624), + 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17) + }, + }, + 'timm-regnetx_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_080"], + 'params': { + 'out_channels': (3, 32, 80, 240, 720, 1920), + 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23) + }, + }, + 'timm-regnetx_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19) + }, + }, + 'timm-regnetx_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_160"], + 'params': { + 'out_channels': (3, 32, 256, 512, 896, 2048), + 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22) + }, + }, + 'timm-regnetx_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_320"], + 'params': { + 'out_channels': (3, 32, 336, 672, 1344, 2520), + 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23) + }, + }, + #regnety + 'timm-regnety_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25) + }, + }, + 'timm-regnety_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_004"], + 'params': { + 'out_channels': (3, 32, 48, 104, 208, 440), + 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25) + }, + }, + 'timm-regnety_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_006"], + 'params': { + 'out_channels': (3, 32, 48, 112, 256, 608), + 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25) + }, + }, + 'timm-regnety_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 320, 768), + 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25) + }, + }, + 'timm-regnety_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_016"], + 'params': { + 'out_channels': (3, 32, 48, 120, 336, 888), + 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25) + }, + }, + 'timm-regnety_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_032"], + 'params': { + 'out_channels': (3, 32, 72, 216, 576, 1512), + 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25) + }, + }, + 'timm-regnety_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_040"], + 'params': { + 'out_channels': (3, 32, 128, 192, 512, 1088), + 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25) + }, + }, + 'timm-regnety_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_064"], + 'params': { + 'out_channels': (3, 32, 144, 288, 576, 1296), + 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25) + }, + }, + 'timm-regnety_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_080"], + 'params': { + 'out_channels': (3, 32, 168, 448, 896, 2016), + 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25) + }, + }, + 'timm-regnety_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25) + }, + }, + 'timm-regnety_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_160"], + 'params': { + 'out_channels': (3, 32, 224, 448, 1232, 3024), + 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25) + }, + }, + 'timm-regnety_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_320"], + 'params': { + 'out_channels': (3, 32, 232, 696, 1392, 3712), + 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25) + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/timm_res2net.py b/plugins/ai_method/packages/models/encoders/timm_res2net.py new file mode 100644 index 0000000..3dca78a --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_res2net.py @@ -0,0 +1,163 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.res2net import Bottle2neck +import torch.nn as nn + + +class Res2NetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def make_dilated(self, output_stride): + raise ValueError("Res2Net encoders do not support dilated mode") + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +res2net_weights = { + 'timm-res2net50_26w_4s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth' + }, + 'timm-res2net50_48w_2s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth' + }, + 'timm-res2net50_14w_8s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth', + }, + 'timm-res2net50_26w_6s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth', + }, + 'timm-res2net50_26w_8s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth', + }, + 'timm-res2net101_26w_4s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth', + }, + 'timm-res2next50': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth', + } +} + +pretrained_settings = {} +for model_name, sources in res2net_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + + +timm_res2net_encoders = { + 'timm-res2net50_26w_4s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 4} + }, + }, + 'timm-res2net101_26w_4s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 23, 3], + 'base_width': 26, + 'block_args': {'scale': 4} + }, + }, + 'timm-res2net50_26w_6s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 6} + }, + }, + 'timm-res2net50_26w_8s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 8} + }, + }, + 'timm-res2net50_48w_2s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 48, + 'block_args': {'scale': 2} + }, + }, + 'timm-res2net50_14w_8s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 14, + 'block_args': {'scale': 8} + }, + }, + 'timm-res2next50': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2next50"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 4, + 'cardinality': 8, + 'block_args': {'scale': 4} + }, + } +} diff --git a/plugins/ai_method/packages/models/encoders/timm_resnest.py b/plugins/ai_method/packages/models/encoders/timm_resnest.py new file mode 100644 index 0000000..100df1d --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_resnest.py @@ -0,0 +1,208 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.resnest import ResNestBottleneck +import torch.nn as nn + + +class ResNestEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def make_dilated(self, output_stride): + raise ValueError("ResNest encoders do not support dilated mode") + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +resnest_weights = { + 'timm-resnest14d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth' + }, + 'timm-resnest26d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth' + }, + 'timm-resnest50d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth', + }, + 'timm-resnest101e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', + }, + 'timm-resnest200e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', + }, + 'timm-resnest269e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', + }, + 'timm-resnest50d_4s2x40d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', + }, + 'timm-resnest50d_1s4x24d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', + } +} + +pretrained_settings = {} +for model_name, sources in resnest_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + + +timm_resnest_encoders = { + 'timm-resnest14d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest14d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [1, 1, 1, 1], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest26d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest26d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [2, 2, 2, 2], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest50d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest101e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest101e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 23, 3], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest200e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest200e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 24, 36, 3], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest269e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest269e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 30, 48, 8], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + }, + }, + 'timm-resnest50d_4s2x40d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 40, + 'cardinality': 2, + 'block_args': {'radix': 4, 'avd': True, 'avd_first': True} + } + }, + 'timm-resnest50d_1s4x24d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 24, + 'cardinality': 4, + 'block_args': {'radix': 1, 'avd': True, 'avd_first': True} + } + } +} diff --git a/plugins/ai_method/packages/models/encoders/timm_sknet.py b/plugins/ai_method/packages/models/encoders/timm_sknet.py new file mode 100644 index 0000000..38804d9 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_sknet.py @@ -0,0 +1,103 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic +import torch.nn as nn + + +class SkNetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +sknet_weights = { + 'timm-skresnet18': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth' + }, + 'timm-skresnet34': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth' + }, + 'timm-skresnext50_32x4d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth', + } +} + +pretrained_settings = {} +for model_name, sources in sknet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +timm_sknet_encoders = { + 'timm-skresnet18': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet18"], + 'params': { + 'out_channels': (3, 64, 64, 128, 256, 512), + 'block': SelectiveKernelBasic, + 'layers': [2, 2, 2, 2], + 'zero_init_last_bn': False, + 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} + } + }, + 'timm-skresnet34': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet34"], + 'params': { + 'out_channels': (3, 64, 64, 128, 256, 512), + 'block': SelectiveKernelBasic, + 'layers': [3, 4, 6, 3], + 'zero_init_last_bn': False, + 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} + } + }, + 'timm-skresnext50_32x4d': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': SelectiveKernelBottleneck, + 'layers': [3, 4, 6, 3], + 'zero_init_last_bn': False, + 'cardinality': 32, + 'base_width': 4 + } + } +} diff --git a/plugins/ai_method/packages/models/encoders/timm_universal.py b/plugins/ai_method/packages/models/encoders/timm_universal.py new file mode 100644 index 0000000..c818e54 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/timm_universal.py @@ -0,0 +1,34 @@ +import timm +import torch.nn as nn + + +class TimmUniversalEncoder(nn.Module): + + def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): + super().__init__() + kwargs = dict( + in_chans=in_channels, + features_only=True, + output_stride=output_stride, + pretrained=pretrained, + out_indices=tuple(range(depth)), + ) + + # not all models support output stride argument, drop it by default + if output_stride == 32: + kwargs.pop("output_stride") + + self.model = timm.create_model(name, **kwargs) + + self._in_channels = in_channels + self._out_channels = [3, ] + self.model.feature_info.channels() + self._depth = depth + + def forward(self, x): + features = self.model(x) + features = [x,] + features + return features + + @property + def out_channels(self): + return self._out_channels diff --git a/plugins/ai_method/packages/models/encoders/vgg.py b/plugins/ai_method/packages/models/encoders/vgg.py new file mode 100644 index 0000000..fe5f4f2 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/vgg.py @@ -0,0 +1,157 @@ +""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from torchvision.models.vgg import VGG +from torchvision.models.vgg import make_layers +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin + +# fmt: off +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} +# fmt: on + + +class VGGEncoder(VGG, EncoderMixin): + def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): + super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + del self.classifier + + def make_dilated(self, output_stride): + raise ValueError("'VGG' models do not support dilated mode due to Max Pooling" + " operations for downsampling!") + + def get_stages(self): + stages = [] + stage_modules = [] + for module in self.features: + if isinstance(module, nn.MaxPool2d): + stages.append(nn.Sequential(*stage_modules)) + stage_modules = [] + stage_modules.append(module) + stages.append(nn.Sequential(*stage_modules)) + return stages + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + keys = list(state_dict.keys()) + for k in keys: + if k.startswith("classifier"): + state_dict.pop(k, None) + super().load_state_dict(state_dict, **kwargs) + + +vgg_encoders = { + "vgg11": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg11"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["A"], + "batch_norm": False, + }, + }, + "vgg11_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg11_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["A"], + "batch_norm": True, + }, + }, + "vgg13": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg13"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["B"], + "batch_norm": False, + }, + }, + "vgg13_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg13_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["B"], + "batch_norm": True, + }, + }, + "vgg16": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg16"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["D"], + "batch_norm": False, + }, + }, + "vgg16_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg16_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["D"], + "batch_norm": True, + }, + }, + "vgg19": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg19"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["E"], + "batch_norm": False, + }, + }, + "vgg19_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg19_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["E"], + "batch_norm": True, + }, + }, +} diff --git a/plugins/ai_method/packages/models/encoders/xception.py b/plugins/ai_method/packages/models/encoders/xception.py new file mode 100644 index 0000000..480d269 --- /dev/null +++ b/plugins/ai_method/packages/models/encoders/xception.py @@ -0,0 +1,66 @@ +import re +import torch.nn as nn + +from pretrainedmodels.models.xception import pretrained_settings +from pretrainedmodels.models.xception import Xception + +from ._base import EncoderMixin + + +class XceptionEncoder(Xception, EncoderMixin): + + def __init__(self, out_channels, *args, depth=5, **kwargs): + super().__init__(*args, **kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # modify padding to maintain output shape + self.conv1.padding = (1, 1) + self.conv2.padding = (1, 1) + + del self.fc + + def make_dilated(self, output_stride): + raise ValueError("Xception encoder does not support dilated mode " + "due to pooling operation for downsampling!") + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu), + self.block1, + self.block2, + nn.Sequential(self.block3, self.block4, self.block5, self.block6, self.block7, + self.block8, self.block9, self.block10, self.block11), + nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4), + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict): + # remove linear + state_dict.pop('fc.bias', None) + state_dict.pop('fc.weight', None) + + super().load_state_dict(state_dict) + + +xception_encoders = { + 'xception': { + 'encoder': XceptionEncoder, + 'pretrained_settings': pretrained_settings['xception'], + 'params': { + 'out_channels': (3, 64, 128, 256, 728, 2048), + } + }, +} diff --git a/plugins/ai_method/packages/sta_net.py b/plugins/ai_method/packages/sta_net.py deleted file mode 100644 index 6484bdc..0000000 --- a/plugins/ai_method/packages/sta_net.py +++ /dev/null @@ -1,109 +0,0 @@ -from rscder.utils.icons import IconInstance -from ai_method.basic_cd import AIMethodDialog, AI_METHOD, AlgFrontend -import os -from PyQt5.QtWidgets import QFormLayout, QLabel, QLineEdit, QPushButton, QDialogButtonBox, QWidget -from PyQt5.QtCore import Qt - -class STAParams(AlgFrontend): - - options = dict( - train_dir= dict( - label = '训练集根目录', - opt='--dataroot', - dtype='str', - default=None - ), - val_dir = dict( - label = '验证集根目录', - opt='--val_dataroot', - dtype='str', - default=None - ), - # test_dir = dict( - # label = '测试集根目录', - # opt='--dataroot', - # dtype='str', - # default=None - # ) - ) - - @staticmethod - def get_widget(parent=None): - widget = QWidget(parent) - form = QFormLayout(widget) - # form.setFormAlignment() - # for key in STAParams.options: - - train_dir_label = QLabel('训练集根目录') - train_dir_data = QLineEdit() - train_dir_data.setObjectName('train_dir') - form.addRow(train_dir_label, train_dir_data) - - val_dir_label = QLabel('验证集根目录') - val_dir_data = QLineEdit() - val_dir_data.setObjectName('val_dir') - form.addRow(val_dir_label, val_dir_data) - - test_dir_label = QLabel('测试集根目录') - test_dir_data = QLineEdit() - # test_dir_data.tex - test_dir_data.setObjectName('test_dir') - form.addRow(test_dir_label, test_dir_data) - - widget.setLayout(form) - return widget - - @staticmethod - def get_params(widget:QWidget=None): - if widget is None: - return None - opt = [] - for key in STAParams.options: - comp:QLineEdit = widget.findChild(QLineEdit, name=key) - if comp is None: - if STAParams.options[key]['default'] is not None: - opt.append(STAParams.options[key]['opt']) - opt.append(STAParams.options[key]['default']) - continue - - opt.append(STAParams.options[key]['opt']) - opt.append(comp.text()) - - return opt - - -class STAMethod(AIMethodDialog): - - ENV = 'torch1121cu113' - setting_widget = STAParams - stages = [ ('train', '训练'), ('test', '测试'), ('predict_batch', '批量预测') ] - name = 'STA Net' - - @property - def workdir(self): - return os.path.abspath(os.path.join(os.path.dirname(__file__), 'STA_net')) - - def stage_script(self, stage): - if stage == 'train': - return os.path.join(self.workdir, 'train.py') - elif stage == 'test': - return os.path.join(self.workdir, 'test.py') - else: - return None - - -@AI_METHOD.register -class STANet(AlgFrontend): - - @staticmethod - def get_name(): - return 'STA Net' - - @staticmethod - def get_icon(): - return IconInstance().AI_DETECT - - @staticmethod - def get_widget(parent=None): - return STAMethod(parent) - \ No newline at end of file diff --git a/plugins/ai_method/utils.py b/plugins/ai_method/utils.py new file mode 100644 index 0000000..e69de29