This commit is contained in:
copper 2023-04-18 11:04:50 +08:00
parent 08550bd746
commit 804a3064f6
106 changed files with 6061 additions and 3730 deletions

View File

@ -1,5 +1,6 @@
from misc.utils import Register
AI_METHOD = Register('AI Method')
from .packages.sta_net import STANet
from .main import AIPlugin
from .basic_cd import DPFCN, DVCA,RCNN

View File

@ -1,210 +1,183 @@
from PyQt5.QtCore import QSettings, pyqtSignal, Qt
from PyQt5.QtGui import QIcon, QTextBlock, QTextCursor
from PyQt5.QtWidgets import QDialog, QLabel,QComboBox, QDialogButtonBox, QFileDialog, QHBoxLayout, QMessageBox, QProgressBar, QPushButton, QTextEdit, QVBoxLayout
import subprocess
import threading
import sys
from . import AI_METHOD
from plugins.misc import AlgFrontend
from rscder.utils.icons import IconInstance
from rscder.utils.project import PairLayer
from osgeo import gdal, gdal_array
import os
import sys
import ai_method.subprcess_python as sp
from ai_method import AI_METHOD
from misc.main import AlgFrontend
import abc
from rscder.utils.project import Project
from rscder.utils.geomath import geo2imageRC, imageRC2geo
import math
from .packages import get_model
class AIMethodDialog(QDialog):
class BasicAICD(AlgFrontend):
stage_end = pyqtSignal(int)
stage_log = pyqtSignal(str)
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC3
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, model=None, *args, **kargs):
if model is None and send_message is not None:
send_message.emit('未能加载模型!')
return
ENV = 'base'
name = ''
stages = []
setting_widget = None
ds1: gdal.Dataset = gdal.Open(pth1)
ds2: gdal.Dataset = gdal.Open(pth2)
def __init__(self, parent = None) -> None:
super().__init__(parent)
# self.setting_widget:AlgFrontend = setting_widget
cell_size = (512, 512)
xsize = layer_parent.size[0]
ysize = layer_parent.size[1]
vlayout = QVBoxLayout()
band = ds1.RasterCount
yblocks = ysize // cell_size[1]
xblocks = xsize // cell_size[0]
hlayout = QHBoxLayout()
select_label = QLabel('模式选择:')
self.select_mode = QComboBox()
self.select_mode.addItem('----------', 'NoValue')
for stage in self.stages:
self.select_mode.addItem(stage[1], stage[0])
driver = gdal.GetDriverByName('GTiff')
out_tif = os.path.join(Project().other_path, 'temp.tif')
out_ds = driver.Create(out_tif, xsize, ysize, 1, gdal.GDT_Float32)
geo = layer_parent.grid.geo
proj = layer_parent.grid.proj
out_ds.SetGeoTransform(geo)
out_ds.SetProjection(proj)
setting_btn = QPushButton(IconInstance().SELECT, '配置')
max_diff = 0
min_diff = math.inf
hlayout.addWidget(select_label)
hlayout.addWidget(self.select_mode, 2)
hlayout.addWidget(setting_btn)
self.setting_args = []
start1x, start1y = geo2imageRC(ds1.GetGeoTransform(
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
end1x, end1y = geo2imageRC(ds1.GetGeoTransform(
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
def show_setting():
if self.setting_widget is None:
return
dialog = QDialog(parent)
vlayout = QVBoxLayout()
dialog.setLayout(vlayout)
widget = self.setting_widget.get_widget(dialog)
vlayout.addWidget(widget)
start2x, start2y = geo2imageRC(ds2.GetGeoTransform(
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
end2x, end2y = geo2imageRC(ds2.GetGeoTransform(
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
dbtn_box = QDialogButtonBox(dialog)
dbtn_box.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
dbtn_box.button(QDialogButtonBox.Ok).setText('确定')
dbtn_box.button(QDialogButtonBox.Cancel).setText('取消')
dbtn_box.button(QDialogButtonBox.Ok).clicked.connect(dialog.accept)
dbtn_box.button(QDialogButtonBox.Cancel).clicked.connect(dialog.reject)
vlayout.addWidget(dbtn_box, 1, Qt.AlignRight)
dialog.setMinimumHeight(500)
dialog.setMinimumWidth(900)
if dialog.exec_() == QDialog.Accepted:
self.setting_args = self.setting_widget.get_params(widget)
for j in range(yblocks + 1): # 该改这里了
if send_message is not None:
send_message.emit(f'计算{j}/{yblocks}')
for i in range(xblocks +1):
setting_btn.clicked.connect(show_setting)
block_xy1 = (start1x + i * cell_size[0], start1y+j * cell_size[1])
block_xy2 = (start2x + i * cell_size[0], start2y+j * cell_size[1])
block_xy = (i * cell_size[0], j * cell_size[1])
if block_xy1[1] > end1y or block_xy2[1] > end2y:
break
if block_xy1[0] > end1x or block_xy2[0] > end2x:
break
block_size = list(cell_size)
btnbox = QDialogButtonBox(self)
btnbox.setStandardButtons(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btnbox.button(QDialogButtonBox.Ok).setText('确定')
btnbox.button(QDialogButtonBox.Cancel).setText('取消')
if block_xy[1] + block_size[1] > ysize:
block_xy[1] = (ysize - block_size[1])
if block_xy[0] + block_size[0] > xsize:
block_xy[0] = ( xsize - block_size[0])
processbar = QProgressBar(self)
processbar.setMaximum(1)
processbar.setMinimum(0)
processbar.setEnabled(False)
def run_stage():
self.log = f'开始{self.current_stage}...\n'
btnbox.button(QDialogButtonBox.Cancel).setEnabled(True)
processbar.setMaximum(0)
processbar.setValue(0)
processbar.setEnabled(True)
q = threading.Thread(target=self.run_stage, args=(self.current_stage,))
q.start()
if block_xy1[1] + block_size[1] > end1y:
block_xy1[1] = (end1y - block_size[1])
if block_xy1[0] + block_size[0] > end1x:
block_xy1[0] = (end1x - block_size[0])
if block_xy2[1] + block_size[1] > end2y:
block_xy2[1] = (end2y - block_size[1])
if block_xy2[0] + block_size[0] > end2x:
block_xy2[0] = (end2x - block_size[0])
# if block_size1[0] * block_size1[1] == 0 or block_size2[0] * block_size2[1] == 0:
# continue
block_data1 = ds1.ReadAsArray(*block_xy1, *block_size)
block_data2 = ds2.ReadAsArray(*block_xy2, *block_size)
# if block_data1.shape[0] == 0:
# continue
if band == 1:
block_data1 = block_data1[None, ...]
block_data2 = block_data2[None, ...]
block_diff = model(block_data1, block_data2)
out_ds.GetRasterBand(1).WriteArray(block_diff, *block_xy)
if send_message is not None:
send_message.emit(f'完成{j}/{yblocks}')
del ds2
del ds1
out_ds.FlushCache()
del out_ds
if send_message is not None:
send_message.emit('归一化概率中...')
temp_in_ds = gdal.Open(out_tif)
out_normal_tif = os.path.join(Project().cmi_path, '{}_{}_cmi.tif'.format(
layer_parent.name, int(np.random.rand() * 100000)))
out_normal_ds = driver.Create(
out_normal_tif, xsize, ysize, 1, gdal.GDT_Byte)
out_normal_ds.SetGeoTransform(geo)
out_normal_ds.SetProjection(proj)
# hist = np.zeros(256, dtype=np.int32)
for j in range(yblocks+1):
block_xy = (0, j * cell_size[1])
if block_xy[1] > ysize:
break
block_size = (xsize, cell_size[1])
if block_xy[1] + block_size[1] > ysize:
block_size = (xsize, ysize - block_xy[1])
block_data = temp_in_ds.ReadAsArray(*block_xy, *block_size)
block_data = (block_data - min_diff) / (max_diff - min_diff) * 255
block_data = block_data.astype(np.uint8)
out_normal_ds.GetRasterBand(1).WriteArray(block_data, *block_xy)
# hist_t, _ = np.histogram(block_data, bins=256, range=(0, 256))
# hist += hist_t
# print(hist)
del temp_in_ds
del out_normal_ds
try:
os.remove(out_tif)
except:
pass
if send_message is not None:
send_message.emit('计算完成')
return out_normal_tif
btnbox.button(QDialogButtonBox.Cancel).setEnabled(False)
btnbox.accepted.connect(run_stage)
btnbox.rejected.connect(self._stage_stop)
self.processbar = processbar
vlayout.addLayout(hlayout)
self.detail = QTextEdit(self)
vlayout.addWidget(self.detail)
vlayout.addWidget(processbar)
vlayout.addWidget(btnbox)
self.detail.setReadOnly(True)
# self.detail.copyAvailable(True)
self.detail.setText(f'等待开始...')
self.setLayout(vlayout)
@AI_METHOD.register
class DVCA(BasicAICD):
self.setMinimumHeight(500)
self.setMinimumWidth(500)
self.setWindowIcon(IconInstance().AI_DETECT)
self.setWindowTitle(self.name)
@staticmethod
def get_name():
return 'DVCA'
self.stage_end.connect(self._stage_end)
self.log = f'等待开始...\n'
self.stage_log.connect(self._stage_log)
self.p = None
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('DVCA')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
@property
def current_stage(self):
if self.select_mode.currentData() == 'NoValue':
return None
return self.select_mode.currentData()
@property
def _python_base(self):
return os.path.abspath(os.path.join(os.path.dirname(sys.executable), '..', '..'))
@property
def activate_env(self):
script = os.path.join(self._python_base, 'Scripts','activate')
if self.ENV == 'base':
return script
else:
return script + ' ' + self.ENV
@AI_METHOD.register
class DPFCN(BasicAICD):
@property
def python_path(self):
if self.ENV == 'base':
return self._python_base
return os.path.join(self._python_base, 'envs', self.ENV, 'python.exe')
@staticmethod
def get_name():
return 'DPFCN'
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('DPFCN')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
@property
def workdir(self):
raise NotImplementedError()
@abc.abstractmethod
def stage_script(self, stage):
return None
def _stage_log(self, l):
self.log += l + '\n'
self.detail.setText(self.log)
cursor = self.detail.textCursor()
# QTextCursor
cursor.movePosition(QTextCursor.End)
self.detail.setTextCursor(cursor)
def closeEvent(self, e) -> None:
if self.p is None:
return super().closeEvent(e)
dialog = QMessageBox(QMessageBox.Warning, '警告', '关闭窗口将停止' + self.current_stage + '', QMessageBox.Ok | QMessageBox.Cancel)
dialog.button(QMessageBox.Ok).clicked.connect(dialog.accepted)
dialog.button(QMessageBox.Cancel).clicked.connect(dialog.rejected)
dialog.show()
r = dialog.exec()
# print(r)
# print(QMessageBox.Rejected)
# print(QMessageBox.Accepted)
if r == QMessageBox.Cancel:
e.ignore()
return
return super().closeEvent(e)
def _stage_stop(self):
if self.p is not None:
try:
self.stage_log.emit(f'用户停止{self.stage}...')
self.p.kill()
except:
pass
@AI_METHOD.register
class RCNN(BasicAICD):
def _stage_end(self, c):
self.processbar.setMaximum(1)
self.processbar.setValue(0)
self.processbar.setEnabled(False)
self.log += '完成!'
self.detail.setText(self.log)
@staticmethod
def get_name():
return 'RCNN'
def run_stage(self, stage):
if stage is None:
return
ss = self.stage_script(stage)
if ss is None:
self.stage_log.emit(f'开始{stage}时未发现脚本')
self.stage_end.emit(1)
return
if self.workdir is None:
self.stage_log.emit(f'未配置工作目录!')
self.stage_end.emit(2)
return
args = [ss, *self.setting_args]
self.p = sp.SubprocessWraper(self.python_path, self.workdir, args, self.activate_env)
for line in self.p.run():
self.stage_log.emit(line)
self.stage_end.emit(self.p.returncode)
from collections import OrderedDict
def option_to_gui(parent, options:OrderedDict):
for key in options:
pass
def gui_to_option(widget, options:OrderedDict):
pass
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('RCNN')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)

View File

@ -1,27 +1,194 @@
from rscder.plugins.basic import BasicPlugin
from rscder.gui.actions import ActionManager
from ai_method import AI_METHOD
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox
from rscder.utils.icons import IconInstance
from plugins.misc.main import AlgFrontend
from functools import partial
from threading import Thread
from plugins.misc.main import AlgFrontend
from rscder.gui.actions import ActionManager
from rscder.plugins.basic import BasicPlugin
from PyQt5.QtWidgets import QAction, QToolBar, QMenu, QDialog, QHBoxLayout, QVBoxLayout, QPushButton,QWidget,QLabel,QLineEdit,QPushButton,QComboBox,QDialogButtonBox
from rscder.gui.layercombox import PairLayerCombox
from rscder.utils.icons import IconInstance
from filter_collection import FILTER
from . import AI_METHOD
from thres import THRES
from misc import table_layer, AlgSelectWidget
from follow import FOLLOW
import os
class AICDMethod(QDialog):
def __init__(self,parent=None, alg:AlgFrontend=None):
super(AICDMethod, self).__init__(parent)
self.alg = alg
self.setWindowTitle('AI变化检测:{}'.format(alg.get_name()))
self.setWindowIcon(IconInstance().LOGO)
self.initUI()
self.setMinimumWidth(500)
def initUI(self):
#图层
self.layer_combox = PairLayerCombox(self)
layerbox = QHBoxLayout()
layerbox.addWidget(self.layer_combox)
self.filter_select = AlgSelectWidget(self, FILTER)
self.param_widget = self.alg.get_widget(self)
self.unsupervised_menu = self.param_widget
self.thres_select = AlgSelectWidget(self, THRES)
self.ok_button = QPushButton('确定', self)
self.ok_button.setIcon(IconInstance().OK)
self.ok_button.clicked.connect(self.accept)
self.ok_button.setDefault(True)
self.cancel_button = QPushButton('取消', self)
self.cancel_button.setIcon(IconInstance().CANCEL)
self.cancel_button.clicked.connect(self.reject)
self.cancel_button.setDefault(False)
buttonbox=QDialogButtonBox(self)
buttonbox.addButton(self.ok_button,QDialogButtonBox.NoRole)
buttonbox.addButton(self.cancel_button,QDialogButtonBox.NoRole)
buttonbox.setCenterButtons(True)
totalvlayout=QVBoxLayout()
totalvlayout.addLayout(layerbox)
totalvlayout.addWidget(self.filter_select)
if self.param_widget is not None:
totalvlayout.addWidget(self.param_widget)
totalvlayout.addWidget(self.thres_select)
totalvlayout.addStretch(1)
hbox = QHBoxLayout()
hbox.addStretch(1)
hbox.addWidget(buttonbox)
totalvlayout.addLayout(hbox)
# totalvlayout.addStretch()
self.setLayout(totalvlayout)
@FOLLOW.register
class AICDFollow(AlgFrontend):
@staticmethod
def get_name():
return 'AI变化检测'
@staticmethod
def get_icon():
return IconInstance().UNSUPERVISED
@staticmethod
def get_widget(parent=None):
widget = QWidget(parent)
layer_combox = PairLayerCombox(widget)
layer_combox.setObjectName('layer_combox')
filter_select = AlgSelectWidget(widget, FILTER)
filter_select.setObjectName('filter_select')
ai_select = AlgSelectWidget(widget, AI_METHOD)
ai_select.setObjectName('ai_select')
thres_select = AlgSelectWidget(widget, THRES)
thres_select.setObjectName('thres_select')
totalvlayout=QVBoxLayout()
totalvlayout.addWidget(layer_combox)
totalvlayout.addWidget(filter_select)
totalvlayout.addWidget(ai_select)
totalvlayout.addWidget(thres_select)
totalvlayout.addStretch()
widget.setLayout(totalvlayout)
return widget
@staticmethod
def get_params(widget:QWidget=None):
if widget is None:
return dict()
layer_combox = widget.findChild(PairLayerCombox, 'layer_combox')
filter_select = widget.findChild(AlgSelectWidget, 'filter_select')
ai_select = widget.findChild(AlgSelectWidget, 'ai_select')
thres_select = widget.findChild(AlgSelectWidget, 'thres_select')
layer1=layer_combox.layer1
pth1 = layer_combox.layer1.path
pth2 = layer_combox.layer2.path
falg, fparams = filter_select.get_alg_and_params()
cdalg, cdparams = ai_select.get_alg_and_params()
thalg, thparams = thres_select.get_alg_and_params()
if cdalg is None or thalg is None:
return dict()
return dict(
layer1=layer1,
pth1 = pth1,
pth2 = pth2,
falg = falg,
fparams = fparams,
cdalg = cdalg,
cdparams = cdparams,
thalg = thalg,
thparams = thparams,
)
@staticmethod
def run_alg(layer1=None,
pth1 = None,
pth2 = None,
falg = None,
fparams = None,
cdalg = None,
cdparams = None,
thalg = None,
thparams = None,
send_message = None):
if cdalg is None or thalg is None:
return
name = layer1.name
method_info = dict()
if falg is not None:
pth1 = falg.run_alg(pth1, name=name, send_message= send_message, **fparams)
pth2 = falg.run_alg(pth2, name=name, send_message= send_message, **fparams)
method_info['滤波算法'] = falg.get_name()
else:
method_info['滤波算法'] = ''
cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message= send_message,**cdparams)
if falg is not None:
try:
os.remove(pth1)
os.remove(pth2)
# send_message.emit('删除临时文件')
except:
# send_message.emit('删除临时文件失败!')
pass
thpth, th = thalg.run_alg(cdpth, name=name, send_message= send_message, **thparams)
method_info['变化检测算法'] = cdalg.get_name()
method_info['二值化算法'] = thalg.get_name()
table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = send_message)
class AIPlugin(BasicPlugin):
@staticmethod
def info():
return {
'name': 'AI 变化检测',
'author': 'RSC',
'name': 'AIPlugin',
'description': 'AIPlugin',
'author': 'RSCDER',
'version': '1.0.0',
'description': 'AI 变化检测',
'category': 'Ai method'
}
def set_action(self):
ai_menu = ActionManager().ai_menu
# ai_menu.setIcon(IconInstance().UNSUPERVISED)
# ActionManager().change_detection_menu.addMenu(ai_menu)
AI_menu = QMenu('&AI变化检测', self.mainwindow)
AI_menu.setIcon(IconInstance().AI_DETECT)
ActionManager().change_detection_menu.addMenu(AI_menu)
toolbar = ActionManager().add_toolbar('AI method')
for key in AI_METHOD.keys():
alg:AlgFrontend = AI_METHOD[key]
@ -30,15 +197,58 @@ class AIPlugin(BasicPlugin):
else:
name = alg.get_name()
action = QAction(alg.get_icon(), name, ai_menu)
action = QAction(alg.get_icon(), name, AI_menu)
func = partial(self.run_cd, alg)
action.triggered.connect(func)
toolbar.addAction(action)
ai_menu.addAction(action)
AI_menu.addAction(action)
def run_cd(self, alg:AlgFrontend):
dialog = alg.get_widget(self.mainwindow)
dialog.setWindowModality(Qt.NonModal)
def run_cd(self, alg):
# print(alg.get_name())
dialog = AICDMethod(self.mainwindow, alg)
dialog.show()
# dialog.exec()
if dialog.exec_() == QDialog.Accepted:
t = Thread(target=self.run_cd_alg, args=(dialog,))
t.start()
def run_cd_alg(self, w:AICDMethod):
layer1=w.layer_combox.layer1
pth1 = w.layer_combox.layer1.path
pth2 = w.layer_combox.layer2.path
name = layer1.layer_parent.name
falg, fparams = w.filter_select.get_alg_and_params()
cdalg = w.alg
cdparams = w.alg.get_params(w.param_widget)
thalg, thparams = w.thres_select.get_alg_and_params()
if cdalg is None or thalg is None:
return
method_info = dict()
if falg is not None:
pth1 = falg.run_alg(pth1, name=name, send_message=self.send_message, **fparams)
pth2 = falg.run_alg(pth2, name=name, send_message=self.send_message, **fparams)
method_info['滤波算法'] = falg.get_name()
cdpth = cdalg.run_alg(pth1, pth2, layer1.layer_parent, send_message=self.send_message,**cdparams)
if falg is not None:
try:
os.remove(pth1)
os.remove(pth2)
# send_message.emit('删除临时文件')
except:
# send_message.emit('删除临时文件失败!')
pass
thpth, th = thalg.run_alg(cdpth, name=name, send_message=self.send_message, **thparams)
method_info['变化检测算法'] = cdalg.get_name()
method_info['二值化算法'] = thalg.get_name()
table_layer(thpth,layer1,name, cdpath=cdpth, th=th, method_info=method_info, send_message = self.send_message)
# table_layer(thpth,layer1,name,self.send_message)

View File

@ -1,19 +0,0 @@
.DS_Store
checkpoints/
results/
result/
build/
*.pth
*/*.pth
*/*/*.pth
torch.egg-info/
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
.idea

View File

@ -1,25 +0,0 @@
BSD 2-Clause License
Copyright (c) 2020, justchenhao
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,233 +0,0 @@
# STANet for remote sensing image change detection
It is the implementation of the paper: A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection.
Here, we provide the pytorch implementation of the spatial-temporal attention neural network (STANet) for remote sensing image change detection.
![image-20200601213320103](/src/stanet-overview.png)
## Change log
20210112:
- add the pretraining weight of PAM. [baidupan link](https://pan.baidu.com/s/1O1kg7JWunqd87ajtVMM6pg), code: 2rja
20201105
- add a demo for quick start.
- add more dataset loader modes.
- enhance the image augmentation module (crop and rotation).
20200601
- first commit
## Prerequisites
- windows or Linux
- Python 3.6+
- CPU or NVIDIA GPU
- CUDA 9.0+
- PyTorch > 1.0
- visdom
## Installation
Clone this repo:
```bash
git clone https://github.com/justchenhao/STANet
cd STANet
```
Install [PyTorch](http://pytorch.org/) 1.0+ and other dependencies (e.g., torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate))
## Quick Start
You can run a demo to get started.
```bash
python demo.py
```
The input samples are in `samples`. After successfully run this script, you can find the predicted results in `samples/output`.
## Prepare Datasets
### download the change detection dataset
You could download the LEVIR-CD at https://justchenhao.github.io/LEVIR/;
The path list in the downloaded folder is as follows:
```
path to LEVIR-CD:
├─train
│ ├─A
│ ├─B
│ ├─label
├─val
│ ├─A
│ ├─B
│ ├─label
├─test
│ ├─A
│ ├─B
│ ├─label
```
where A contains images of pre-phase, B contains images of post-phase, and label contains label maps.
### cut bitemporal image pairs
The original image in LEVIR-CD has a size of 1024 * 1024, which will consume too much memory when training. Therefore, we can cut the origin images into smaller patches (e.g., 256 * 256, or 512 * 512). In our paper, we cut the original image into patches of 256 * 256 size without overlapping.
Make sure that the corresponding patch samples in the A, B, and label subfolders have the same name.
## Train
### Monitor training status
To view training results and loss plots, run this script and click the URL [http://localhost:8097](http://localhost:8097/).
```bash
python -m visdom.server
```
### train with our base method
Run the following script:
```bash
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDF0 --lr 0.001 --model CDF0 --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
```
Once finished, you could find the best model and the log files in the project folder.
### train with Basic spatial-temporal Attention Module (BAM) method
```bash
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
```
### train with Pyramid spatial-temporal Attention Module (PAM) method
```bash
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
```
## Test
You could edit the file val.py, for example:
```python
if __name__ == '__main__':
opt = TestOptions().parse() # get training options
opt = make_val_opt(opt)
opt.phase = 'test'
opt.dataroot = 'path-to-LEVIR-CD-test' # data root
opt.dataset_mode = 'changedetection'
opt.n_class = 2
opt.SA_mode = 'PAM' # BAM | PAM
opt.arch = 'mynet3'
opt.model = 'CDFA' # model type
opt.name = 'LEVIR-CDFAp0' # project name
opt.results_dir = './results/' # save predicted images
opt.epoch = 'best-epoch-in-val' # which epoch to test
opt.num_test = np.inf
val(opt)
```
then run the script: `python val.py`. Once finished, you can find the prediction log file in the project directory and predicted image files in the result directory.
## Using other dataset mode
### List mode
```bash
list=train
lr=0.001
dataset_mode=list
dataroot=path-to-dataroot
name=project_name
python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
```
In this case, the data structure should be the following:
```
"""
data structure
-dataroot
├─A
├─train1.png
...
├─B
├─train1.png
...
├─label
├─train1.png
...
└─list
├─val.txt
├─test.txt
└─train.txt
# In list/train.txt, each low writes the filename of each sample,
# for example:
list/train.txt
train1.png
train2.png
...
"""
```
### Concat mode for loading multiple datasets (each default mode is List)
```bash
list=train
lr=0.001
dataset_type=CD_data1,CD_data2,...,
val_dataset_type=CD_data
dataset_mode=concat
name=project_name
python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop
```
Note, in this case, you should modify the `get_dataset_info` in `data/data_config.py` to add the corresponding ` dataset_name` and `dataroot` in it.
```python
if dataset_type == 'LEVIR_CD':
root = 'path-to-LEVIR_CD-dataroot'
elif ...
# add more dataset ...
```
## Other TIPS
For more Training/Testing guides, you could see the option files in the `./options/` folder.
## Citation
If you use this code for your research, please cite our papers.
```
@Article{rs12101662,
AUTHOR = {Chen, Hao and Shi, Zhenwei},
TITLE = {A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection},
JOURNAL = {Remote Sensing},
VOLUME = {12},
YEAR = {2020},
NUMBER = {10},
ARTICLE-NUMBER = {1662},
URL = {https://www.mdpi.com/2072-4292/12/10/1662},
ISSN = {2072-4292},
DOI = {10.3390/rs12101662}
}
```
## Acknowledgments
Our code is inspired by [pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).

View File

@ -1,108 +0,0 @@
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
import torch
from torch.utils.data import ConcatDataset
from data.data_config import get_dataset_info
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
def create_single_dataset(opt, dataset_type_):
# return dataset_class
dataset_class = find_dataset_using_name('list')
# get dataset root
opt.dataroot = get_dataset_info(dataset_type_)
return dataset_class(opt)
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
print(opt.dataset_mode)
if opt.dataset_mode == 'concat':
# 叠加多个数据集
datasets = []
# 获取concat的多个数据集列表
self.dataset_type = opt.dataset_type.split(',')
# 去除“,”的影响
if self.dataset_type[-1] == '':
self.dataset_type = self.dataset_type[:-1]
for dataset_type_ in self.dataset_type:
dataset_ = create_single_dataset(opt, dataset_type_)
datasets.append(dataset_)
self.dataset = ConcatDataset(datasets)
else:
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads),
drop_last=True)
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data

View File

@ -1,189 +0,0 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image,ImageFilter
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
import math
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(opt, size, test=False):
w, h = size
new_h = h
new_w = w
angle = 0
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
if 'rotate' in opt.preprocess and test is False:
angle = random.uniform(0, opt.angle)
# print(angle)
new_w = int(new_w * math.cos(angle*math.pi/180) \
+ new_h*math.sin(angle*math.pi/180))
new_h = int(new_h * math.cos(angle*math.pi/180) \
+ new_w*math.sin(angle*math.pi/180))
new_w = min(new_w,new_h)
new_h = min(new_w,new_h)
# print(new_h,new_w)
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
# print('x,y: ',x,y)
flip = random.random() > 0.5 # left-right
return {'crop_pos': (x, y), 'flip': flip, 'angle': angle}
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC,
convert=True, normalize=True, test=False):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
# gaussian blur
if 'blur' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __blur(img)))
if 'rotate' in opt.preprocess and test==False:
if params is None:
transform_list.append(transforms.RandomRotation(5))
else:
degree = params['angle']
transform_list.append(transforms.Lambda(lambda img: __rotate(img, degree)))
if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'],
opt.crop_size)))
if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __blur(img):
if img.mode == 'RGB':
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
return img
def __rotate(img, degree):
if img.mode =='RGB':
# set img padding == 128
img2 = img.convert('RGBA')
rot = img2.rotate(degree,expand=1)
fff = Image.new('RGBA', rot.size, (128,) * 4) # 灰色
out = Image.composite(rot, fff, rot)
img = out.convert(img.mode)
return img
else:
# set label padding == 0
img2 = img.convert('RGBA')
rot = img2.rotate(degree,expand=1)
# a white image same size as rotated image
fff = Image.new('RGBA', rot.size, (255,) * 4)
# create a composite image using the alpha layer of rot as a mask
out = Image.composite(rot, fff, rot)
img = out.convert(img.mode)
return img
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
# print('imagesize:',ow,oh)
# only 图像尺寸大于截取尺寸才截取否则要padding
if (ow > tw and oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
size = [size, size]
if img.mode == 'RGB':
new_image = Image.new('RGB', size, (128, 128, 128))
new_image.paste(img, (int((1+size[1] - img.size[0]) / 2),
int((1+size[0] - img.size[1]) / 2)))
return new_image
else:
new_image = Image.new(img.mode, size, 255)
# upper left corner
new_image.paste(img, (int((1 + size[1] - img.size[0]) / 2),
int((1 + size[0] - img.size[1]) / 2)))
return new_image
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True

View File

@ -1,63 +0,0 @@
from data.base_dataset import BaseDataset, get_transform, get_params
from data.image_folder import make_dataset
from PIL import Image
import os
import numpy as np
class ChangeDetectionDataset(BaseDataset):
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
datafolder-tree
dataroot:.
A
B
label
"""
def __init__(self, opt):
BaseDataset.__init__(self, opt)
folder_A = 'A'
folder_B = 'B'
folder_L = 'label'
self.istest = False
if opt.phase == 'test':
self.istest = True
self.A_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_A), opt.max_dataset_size))
self.B_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_B), opt.max_dataset_size))
if not self.istest:
self.L_paths = sorted(make_dataset(os.path.join(opt.dataroot, folder_L), opt.max_dataset_size))
print(self.A_paths)
def __getitem__(self, index):
A_path = self.A_paths[index]
B_path = self.B_paths[index]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
transform_params = get_params(self.opt, A_img.size, test=self.istest)
# apply the same transform to A B L
transform = get_transform(self.opt, transform_params, test=self.istest)
A = transform(A_img)
B = transform(B_img)
if self.istest:
return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path}
L_path = self.L_paths[index]
tmp = np.array(Image.open(L_path), dtype=np.uint32)/255
L_img = Image.fromarray(tmp)
transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False,
test=self.istest)
L = transform_L(L_img)
return {'A': A, 'A_paths': A_path,
'B': B, 'B_paths': B_path,
'L': L, 'L_paths': L_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)

View File

@ -1,12 +0,0 @@
def get_dataset_info(dataset_type):
"""define dataset_name and its dataroot"""
root = ''
if dataset_type == 'LEVIR_CD':
root = 'path-to-LEVIR_CD-dataroot'
# add more dataset ...
else:
raise TypeError("not define the %s" % dataset_type)
return root

View File

@ -1,66 +0,0 @@
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 'tif', 'tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

View File

@ -1,101 +0,0 @@
import os
import collections
import numpy as np
from data.base_dataset import BaseDataset, get_transform, get_params
from PIL import Image
class listDataset(BaseDataset):
"""
data structure
-dataroot
A
train1.png
...
B
train1.png
...
label
train1.png
...
list
val.txt
test.txt
train.txt
# In list/train.txt, each low writes the filename of each sample,
# for example:
list/train.txt
train1.png
train2.png
...
"""
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.split = opt.split
self.files = collections.defaultdict(list)
self.istest = False if opt.phase == 'train' else True # 是否为测试/验证;若是,对数据不做尺度变换和旋转变换;
path = os.path.join(self.root, 'list', self.split + '.txt')
file_list = tuple(open(path, 'r'))
file_list = [id_.rstrip() for id_ in file_list]
self.files[self.split] = file_list
# print(file_list)
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
paths = self.files[self.split][index]
path = paths.split(" ")
A_path = os.path.join(self.root,'A', path[0])
B_path = os.path.join(self.root,'B', path[0])
L_path = os.path.join(self.root,'label', path[0])
A = Image.open(A_path).convert('RGB')
B = Image.open(B_path).convert('RGB')
tmp = np.array(Image.open(L_path), dtype=np.uint32) / 255
# print(tmp.max())
L = Image.fromarray(tmp)
transform_params = get_params(self.opt, A.size, self.istest)
transform = get_transform(self.opt, transform_params, test=self.istest)
transform_L = get_transform(self.opt, transform_params, method=Image.NEAREST, normalize=False,
test=self.istest) # 标签不做归一化
A = transform(A)
B = transform(B)
L = transform_L(L)
return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path, 'L': L, 'L_paths': L_path}
# Leave code for debugging purposes
if __name__ == '__main__':
from options.train_options import TrainOptions
opt = TrainOptions().parse()
opt.dataroot = r'I:\data\change_detection\LEVIR-CD-r'
opt.split = 'train'
opt.load_size = 500
opt.crop_size = 500
opt.batch_size = 1
opt.dataset_mode = 'list'
from data import create_dataset
dataset = create_dataset(opt)
import matplotlib.pyplot as plt
from util.util import tensor2im
for i, data in enumerate(dataset):
A = data['A']
L = data['L']
A = tensor2im(A)
color = tensor2im(L)[:,:,0]*255
plt.imshow(A)
plt.show()

View File

@ -1,82 +0,0 @@
import time
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
import os
from util.util import save_images
import numpy as np
from util.util import mkdir
from PIL import Image
def make_val_opt(opt):
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
opt.phase = 'val'
opt.preprocess = 'none1'
opt.isTrain = False
opt.aspect_ratio = 1
opt.eval = True
return opt
def val(opt):
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
save_path = opt.results_dir
mkdir(save_path)
model.eval()
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
pred = model.test(val=False) # run inference return pred
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(pred, save_path, img_path)
def pred_image(data_root, results_dir):
opt = TestOptions().parse() # get training options
opt = make_val_opt(opt)
opt.phase = 'test'
opt.dataset_mode = 'changedetection'
opt.n_class = 2
opt.SA_mode = 'PAM'
opt.arch = 'mynet3'
opt.model = 'CDFA'
opt.epoch = 'pam'
opt.num_test = np.inf
opt.name = 'pam'
opt.dataroot = data_root
opt.results_dir = results_dir
val(opt)
if __name__ == '__main__':
# define the data_root and the results_dir
# note:
# data_root should have such structure:
# ├─A
# ├─B
# A for before images
# B for after images
data_root = './samples'
results_dir = './samples/output/'
pred_image(data_root, results_dir)

View File

@ -1,52 +0,0 @@
import torch
import torch.nn.functional as F
from torch import nn
class BAM(nn.Module):
""" Basic self-attention module
"""
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
super(BAM, self).__init__()
self.chanel_in = in_dim
self.key_channel = self.chanel_in //8
self.activation = activation
self.ds = ds #
self.pool = nn.AvgPool2d(self.ds)
print('ds: ',ds)
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, input):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
x = self.pool(input)
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
energy = torch.bmm(proj_query, proj_key) # transpose check
energy = (self.key_channel**-.5) * energy
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = F.interpolate(out, [width*self.ds,height*self.ds])
out = out + input
return out

View File

@ -1,114 +0,0 @@
import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss
class CDF0Model(BaseModel):
"""
change detection module:
feature extractor
contrastive loss
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.istest = opt.istest
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['f']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
if self.istest:
self.visual_names = ['A', 'B', 'pred_L_show']
self.visual_features = ['feat_A', 'feat_B']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['F']
else: # during test time, only load Gs
self.model_names = ['F']
self.ds=1
# define networks (both Generators and discriminators)
self.n_class = 2
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
if self.isTrain:
# define loss functions
self.criterionF = loss.BCL()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
self.A = input['A'].to(self.device)
self.B = input['B'].to(self.device)
if not self.istest:
self.L = input['L'].to(self.device).long()
self.image_paths = input['A_paths']
if self.isTrain:
self.L_s = self.L.float()
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
self.L_s[self.L_s == 1] = -1 # change
self.L_s[self.L_s == 0] = 1 # no change
def test(self, val=False):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
if val: # score
from util.metrics import RunningMetrics
metrics = RunningMetrics(self.n_class)
pred = self.pred_L.long()
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
scores = metrics.get_cm()
return scores
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.feat_A = self.netF(self.A) # f(A)
self.feat_B = self.netF(self.B) # f(B)
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)
# print(self.dist.shape)
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
self.pred_L = (self.dist > 1).float()
self.pred_L_show = self.pred_L.long()
return self.pred_L
def backward(self):
"""Calculate the loss for generators F and L"""
# print(self.weight)
self.loss_f = self.criterionF(self.dist, self.L_s)
self.loss = self.loss_f
if torch.isnan(self.loss):
print(self.image_paths)
self.loss.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute feat and dist
self.optimizer_G.zero_grad() # set G's gradients to zero
self.backward() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights

View File

@ -1,119 +0,0 @@
import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss
class CDFAModel(BaseModel):
"""
change detection module:
feature extractor+ spatial-temporal-self-attention
contrastive loss
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['f']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
if opt.phase == 'test':
self.istest = True
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
if self.istest:
self.visual_names = ['A', 'B', 'pred_L_show'] # visualizations for A and B
self.visual_features = ['feat_A','feat_B']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['F','A']
else: # during test time, only load Gs
self.model_names = ['F','A']
self.istest = False
self.ds = 1
self.n_class =2
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
self.netA = backbone.CDSA(in_c=opt.f_c, ds=opt.ds, mode=opt.SA_mode).to(self.device)
if self.isTrain:
# define loss functions
self.criterionF = loss.BCL()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(
self.netF.parameters(),
), lr=opt.lr*opt.lr_decay, betas=(opt.beta1, 0.999))
self.optimizer_A = torch.optim.Adam(self.netA.parameters(), lr=opt.lr*1, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_A)
def set_input(self, input):
self.A = input['A'].to(self.device)
self.B = input['B'].to(self.device)
if self.istest is False:
if 'L' in input.keys():
self.L = input['L'].to(self.device).long()
self.image_paths = input['A_paths']
if self.isTrain:
self.L_s = self.L.float()
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
self.L_s[self.L_s == 1] = -1 # change
self.L_s[self.L_s == 0] = 1 # no change
def test(self, val=False):
with torch.no_grad():
self.forward()
self.compute_visuals()
if val: # 返回score
from util.metrics import RunningMetrics
metrics = RunningMetrics(self.n_class)
pred = self.pred_L.long()
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
scores = metrics.get_cm()
return scores
else:
return self.pred_L.long()
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.feat_A = self.netF(self.A) # f(A)
self.feat_B = self.netF(self.B) # f(B)
self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B)
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) # 特征距离
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
self.pred_L = (self.dist > 1).float()
# self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest')
self.pred_L_show = self.pred_L.long()
return self.pred_L
def backward(self):
self.loss_f = self.criterionF(self.dist, self.L_s)
self.loss = self.loss_f
# print(self.loss)
self.loss.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute feat and dist
self.set_requires_grad([self.netF, self.netA], True)
self.optimizer_G.zero_grad() # set G's gradients to zero
self.optimizer_A.zero_grad()
self.backward() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights
self.optimizer_A.step()

View File

@ -1,168 +0,0 @@
import torch
import torch.nn.functional as F
from torch import nn
class _PAMBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input/Output:
N * C * H * (2*W)
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to partition the input feature maps
ds : downsampling scale
'''
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
super(_PAMBlock, self).__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2d(self.ds)
self.in_channels = in_channels
self.key_channels = key_channels
self.value_channels = value_channels
self.f_key = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_query = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
kernel_size=1, stride=1, padding=0)
def forward(self, input):
x = input
if self.ds != 1:
x = self.pool(input)
# input shape: b,c,h,2w
batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)//2
local_y = []
local_x = []
step_h, step_w = h//self.scale, w//self.scale
for i in range(0, self.scale):
for j in range(0, self.scale):
start_x, start_y = i*step_h, j*step_w
end_x, end_y = min(start_x+step_h, h), min(start_y+step_w, w)
if i == (self.scale-1):
end_x = h
if j == (self.scale-1):
end_y = w
local_x += [start_x, end_x]
local_y += [start_y, end_y]
value = self.f_value(x)
query = self.f_query(x)
key = self.f_key(x)
value = torch.stack([value[:, :, :, :w], value[:,:,:,w:]], 4) # B*N*H*W*2
query = torch.stack([query[:, :, :, :w], query[:,:,:,w:]], 4) # B*N*H*W*2
key = torch.stack([key[:, :, :, :w], key[:,:,:,w:]], 4) # B*N*H*W*2
local_block_cnt = 2*self.scale*self.scale
# self-attention func
def func(value_local, query_local, key_local):
batch_size_new = value_local.size(0)
h_local, w_local = value_local.size(2), value_local.size(3)
value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)
query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
query_local = query_local.permute(0, 2, 1)
key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)
sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context_local = torch.bmm(value_local, sim_map.permute(0,2,1))
# context_local = context_local.permute(0, 2, 1).contiguous()
context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
return context_local
# Parallel Computing to speed up
# reshape value_local, q, k
v_list = [value[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
v_locals = torch.cat(v_list,dim=0)
q_list = [query[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
q_locals = torch.cat(q_list,dim=0)
k_list = [key[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
k_locals = torch.cat(k_list,dim=0)
# print(v_locals.shape)
context_locals = func(v_locals,q_locals,k_locals)
context_list = []
for i in range(0, self.scale):
row_tmp = []
for j in range(0, self.scale):
left = batch_size*(j+i*self.scale)
right = batch_size*(j+i*self.scale) + batch_size
tmp = context_locals[left:right]
row_tmp.append(tmp)
context_list.append(torch.cat(row_tmp, 3))
context = torch.cat(context_list, 2)
context = torch.cat([context[:,:,:,:,0],context[:,:,:,:,1]],3)
if self.ds !=1:
context = F.interpolate(context, [h*self.ds, 2*w*self.ds])
return context
class PAMBlock(_PAMBlock):
def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1):
if key_channels == None:
key_channels = in_channels//8
if value_channels == None:
value_channels = in_channels
super(PAMBlock, self).__init__(in_channels,key_channels,value_channels,scale,ds)
class PAM(nn.Module):
"""
PAM module
"""
def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
super(PAM, self).__init__()
self.group = len(sizes)
self.stages = []
self.ds = ds # output stride
self.value_channels = out_channels
self.key_channels = out_channels // 8
self.stages = nn.ModuleList(
[self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds)
for size in sizes])
self.conv_bn = nn.Sequential(
nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0,bias=False),
# nn.BatchNorm2d(out_channels),
)
def _make_stage(self, in_channels, key_channels, value_channels, size, ds):
return PAMBlock(in_channels,key_channels,value_channels,size,ds)
def forward(self, feats):
priors = [stage(feats) for stage in self.stages]
# concat
context = []
for i in range(0, len(priors)):
context += [priors[i]]
output = self.conv_bn(torch.cat(context, 1))
return output

View File

@ -1,67 +0,0 @@
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance

View File

@ -1,51 +0,0 @@
# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM
def define_F(in_c, f_c, type='unet'):
if type == 'mynet3':
print("using mynet3 backbone")
return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
else:
NotImplementedError('no such F type!')
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class CDSA(nn.Module):
"""self attention module for change detection
"""
def __init__(self, in_c, ds=1, mode='BAM'):
super(CDSA, self).__init__()
self.in_C = in_c
self.ds = ds
print('ds: ',self.ds)
self.mode = mode
if self.mode == 'BAM':
self.Self_Att = BAM(self.in_C, ds=self.ds)
elif self.mode == 'PAM':
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
self.apply(weights_init)
def forward(self, x1, x2):
height = x1.shape[3]
x = torch.cat((x1, x2), 3)
x = self.Self_Att(x)
return x[:,:,:,0:height], x[:,:,:,height:]

View File

@ -1,333 +0,0 @@
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from torch.optim import lr_scheduler
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.visual_features = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
self.istest = True if opt.phase == 'test' else False # 如果是测试,该模式下,没有标注样本;
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def train(self):
"""Make models train mode during train time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.train()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
# torch.save(net.module.cpu().state_dict(), save_path)
torch.save(net.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def get_visual(self, name):
visual_ret = {}
visual_ret[name] = getattr(self, name)
return visual_ret
def pred_large(self, A, B, input_size=256, stride=0):
"""
输入前后时相的大图获得预测结果
假定预测结果中心部分为准确边缘padding = (input_size-stride)/2
:param A: tensor, N*C*H*W
:param B: tensor, N*C*H*W
:param input_size: int, 输入网络的图像size
:param stride: int, 预测时的跨步
:return: pred, tensor, N*1*H*W
"""
import math
import numpy as np
n, c, h, w = A.shape
assert A.shape == B.shape
# 分块数量
n_h = math.ceil((h - input_size) / stride) + 1
n_w = math.ceil((w - input_size) / stride) + 1
# 重新计算长宽
new_h = (n_h - 1) * stride + input_size
new_w = (n_w - 1) * stride + input_size
print("new_h: ", new_h)
print("new_w: ", new_w)
print("n_h: ", n_h)
print("n_w: ", n_w)
new_A = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
new_B = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
new_A[:, :, :h, :w] = A
new_B[:, :, :h, :w] = B
new_pred = torch.zeros([n, 1, new_h, new_w], dtype=torch.uint8)
del A
del B
#
for i in range(0, new_h - input_size + 1, stride):
for j in range(0, new_w - input_size + 1, stride):
left = j
right = input_size + j
top = i
bottom = input_size + i
patch_A = new_A[:, :, top:bottom, left:right]
patch_B = new_B[:, :, top:bottom, left:right]
# print(left,' ',right,' ', top,' ', bottom)
self.A = patch_A.to(self.device)
self.B = patch_B.to(self.device)
with torch.no_grad():
patch_pred = self.forward()
new_pred[:, :, top:bottom, left:right] = patch_pred.detach().cpu()
pred = new_pred[:, :, :h, :w]
return pred
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
# if isinstance(net, torch.nn.DataParallel):
# net = net.module
# net = net.module # 适配保存的module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
# print(key)
net.load_state_dict(state_dict,strict=False)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
if __name__ == '__main__':
A = torch.rand([1,3,512,512],dtype=torch.float32)
B = torch.rand([1,3,512,512],dtype=torch.float32)

View File

@ -1,29 +0,0 @@
import torch.nn as nn
import torch
class BCL(nn.Module):
"""
batch-balanced contrastive loss
no-change1
change-1
"""
def __init__(self, margin=2.0):
super(BCL, self).__init__()
self.margin = margin
def forward(self, distance, label):
label[label==255] = 1
mask = (label != 255).float()
distance = distance * mask
pos_num = torch.sum((label==1).float())+0.0001
neg_num = torch.sum((label==-1).float())+0.0001
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
loss_2 = torch.sum((1-label) / 2 * mask *
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
) / neg_num
loss = loss_1 + loss_2
return loss

View File

@ -1,352 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import math
class F_mynet3(nn.Module):
def __init__(self, backbone='resnet18',in_c=3, f_c=64, output_stride=8):
self.in_c = in_c
super(F_mynet3, self).__init__()
self.module = mynet3(backbone=backbone, output_stride=output_stride, f_c=f_c, in_c=self.in_c)
def forward(self, input):
return self.module(input)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
512, 64
"""
print(in_c)
model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
if in_c != 3:
pretrained = False
if pretrained:
model._load_pretrained_model(model_urls['resnet34'])
return model
def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
512, 256, 128, 64, 64
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
if in_c !=3:
pretrained=False
if pretrained:
model._load_pretrained_model(model_urls['resnet18'])
return model
def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
2048, 256
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
if in_c !=3:
pretrained=False
if pretrained:
model._load_pretrained_model(model_urls['resnet50'])
return model
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn1 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn2 = BatchNorm(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
self.inplanes = 64
self.in_c = in_c
print('in_c: ',self.in_c)
super(ResNet, self).__init__()
blocks = [1, 2, 4]
if output_stride == 32:
strides = [1, 2, 2, 2]
dilations = [1, 1, 1, 1]
elif output_stride == 16:
strides = [1, 2, 2, 1]
dilations = [1, 1, 1, 2]
elif output_stride == 8:
strides = [1, 2, 1, 1]
dilations = [1, 1, 2, 4]
elif output_stride == 4:
strides = [1, 1, 1, 1]
dilations = [1, 2, 4, 8]
else:
raise NotImplementedError
# Modules
self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = BatchNorm(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
# self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
self._init_weight()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
return nn.Sequential(*layers)
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
downsample=downsample, BatchNorm=BatchNorm))
self.inplanes = planes * block.expansion
for i in range(1, len(blocks)):
layers.append(block(self.inplanes, planes, stride=1,
dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
return nn.Sequential(*layers)
def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) # | 4
x = self.layer1(x) # | 4
low_level_feat2 = x # | 4
x = self.layer2(x) # | 8
low_level_feat3 = x
x = self.layer3(x) # | 16
low_level_feat4 = x
x = self.layer4(x) # | 32
return x, low_level_feat2, low_level_feat3, low_level_feat4
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _load_pretrained_model(self, model_path):
pretrain_dict = model_zoo.load_url(model_path)
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
if backbone == 'resnet50':
return ResNet50(output_stride, BatchNorm, in_c=in_c)
elif backbone == 'resnet34':
return ResNet34(output_stride, BatchNorm, in_c=in_c)
elif backbone == 'resnet18':
return ResNet18(output_stride, BatchNorm, in_c=in_c)
else:
raise NotImplementedError
class DR(nn.Module):
def __init__(self, in_d, out_d):
super(DR, self).__init__()
self.in_d = in_d
self.out_d = out_d
self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False)
self.bn1 = nn.BatchNorm2d(self.out_d)
self.relu = nn.ReLU()
def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
return x
class Decoder(nn.Module):
def __init__(self, fc, BatchNorm):
super(Decoder, self).__init__()
self.fc = fc
self.dr2 = DR(64, 96)
self.dr3 = DR(128, 96)
self.dr4 = DR(256, 96)
self.dr5 = DR(512, 96)
self.last_conv = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv2d(256, self.fc, kernel_size=1, stride=1, padding=0, bias=False),
BatchNorm(self.fc),
nn.ReLU(),
)
self._init_weight()
def forward(self, x,low_level_feat2, low_level_feat3, low_level_feat4):
# x1 = self.dr1(low_level_feat1)
x2 = self.dr2(low_level_feat2)
x3 = self.dr3(low_level_feat3)
x4 = self.dr4(low_level_feat4)
x = self.dr5(x)
x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
# x2 = F.interpolate(x2, size=x3.size()[2:], mode='bilinear', align_corners=True)
x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True)
x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, x2, x3, x4), dim=1)
x = self.last_conv(x)
return x
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def build_decoder(fc, backbone, BatchNorm):
return Decoder(fc, BatchNorm)
class mynet3(nn.Module):
def __init__(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3):
super(mynet3, self).__init__()
print('arch: mynet3')
BatchNorm = nn.BatchNorm2d
self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c)
self.decoder = build_decoder(f_c, backbone, BatchNorm)
if freeze_bn:
self.freeze_bn()
def forward(self, input):
x, f2, f3, f4 = self.backbone(input)
x = self.decoder(x, f2, f3, f4)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

View File

@ -1 +0,0 @@
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""

View File

@ -1,147 +0,0 @@
import argparse
import os
from util import util
import torch
import models
import data
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', type=str, default='./LEVIR-CD', help='path to images (should have subfolders A, B, label)')
parser.add_argument('--val_dataroot', type=str, default='./LEVIR-CD', help='path to images in the val phase (should have subfolders A, B, label)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='CDF0', help='chooses which model to use. [CDF0 | CDFA]')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB')
parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3')
parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num')
parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--SA_mode', type=str, default='BAM', help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat | list | json]')
parser.add_argument('--val_dataset_mode', type=str, default='changedetection', help='chooses how datasets are loaded. [changedetection | concat| list | json]')
parser.add_argument('--dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]')
parser.add_argument('--val_dataset_type', type=str, default='CD_LEVIR', help='chooses which datasets too load. [LEVIR | WHU ]')
parser.add_argument('--split', type=str, default='train', help='chooses wihch list-file to open when use listDataset. [train | val | test]')
parser.add_argument('--val_split', type=str, default='val', help='chooses wihch list-file to open when use listDataset. [train | val | test]')
parser.add_argument('--json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase')
parser.add_argument('--val_json_name', type=str, default='train_val_test', help='input the json name which contain the file names of images of different phase')
parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate')
parser.add_argument('--angle', type=int, default=0, help='rotate angle')
parser.add_argument('--istest', type=bool, default=False, help='True for the case without label')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | none]')
parser.add_argument('--no_flip', type=bool, default=True, help='if specified, do not flip(left-right) the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
if dataset_name != 'concat':
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt

View File

@ -1,22 +0,0 @@
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
# Dropout and Batchnorm has different behavioir during training and test.
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
# To avoid cropping, the load_size should be the same as crop_size
parser.set_defaults(load_size=parser.get_default('crop_size'))
self.isTrain = False
return parser

View File

@ -1,40 +0,0 @@
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# visdom and HTML visualization parameters
parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--lr_decay', type=float, default=1, help='learning rate decay for certain module ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# training parameters
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
return parser

Binary file not shown.

Before

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 922 B

View File

@ -1,5 +0,0 @@
# train PAM
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot path-to-LEVIR-CD-train --val_dataroot path-to-LEVIR-CD-val --name LEVIR-CDFAp0 --lr 0.001 --model --SA_mode PAM CDFA --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop

View File

@ -1,11 +0,0 @@
# concat mode
list=train
lr=0.001
dataset_type=CD_data1,CD_data2,...,
dataset_type=LEVIR_CD1
val_dataset_type=LEVIR_CD1
dataset_mode=concat
name=project_name
python ./train.py --num_threads 4 --display_id 0 --dataset_type $dataset_type --val_dataset_type $val_dataset_type --save_epoch_freq 1 --niter 100 --angle 15 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop

View File

@ -1,9 +0,0 @@
# List mode
list=train
lr=0.001
dataset_mode=list
dataroot=path-to-dataroot
name=project_name
#
python ./train.py --num_threads 4 --display_id 0 --dataroot ${dataroot} --val_dataroot ${dataroot} --save_epoch_freq 1 --niter 100 --angle 25 --niter_decay 100 --display_env FAp0 --SA_mode PAM --name $name --lr $lr --model CDFA --batch_size 4 --dataset_mode $dataset_mode --val_dataset_mode $dataset_mode --split $list --load_size 256 --crop_size 256 --preprocess resize_rotate_and_crop

Binary file not shown.

Before

Width:  |  Height:  |  Size: 434 KiB

View File

@ -1,105 +0,0 @@
from data import create_dataset
from models import create_model
from util.util import save_images
import numpy as np
from util.util import mkdir
import argparse
from PIL import Image
import torchvision.transforms as transforms
def transform():
transform_list = []
transform_list += [transforms.ToTensor()]
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def val(opt):
image_1_path = opt.image1_path
image_2_path = opt.image2_path
A_img = Image.open(image_1_path).convert('RGB')
B_img = Image.open(image_2_path).convert('RGB')
trans = transform()
A = trans(A_img).unsqueeze(0)
B = trans(B_img).unsqueeze(0)
# dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
save_path = opt.results_dir
mkdir(save_path)
model.eval()
data = {}
data['A']= A
data['B'] = B
data['A_paths'] = [image_1_path]
model.set_input(data) # unpack data from data loader
pred = model.test(val=False) # run inference return pred
img_path = [image_1_path] # get image paths
save_images(pred, save_path, img_path)
if __name__ == '__main__':
# 从外界调用方式:
# python test.py --image1_path [path-to-img1] --image2_path [path-to-img2] --results_dir [path-to-result_dir]
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image1_path', type=str, default='./samples/A/test_2_0000_0000.png',
help='path to images A')
parser.add_argument('--image2_path', type=str, default='./samples/B/test_2_0000_0000.png',
help='path to images B')
parser.add_argument('--results_dir', type=str, default='./samples/output/', help='saves results here.')
parser.add_argument('--name', type=str, default='pam',
help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='CDFA', help='chooses which model to use. [CDF0 | CDFA]')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB ')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB')
parser.add_argument('--arch', type=str, default='mynet3', help='feature extractor architecture | mynet3')
parser.add_argument('--f_c', type=int, default=64, help='feature extractor channel num')
parser.add_argument('--n_class', type=int, default=2, help='# of output pred channels: 2 for num of classes')
parser.add_argument('--SA_mode', type=str, default='PAM',
help='choose self attention mode for change detection, | ori |1 | 2 |pyramid, ...')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='changedetection',
help='chooses how datasets are loaded. [changedetection | json]')
parser.add_argument('--val_dataset_mode', type=str, default='changedetection',
help='chooses how datasets are loaded. [changedetection | json]')
parser.add_argument('--split', type=str, default='train',
help='chooses wihch list-file to open when use listDataset. [train | val | test]')
parser.add_argument('--ds', type=int, default='1', help='self attention module downsample rate')
parser.add_argument('--angle', type=int, default=0, help='rotate angle')
parser.add_argument('--istest', type=bool, default=False, help='True for the case without label')
parser.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop',
help='scaling and cropping of images at load time [resize_and_crop | none]')
parser.add_argument('--no_flip', type=bool, default=True,
help='if specified, do not flip(left-right) the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
parser.add_argument('--epoch', type=str, default='pam',
help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default='0',
help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--isTrain', type=bool, default=False, help='is or not')
parser.add_argument('--num_test', type=int, default=np.inf, help='how many test images to run')
opt = parser.parse_args()
val(opt)

View File

@ -1,181 +0,0 @@
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import os
from util import html
from util.visualizer import save_images
from util.metrics import AverageMeter
import copy
import numpy as np
import torch
import random
def seed_torch(seed=2019):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# set seeds
# seed_torch(2019)
ifSaveImage = False
def make_val_opt(opt):
val_opt = copy.deepcopy(opt)
val_opt.preprocess = '' #
# hard-code some parameters for test
val_opt.num_threads = 0 # test code only supports num_threads = 1
val_opt.batch_size = 4 # test code only supports batch_size = 1
val_opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
val_opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
val_opt.angle = 0
val_opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
val_opt.phase = 'val'
val_opt.split = opt.val_split # function in jsonDataset and ListDataset
val_opt.isTrain = False
val_opt.aspect_ratio = 1
val_opt.results_dir = './results/'
val_opt.dataroot = opt.val_dataroot
val_opt.dataset_mode = opt.val_dataset_mode
val_opt.dataset_type = opt.val_dataset_type
val_opt.json_name = opt.val_json_name
val_opt.eval = True
val_opt.num_test = 2000
return val_opt
def print_current_acc(log_name, epoch, score):
"""print current acc on console; also save the losses to the disk
Parameters:
"""
message = '(epoch: %d) ' % epoch
for k, v in score.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
def val(opt, model):
opt = make_val_opt(opt)
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
# model = create_model(opt) # create a model given opt.model and other options
# model.setup(opt) # regular setup: load and print networks; create schedulers
web_dir = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
model.eval()
# create a logging file to store training losses
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val_log.txt')
with open(log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ val acc (%s) ================\n' % now)
running_metrics = AverageMeter()
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
score = model.test(val=True) # run inference
running_metrics.update(score)
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
if ifSaveImage:
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
score = running_metrics.get_scores()
print_current_acc(log_name, epoch, score)
if opt.display_id > 0:
visualizer.plot_current_acc(epoch, float(epoch_iter) / dataset_size, score)
webpage.save() # save the HTML
return score[metric_name]
metric_name = 'F1_1'
if __name__ == '__main__':
opt = TrainOptions().parse() # get training options
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
total_iters = 0 # the total number of training iterations
miou_best = 0
n_epoch_bad = 0
epoch_best = 0
time_metric = AverageMeter()
time_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'time_log.txt')
with open(time_log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ training time (%s) ================\n' % now)
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
model.train()
# miou_current = val(opt, model)
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_iters += opt.batch_size
epoch_iter += opt.batch_size
n_epoch = opt.niter + opt.niter_decay
model.set_input(data) # unpack data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if ifSaveImage:
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
save_result = total_iters % opt.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
iter_data_time = time.time()
t_epoch = time.time()-epoch_start_time
time_metric.update(t_epoch)
print_current_acc(time_log_name, epoch,{"current_t_epoch": t_epoch})
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
miou_current = val(opt, model)
if miou_current > miou_best:
miou_best = miou_current
epoch_best = epoch
model.save_networks(str(epoch_best)+"_"+metric_name+'_'+'%0.5f'% miou_best)
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.
time_ave = time_metric.average()
print_current_acc(time_log_name, epoch, {"ave_t_epoch": time_ave})

View File

@ -1 +0,0 @@
"""This package includes a miscellaneous collection of useful helper functions."""

View File

@ -1,86 +0,0 @@
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import os
class HTML:
"""This HTML class allows us to save images and write texts into a single HTML file.
It consists of functions such as <add_header> (add a text header to the HTML file),
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
"""
def __init__(self, web_dir, title, refresh=0):
"""Initialize the HTML classes
Parameters:
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
title (str) -- the webpage name
refresh (int) -- how often the website refresh itself; if 0; no refreshing
"""
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
"""Return the directory that stores images"""
return self.img_dir
def add_header(self, text):
"""Insert a header to the HTML file
Parameters:
text (str) -- the header text
"""
with self.doc:
h3(text)
def add_images(self, ims, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
"""save the current content to the HMTL file"""
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__': # we show an example usage here.
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims, txts, links = [], [], []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()

View File

@ -1,176 +0,0 @@
# Adapted from score written by wkentaro
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
import numpy as np
eps=np.finfo(float).eps
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.initialized = False
self.val = None
self.avg = None
self.sum = None
self.count = None
def initialize(self, val, weight):
self.val = val
self.avg = val
self.sum = val * weight
self.count = weight
self.initialized = True
def update(self, val, weight=1):
if not self.initialized:
self.initialize(val, weight)
else:
self.add(val, weight)
def add(self, val, weight):
self.val = val
self.sum += val * weight
self.count += weight
self.avg = self.sum / self.count
def value(self):
return self.val
def average(self):
return self.avg
def get_scores(self):
scores, cls_iu, m_1 = cm2score(self.sum)
scores.update(cls_iu)
scores.update(m_1)
return scores
def cm2score(confusion_matrix):
hist = confusion_matrix
n_class = hist.shape[0]
tp = np.diag(hist)
sum_a1 = hist.sum(axis=1)
sum_a0 = hist.sum(axis=0)
# ---------------------------------------------------------------------- #
# 1. Accuracy & Class Accuracy
# ---------------------------------------------------------------------- #
acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps)
# precision
precision = tp / (sum_a0 + np.finfo(np.float32).eps)
# F1 score
F1 = 2*acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps)
# ---------------------------------------------------------------------- #
# 2. Mean IoU
# ---------------------------------------------------------------------- #
iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
mean_iu = np.nanmean(iu)
cls_iu = dict(zip(range(n_class), iu))
return {'Overall_Acc': acc,
'Mean_IoU': mean_iu}, cls_iu, \
{
'precision_1': precision[1],
'recall_1': acc_cls_[1],
'F1_1': F1[1],}
class RunningMetrics(object):
def __init__(self, num_classes):
"""
Computes and stores the Metric values from Confusion Matrix
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
:param num_classes: <int> number of classes
"""
self.num_classes = num_classes
self.confusion_matrix = np.zeros((num_classes, num_classes))
def __fast_hist(self, label_gt, label_pred):
"""
Collect values for Confusion Matrix
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
:param label_gt: <np.array> ground-truth
:param label_pred: <np.array> prediction
:return: <np.ndarray> values for confusion matrix
"""
mask = (label_gt >= 0) & (label_gt < self.num_classes)
hist = np.bincount(self.num_classes * label_gt[mask].astype(int) + label_pred[mask],
minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes)
return hist
def update(self, label_gts, label_preds):
"""
Compute Confusion Matrix
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
:param label_gts: <np.ndarray> ground-truths
:param label_preds: <np.ndarray> predictions
:return:
"""
for lt, lp in zip(label_gts, label_preds):
self.confusion_matrix += self.__fast_hist(lt.flatten(), lp.flatten())
def reset(self):
"""
Reset Confusion Matrix
:return:
"""
self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
def get_cm(self):
return self.confusion_matrix
def get_scores(self):
"""
Returns score about:
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
:return:
"""
hist = self.confusion_matrix
tp = np.diag(hist)
sum_a1 = hist.sum(axis=1)
sum_a0 = hist.sum(axis=0)
# ---------------------------------------------------------------------- #
# 1. Accuracy & Class Accuracy
# ---------------------------------------------------------------------- #
acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
# recall
acc_cls_ = tp / (sum_a1 + np.finfo(np.float32).eps)
# precision
precision = tp / (sum_a0 + np.finfo(np.float32).eps)
# ---------------------------------------------------------------------- #
# 2. Mean IoU
# ---------------------------------------------------------------------- #
iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
mean_iu = np.nanmean(iu)
cls_iu = dict(zip(range(self.num_classes), iu))
# F1 score
F1 = 2 * acc_cls_ * precision / (acc_cls_ + precision + np.finfo(np.float32).eps)
scores = {'Overall_Acc': acc,
'Mean_IoU': mean_iu}
scores.update(cls_iu)
scores.update({'precision_1': precision[1],
'recall_1': acc_cls_[1],
'F1_1': F1[1]})
return scores

View File

@ -1,110 +0,0 @@
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import ntpath
def save_images(images, img_dir, name):
"""save images in img_dir, with name
iamges: torch.float, B*C*H*W
img_dir: str
name: list [str]
"""
for i, image in enumerate(images):
print(image.shape)
image_numpy = tensor2im(image.unsqueeze(0),normalize=False)*255
basename = os.path.basename(name[i])
print('name:', basename)
save_path = os.path.join(img_dir,basename)
save_image(image_numpy,save_path)
def save_visuals(visuals,img_dir,name):
"""
"""
name = ntpath.basename(name)
name = name.split(".")[0]
print(name)
# save images to the disk
for label, image in visuals.items():
image_numpy = tensor2im(image)
img_path = os.path.join(img_dir, '%s_%s.png' % (name, label))
save_image(image_numpy, img_path)
def tensor2im(input_image, imtype=np.uint8, normalize=True):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = np.transpose(image_numpy, (1, 2, 0))
if normalize:
image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def save_image(image_numpy, image_path):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)

View File

@ -1,246 +0,0 @@
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
# from scipy.misc import imresize
if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
# if aspect_ratio > 1.0:
# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
# if aspect_ratio < 1.0:
# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
util.save_image(im, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
class Visualizer():
"""This class includes several functions that can display/save images and print/save logging information.
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
"""
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: connect to a visdom server
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the option
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.port = opt.display_port
self.saved = False
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
if not self.vis.check_connection():
self.create_visdom_connections()
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
"""Reset the self.saved status"""
self.saved = False
def create_visdom_connections(self):
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
print('Command: %s' % cmd)
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
def display_current_results(self, visuals, epoch, save_result):
"""Display current results on visdom; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
epoch (int) - - the current epoch
save_result (bool) - - if save the current results to an HTML file
"""
if self.display_id > 0: # show images in the browser using visdom
ncols = self.ncols
if ncols > 0: # show all the images in one visdom panel
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h) # create a table css
# create a table of images.
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
try:
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except VisdomExceptionBase:
self.create_visdom_connections()
else: # show each image in a separate visdom panel;
idx = 1
try:
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
except VisdomExceptionBase:
self.create_visdom_connections()
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
self.saved = True
# save images to the disk
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
def plot_current_losses(self, epoch, counter_ratio, losses):
"""display the current losses on visdom display: dictionary of error labels and values
Parameters:
epoch (int) -- current epoch
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
"""
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
except VisdomExceptionBase:
self.create_visdom_connections()
def plot_current_acc(self, epoch, counter_ratio, acc):
if not hasattr(self, 'acc_data'):
self.acc_data = {'X': [], 'Y': [], 'legend': list(acc.keys())}
self.acc_data['X'].append(epoch + counter_ratio)
self.acc_data['Y'].append([acc[k] for k in self.acc_data['legend']])
try:
self.vis.line(
X=np.stack([np.array(self.acc_data['X'])] * len(self.acc_data['legend']), 1),
Y=np.array(self.acc_data['Y']),
opts={
'title': self.name + ' acc over time',
'legend': self.acc_data['legend'],
'xlabel': 'epoch',
'ylabel': 'acc'},
win=self.display_id+3)
except VisdomExceptionBase:
self.create_visdom_connections()
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message

View File

@ -1,96 +0,0 @@
import time
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
import os
from util.util import save_visuals
from util.metrics import AverageMeter
import numpy as np
from util.util import mkdir
def make_val_opt(opt):
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.no_flip2 = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
opt.phase = 'val'
opt.preprocess = 'none1'
opt.isTrain = False
opt.aspect_ratio = 1
opt.eval = True
return opt
def print_current_acc(log_name, epoch, score):
"""print current acc on console; also save the losses to the disk
Parameters:
"""
message = '(epoch: %s) ' % str(epoch)
for k, v in score.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
def val(opt):
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
save_path = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))
mkdir(save_path)
model.eval()
# create a logging file to store training losses
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val1_log.txt')
with open(log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ val acc (%s) ================\n' % now)
running_metrics = AverageMeter()
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
score = model.test(val=True) # run inference return confusion_matrix
running_metrics.update(score)
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_visuals(visuals,save_path,img_path[0])
score = running_metrics.get_scores()
print_current_acc(log_name, opt.epoch, score)
if __name__ == '__main__':
opt = TestOptions().parse() # get training options
opt = make_val_opt(opt)
opt.phase = 'val'
opt.dataroot = 'path-to-LEVIR-CD-test'
opt.dataset_mode = 'changedetection'
opt.n_class = 2
opt.SA_mode = 'PAM'
opt.arch = 'mynet3'
opt.model = 'CDFA'
opt.name = 'pam'
opt.results_dir = './results/'
opt.epoch = '78_F1_1_0.88780'
opt.num_test = np.inf
val(opt)

View File

@ -0,0 +1,39 @@
from .models import create_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Front:
def __init__(self, model) -> None:
self.model = model
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = model.to(self.device)
def __call__(self, inp1, inp2):
inp1 = torch.from_numpy(inp1).to(self.device)
inp1 = torch.transpose(inp1, 0, 2).transpose(1, 2).unsqueeze(0)
inp2 = torch.from_numpy(inp2).to(self.device)
inp2 = torch.transpose(inp2, 0, 2).transpose(1, 2).unsqueeze(0)
out = self.model(inp1, inp2)
out = out.sigmoid()
out = out.cpu().detach().numpy()[0,0]
return out
def get_model(name):
try:
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=1, # model output channels (number of classes in your datasets)
siam_encoder=True, # whether to use a siamese encoder
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
)
except:
return None
return Front(model)

View File

@ -0,0 +1 @@
from .model import DPFCN

View File

@ -0,0 +1,134 @@
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md
from ..base import Decoder
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class DPFCNDecoder(Decoder):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
fusion_form="concat",
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
# adjust encoder channels according to fusion form
self.fusion_form = fusion_form
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
skip_channels = [ch*2 for ch in skip_channels]
in_channels[0] = in_channels[0] * 2
head_channels = head_channels * 2
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, *features):
features = self.aggregation_layer(features[0], features[1],
self.fusion_form, ignore_original_img=True)
# features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
return x

View File

@ -0,0 +1,72 @@
from typing import Optional, Union, List
from .decoder import DPFCNDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead
class DPFCN(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
siam_encoder: bool = True,
fusion_form: str = "concat",
**kwargs
):
super().__init__()
self.siam_encoder = siam_encoder
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
if not self.siam_encoder:
self.encoder_non_siam = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
self.decoder = DPFCNDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
attention_type=decoder_attention_type,
fusion_form=fusion_form,
)
self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=classes,
activation=activation,
kernel_size=3,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None
self.name = "u-{}".format(encoder_name)
self.initialize()

View File

@ -0,0 +1 @@
from .model import DVCA

View File

@ -0,0 +1,179 @@
"""
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
from torch import nn
from torch.nn import functional as F
from ..base import Decoder
__all__ = ["DVCADecoder"]
class DVCADecoder(Decoder):
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36), fusion_form="concat"):
super().__init__()
# adjust encoder channels according to fusion form
if fusion_form in self.FUSION_DIC["2to2_fusion"]:
in_channels = in_channels * 2
self.aspp = nn.Sequential(
ASPP(in_channels, out_channels, atrous_rates),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
self.out_channels = out_channels
self.fusion_form = fusion_form
def forward(self, *features):
x = self.fusion(features[0][-1], features[1][-1], self.fusion_form)
x = self.aspp(x)
return x
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
class ASPPSeparableConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
SeparableConv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
super(ASPP, self).__init__()
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
)
rate1, rate2, rate3 = tuple(atrous_rates)
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5),
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class SeparableConv2d(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
):
dephtwise_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False,
)
pointwise_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=bias,
)
super().__init__(dephtwise_conv, pointwise_conv)

View File

@ -0,0 +1,69 @@
from typing import Optional
import torch.nn as nn
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
from ..encoders import get_encoder
from .decoder import DVCADecoder
class DVCA(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_channels: int = 256,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
siam_encoder: bool = True,
fusion_form: str = "concat",
**kwargs
):
super().__init__()
self.siam_encoder = siam_encoder
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
)
if not self.siam_encoder:
self.encoder_non_siam = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
)
self.decoder = DVCADecoder(
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
fusion_form=fusion_form,
)
self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

View File

@ -0,0 +1 @@
from .model import RCNN

View File

@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import Decoder
class Conv3x3GNReLU(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False):
super().__init__()
self.upsample = upsample
self.block = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
),
nn.GroupNorm(32, out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.block(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
return x
class FPNBlock(nn.Module):
def __init__(self, pyramid_channels, skip_channels):
super().__init__()
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
skip = self.skip_conv(skip)
x = x + skip
return x
class SegmentationBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_upsamples=0):
super().__init__()
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
if n_upsamples > 1:
for _ in range(1, n_upsamples):
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
self.block = nn.Sequential(*blocks)
def forward(self, x):
return self.block(x)
class MergeBlock(nn.Module):
def __init__(self, policy):
super().__init__()
if policy not in ["add", "cat"]:
raise ValueError(
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
policy
)
)
self.policy = policy
def forward(self, x):
if self.policy == 'add':
return sum(x)
elif self.policy == 'cat':
return torch.cat(x, dim=1)
else:
raise ValueError(
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
)
class RCNNDecoder(Decoder):
def __init__(
self,
encoder_channels,
encoder_depth=5,
pyramid_channels=256,
segmentation_channels=128,
dropout=0.2,
merge_policy="add",
fusion_form="concat",
):
super().__init__()
self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
if encoder_depth < 3:
raise ValueError("Encoder depth for RCNN decoder cannot be less than 3, got {}.".format(encoder_depth))
encoder_channels = encoder_channels[::-1]
encoder_channels = encoder_channels[:encoder_depth + 1]
# (512, 256, 128, 64, 64, 3)
# adjust encoder channels according to fusion form
self.fusion_form = fusion_form
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
encoder_channels = [ch*2 for ch in encoder_channels]
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
self.seg_blocks = nn.ModuleList([
SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
for n_upsamples in [3, 2, 1, 0]
])
self.merge = MergeBlock(merge_policy)
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
def forward(self, *features):
features = self.aggregation_layer(features[0], features[1],
self.fusion_form, ignore_original_img=True)
c2, c3, c4, c5 = features[-4:]
p5 = self.p5(c5)
p4 = self.p4(p5, c4)
p3 = self.p3(p4, c3)
p2 = self.p2(p3, c2)
feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])]
x = self.merge(feature_pyramid)
x = self.dropout(x)
return x

View File

@ -0,0 +1,73 @@
from typing import Optional, Union
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
from ..encoders import get_encoder
from .decoder import RCNNDecoder
class RCNN(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_pyramid_channels: int = 256,
decoder_segmentation_channels: int = 128,
decoder_merge_policy: str = "add",
decoder_dropout: float = 0.2,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
siam_encoder: bool = True,
fusion_form: str = "concat",
**kwargs
):
super().__init__()
self.siam_encoder = siam_encoder
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
if not self.siam_encoder:
self.encoder_non_siam = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
self.decoder = RCNNDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
pyramid_channels=decoder_pyramid_channels,
segmentation_channels=decoder_segmentation_channels,
dropout=decoder_dropout,
merge_policy=decoder_merge_policy,
fusion_form=fusion_form,
)
self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None
self.name = "rcnn-{}".format(encoder_name)
self.initialize()

View File

@ -0,0 +1,42 @@
from .RCNN import RCNN
from .DVCA import DVCA
from .DPFCN import DPFCN
from . import encoders
from . import utils
from . import losses
from . import datasets
from .__version__ import __version__
from typing import Optional
import torch
def create_model(
arch: str,
encoder_name: str = "resnet34",
encoder_weights: Optional[str] = "imagenet",
in_channels: int = 3,
classes: int = 1,
**kwargs,
) -> torch.nn.Module:
"""Models wrapper. Allows to create any model just with parametes
"""
archs = [DVCA, DPFCN, RCNN]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
model_class = archs_dict[arch.lower()]
except KeyError:
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format(
arch, list(archs_dict.keys()),
))
return model_class(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
**kwargs,
)

View File

@ -0,0 +1,3 @@
VERSION = (0, 1, 4)
__version__ = '.'.join(map(str, VERSION))

View File

@ -0,0 +1,12 @@
from .model import SegmentationModel
from .decoder import Decoder
from .modules import (
Conv2dReLU,
Attention,
)
from .heads import (
SegmentationHead,
ClassificationHead,
)

View File

@ -0,0 +1,33 @@
import torch
class Decoder(torch.nn.Module):
# TODO: support learnable fusion modules
def __init__(self):
super().__init__()
self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"],
"2to2_fusion": ["concat"]}
def fusion(self, x1, x2, fusion_form="concat"):
"""Specify the form of feature fusion"""
if fusion_form == "concat":
x = torch.cat([x1, x2], dim=1)
elif fusion_form == "sum":
x = x1 + x2
elif fusion_form == "diff":
x = x2 - x1
elif fusion_form == "abs_diff":
x = torch.abs(x1 - x2)
else:
raise ValueError('the fusion form "{}" is not defined'.format(fusion_form))
return x
def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True):
"""aggregate features from siamese or non-siamese branches"""
start_idx = 1 if ignore_original_img else 0
aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form)
for idx in range(start_idx, len(fea1))]
return aggregate_fea

View File

@ -0,0 +1,24 @@
import torch.nn as nn
from .modules import Flatten, Activation
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1, align_corners=True):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity()
activation = Activation(activation)
super().__init__(conv2d, upsampling, activation)
class ClassificationHead(nn.Sequential):
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
if pooling not in ("max", "avg"):
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1)
flatten = Flatten()
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
linear = nn.Linear(in_channels, classes, bias=True)
activation = Activation(activation)
super().__init__(pool, flatten, dropout, linear, activation)

View File

@ -0,0 +1,27 @@
import torch.nn as nn
def initialize_decoder(module):
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def initialize_head(module):
for m in module.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

View File

@ -0,0 +1,53 @@
import torch
from . import initialization as init
class SegmentationModel(torch.nn.Module):
def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
if self.classification_head is not None:
init.initialize_head(self.classification_head)
def base_forward(self, x1, x2):
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
if self.siam_encoder:
features = self.encoder(x1), self.encoder(x2)
else:
features = self.encoder(x1), self.encoder_non_siam(x2)
decoder_output = self.decoder(*features)
# TODO: features = self.fusion_policy(features)
masks = self.segmentation_head(decoder_output)
if self.classification_head is not None:
raise AttributeError("`classification_head` is not supported now.")
# labels = self.classification_head(features[-1])
# return masks, labels
return masks
def forward(self, x1, x2):
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
return self.base_forward(x1, x2)
def predict(self, x1, x2):
"""Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()`
Args:
x1, x2: 4D torch tensor with shape (batch_size, channels, height, width)
Return:
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
"""
if self.training:
self.eval()
with torch.no_grad():
x = self.forward(x1, x2)
return x

View File

@ -0,0 +1,242 @@
import torch
import torch.nn as nn
try:
from inplace_abn import InPlaceABN
except ImportError:
InPlaceABN = None
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)
class CBAMChannel(nn.Module):
def __init__(self, in_channels, reduction=16):
super(CBAMChannel, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return x * self.sigmoid(out)
class CBAMSpatial(nn.Module):
def __init__(self, in_channels, kernel_size=7):
super(CBAMSpatial, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv1(out)
return x * self.sigmoid(out)
class CBAM(nn.Module):
"""
Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]
//Proceedings of the European conference on computer vision (ECCV).
"""
def __init__(self, in_channels, reduction=16, kernel_size=7):
super(CBAM, self).__init__()
self.ChannelGate = CBAMChannel(in_channels, reduction)
self.SpatialGate = CBAMSpatial(kernel_size)
def forward(self, x):
x = self.ChannelGate(x)
x = self.SpatialGate(x)
return x
class ECAM(nn.Module):
"""
Ensemble Channel Attention Module for UNetPlusPlus.
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
IEEE Geoscience and Remote Sensing Letters, 2021.
Not completely consistent, to be improved.
"""
def __init__(self, in_channels, out_channels, map_num=4):
super(ECAM, self).__init__()
self.ca1 = CBAMChannel(in_channels * map_num, reduction=16)
self.ca2 = CBAMChannel(in_channels, reduction=16 // 4)
self.up = nn.ConvTranspose2d(in_channels * map_num, in_channels * map_num, 2, stride=2)
self.conv_final = nn.Conv2d(in_channels * map_num, out_channels, kernel_size=1)
def forward(self, x):
"""
x (list[tensor] or tuple(tensor))
"""
out = torch.cat(x, 1)
intra = torch.sum(torch.stack(x), dim=0)
ca2 = self.ca2(intra)
out = self.ca1(out) * (out + ca2.repeat(1, 4, 1, 1))
out = self.up(out)
out = self.conv_final(out)
return out
class SEModule(nn.Module):
"""
Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
"""
def __init__(self, in_channels, reduction=16):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ArgMax(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, x):
return torch.argmax(x, dim=self.dim)
class Clamp(nn.Module):
def __init__(self, min=0, max=1):
super().__init__()
self.min, self.max = min, max
def forward(self, x):
return torch.clamp(x, self.min, self.max)
class Activation(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None or name == 'identity':
self.activation = nn.Identity(**params)
elif name == 'sigmoid':
self.activation = nn.Sigmoid()
elif name == 'softmax2d':
self.activation = nn.Softmax(dim=1, **params)
elif name == 'softmax':
self.activation = nn.Softmax(**params)
elif name == 'logsoftmax':
self.activation = nn.LogSoftmax(**params)
elif name == 'tanh':
self.activation = nn.Tanh()
elif name == 'argmax':
self.activation = ArgMax(**params)
elif name == 'argmax2d':
self.activation = ArgMax(dim=1, **params)
elif name == 'clamp':
self.activation = Clamp(**params)
elif callable(name):
self.activation = name(**params)
else:
raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name))
def forward(self, x):
return self.activation(x)
class Attention(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None:
self.attention = nn.Identity(**params)
elif name == 'scse':
self.attention = SCSEModule(**params)
elif name == 'cbam_channel':
self.attention = CBAMChannel(**params)
elif name == 'cbam_spatial':
self.attention = CBAMSpatial(**params)
elif name == 'cbam':
self.attention = CBAM(**params)
elif name == 'se':
self.attention = SEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))
def forward(self, x):
return self.attention(x)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)

View File

@ -0,0 +1,113 @@
import functools
import torch
import torch.utils.model_zoo as model_zoo
from ._preprocessing import preprocess_input
from .densenet import densenet_encoders
from .dpn import dpn_encoders
from .efficientnet import efficient_net_encoders
from .inceptionresnetv2 import inceptionresnetv2_encoders
from .inceptionv4 import inceptionv4_encoders
from .mobilenet import mobilenet_encoders
from .resnet import resnet_encoders
from .senet import senet_encoders
from .timm_efficientnet import timm_efficientnet_encoders
from .timm_gernet import timm_gernet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_res2net import timm_res2net_encoders
from .timm_resnest import timm_resnest_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_universal import TimmUniversalEncoder
from .vgg import vgg_encoders
from .xception import xception_encoders
from .swin_transformer import swin_transformer_encoders
from .mit_encoder import mit_encoders
# from .hrnet import hrnet_encoders
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
encoders = {}
encoders.update(resnet_encoders)
encoders.update(dpn_encoders)
encoders.update(vgg_encoders)
encoders.update(senet_encoders)
encoders.update(densenet_encoders)
encoders.update(inceptionresnetv2_encoders)
encoders.update(inceptionv4_encoders)
encoders.update(efficient_net_encoders)
encoders.update(mobilenet_encoders)
encoders.update(xception_encoders)
encoders.update(timm_efficientnet_encoders)
encoders.update(timm_resnest_encoders)
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_mobilenetv3_encoders)
encoders.update(timm_gernet_encoders)
encoders.update(swin_transformer_encoders)
encoders.update(mit_encoders)
# encoders.update(hrnet_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
if name.startswith("tu-"):
name = name[3:]
encoder = TimmUniversalEncoder(
name=name,
in_channels=in_channels,
depth=depth,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs
)
return encoder
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)
if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights, name, list(encoders[name]["pretrained_settings"].keys()),
))
encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE)))
encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def get_encoder_names():
return list(encoders.keys())
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
settings = encoders[encoder_name]["pretrained_settings"]
if pretrained not in settings.keys():
raise ValueError("Available pretrained options {}".format(settings.keys()))
formatted_settings = {}
formatted_settings["input_space"] = settings[pretrained].get("input_space")
formatted_settings["input_range"] = settings[pretrained].get("input_range")
formatted_settings["mean"] = settings[pretrained].get("mean")
formatted_settings["std"] = settings[pretrained].get("std")
return formatted_settings
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
return functools.partial(preprocess_input, **params)

View File

@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from typing import List
from collections import OrderedDict
from . import _utils as utils
class EncoderMixin:
"""Add encoder functionality such as:
- output channels specification of feature tensors (produced by encoder)
- patching first convolution for arbitrary input channels
"""
@property
def out_channels(self):
"""Return channels dimensions for each tensor of forward output of encoder"""
return self._out_channels[: self._depth + 1]
def set_in_channels(self, in_channels, pretrained=True):
"""Change first convolution channels"""
if in_channels == 3:
return
self._in_channels = in_channels
if self._out_channels[0] == 3:
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
def get_stages(self):
"""Method should be overridden in encoder"""
raise NotImplementedError
def make_dilated(self, output_stride):
if output_stride == 16:
stage_list=[5,]
dilation_list=[2,]
elif output_stride == 8:
stage_list=[4, 5]
dilation_list=[2, 4]
else:
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
stages = self.get_stages()
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
utils.replace_strides_with_dilation(
module=stages[stage_indx],
dilation_rate=dilation_rate,
)

View File

@ -0,0 +1,23 @@
import numpy as np
def preprocess_input(
x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
):
if input_space == "BGR":
x = x[..., ::-1].copy()
if input_range is not None:
if x.max() > 1 and input_range[1] == 1:
x = x / 255.0
if mean is not None:
mean = np.array(mean)
x = x - mean
if std is not None:
std = np.array(std)
x = x / std
return x

View File

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
"""Change first convolution layer input channels.
In case:
in_channels == 1 or in_channels == 2 -> reuse original weights
in_channels > 3 -> make random kaiming normal initialization
"""
# get first conv
for module in model.modules():
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
break
weight = module.weight.detach()
module.in_channels = new_in_channels
if not pretrained:
module.weight = nn.parameter.Parameter(
torch.Tensor(
module.out_channels,
new_in_channels // module.groups,
*module.kernel_size
)
)
module.reset_parameters()
elif new_in_channels == 1:
new_weight = weight.sum(1, keepdim=True)
module.weight = nn.parameter.Parameter(new_weight)
else:
new_weight = torch.Tensor(
module.out_channels,
new_in_channels // module.groups,
*module.kernel_size
)
for i in range(new_in_channels):
new_weight[:, i] = weight[:, i % default_in_channels]
new_weight = new_weight * (default_in_channels / new_in_channels)
module.weight = nn.parameter.Parameter(new_weight)
def replace_strides_with_dilation(module, dilation_rate):
"""Patch Conv2d modules replacing strides with dilation"""
for mod in module.modules():
if isinstance(mod, nn.Conv2d):
mod.stride = (1, 1)
mod.dilation = (dilation_rate, dilation_rate)
kh, kw = mod.kernel_size
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
# Kostyl for EfficientNet
if hasattr(mod, "static_padding"):
mod.static_padding = nn.Identity()

View File

@ -0,0 +1,146 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import re
import torch.nn as nn
from pretrainedmodels.models.torchvision_models import pretrained_settings
from torchvision.models.densenet import DenseNet
from ._base import EncoderMixin
class TransitionWithSkip(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
for module in self.module:
x = module(x)
if isinstance(module, nn.ReLU):
skip = x
return x, skip
class DenseNetEncoder(DenseNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
del self.classifier
def make_dilated(self, output_stride):
raise ValueError("DenseNet encoders do not support dilated mode "
"due to pooling operation for downsampling!")
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0),
nn.Sequential(self.features.pool0, self.features.denseblock1,
TransitionWithSkip(self.features.transition1)),
nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)),
nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)),
nn.Sequential(self.features.denseblock4, self.features.norm5)
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
if isinstance(x, (list, tuple)):
x, skip = x
features.append(skip)
else:
features.append(x)
return features
def load_state_dict(self, state_dict):
pattern = re.compile(
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
# remove linear
state_dict.pop("classifier.bias", None)
state_dict.pop("classifier.weight", None)
super().load_state_dict(state_dict)
densenet_encoders = {
"densenet121": {
"encoder": DenseNetEncoder,
"pretrained_settings": pretrained_settings["densenet121"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 1024),
"num_init_features": 64,
"growth_rate": 32,
"block_config": (6, 12, 24, 16),
},
},
"densenet169": {
"encoder": DenseNetEncoder,
"pretrained_settings": pretrained_settings["densenet169"],
"params": {
"out_channels": (3, 64, 256, 512, 1280, 1664),
"num_init_features": 64,
"growth_rate": 32,
"block_config": (6, 12, 32, 32),
},
},
"densenet201": {
"encoder": DenseNetEncoder,
"pretrained_settings": pretrained_settings["densenet201"],
"params": {
"out_channels": (3, 64, 256, 512, 1792, 1920),
"num_init_features": 64,
"growth_rate": 32,
"block_config": (6, 12, 48, 32),
},
},
"densenet161": {
"encoder": DenseNetEncoder,
"pretrained_settings": pretrained_settings["densenet161"],
"params": {
"out_channels": (3, 96, 384, 768, 2112, 2208),
"num_init_features": 96,
"growth_rate": 48,
"block_config": (6, 12, 36, 24),
},
},
}

View File

@ -0,0 +1,170 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pretrainedmodels.models.dpn import DPN
from pretrainedmodels.models.dpn import pretrained_settings
from ._base import EncoderMixin
class DPNEncoder(DPN, EncoderMixin):
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._stage_idxs = stage_idxs
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.last_linear
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
if isinstance(x, (list, tuple)):
features.append(F.relu(torch.cat(x, dim=1), inplace=True))
else:
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)
dpn_encoders = {
"dpn68": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn68"],
"params": {
"stage_idxs": (4, 8, 20, 24),
"out_channels": (3, 10, 144, 320, 704, 832),
"groups": 32,
"inc_sec": (16, 32, 32, 64),
"k_r": 128,
"k_sec": (3, 4, 12, 3),
"num_classes": 1000,
"num_init_features": 10,
"small": True,
"test_time_pool": True,
},
},
"dpn68b": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn68b"],
"params": {
"stage_idxs": (4, 8, 20, 24),
"out_channels": (3, 10, 144, 320, 704, 832),
"b": True,
"groups": 32,
"inc_sec": (16, 32, 32, 64),
"k_r": 128,
"k_sec": (3, 4, 12, 3),
"num_classes": 1000,
"num_init_features": 10,
"small": True,
"test_time_pool": True,
},
},
"dpn92": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn92"],
"params": {
"stage_idxs": (4, 8, 28, 32),
"out_channels": (3, 64, 336, 704, 1552, 2688),
"groups": 32,
"inc_sec": (16, 32, 24, 128),
"k_r": 96,
"k_sec": (3, 4, 20, 3),
"num_classes": 1000,
"num_init_features": 64,
"test_time_pool": True,
},
},
"dpn98": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn98"],
"params": {
"stage_idxs": (4, 10, 30, 34),
"out_channels": (3, 96, 336, 768, 1728, 2688),
"groups": 40,
"inc_sec": (16, 32, 32, 128),
"k_r": 160,
"k_sec": (3, 6, 20, 3),
"num_classes": 1000,
"num_init_features": 96,
"test_time_pool": True,
},
},
"dpn107": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn107"],
"params": {
"stage_idxs": (5, 13, 33, 37),
"out_channels": (3, 128, 376, 1152, 2432, 2688),
"groups": 50,
"inc_sec": (20, 64, 64, 128),
"k_r": 200,
"k_sec": (4, 8, 20, 3),
"num_classes": 1000,
"num_init_features": 128,
"test_time_pool": True,
},
},
"dpn131": {
"encoder": DPNEncoder,
"pretrained_settings": pretrained_settings["dpn131"],
"params": {
"stage_idxs": (5, 13, 41, 45),
"out_channels": (3, 128, 352, 832, 1984, 2688),
"groups": 40,
"inc_sec": (16, 32, 32, 128),
"k_r": 160,
"k_sec": (4, 8, 28, 3),
"num_classes": 1000,
"num_init_features": 128,
"test_time_pool": True,
},
},
}

View File

@ -0,0 +1,178 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params
from ._base import EncoderMixin
class EfficientNetEncoder(EfficientNet, EncoderMixin):
def __init__(self, stage_idxs, out_channels, model_name, depth=5):
blocks_args, global_params = get_model_params(model_name, override_params=None)
super().__init__(blocks_args, global_params)
self._stage_idxs = stage_idxs
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
del self._fc
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self._conv_stem, self._bn0, self._swish),
self._blocks[:self._stage_idxs[0]],
self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
self._blocks[self._stage_idxs[2]:],
]
def forward(self, x):
stages = self.get_stages()
block_number = 0.
drop_connect_rate = self._global_params.drop_connect_rate
features = []
for i in range(self._depth + 1):
# Identity and Sequential stages
if i < 2:
x = stages[i](x)
# Block stages need drop_connect rate
else:
for module in stages[i]:
drop_connect = drop_connect_rate * block_number / len(self._blocks)
block_number += 1.
x = module(x, drop_connect)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("_fc.bias", None)
state_dict.pop("_fc.weight", None)
super().load_state_dict(state_dict, **kwargs)
def _get_pretrained_settings(encoder):
pretrained_settings = {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": url_map[encoder],
"input_space": "RGB",
"input_range": [0, 1],
},
"advprop": {
"mean": [0.5, 0.5, 0.5],
"std": [0.5, 0.5, 0.5],
"url": url_map_advprop[encoder],
"input_space": "RGB",
"input_range": [0, 1],
}
}
return pretrained_settings
efficient_net_encoders = {
"efficientnet-b0": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (3, 5, 9, 16),
"model_name": "efficientnet-b0",
},
},
"efficientnet-b1": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (5, 8, 16, 23),
"model_name": "efficientnet-b1",
},
},
"efficientnet-b2": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
"params": {
"out_channels": (3, 32, 24, 48, 120, 352),
"stage_idxs": (5, 8, 16, 23),
"model_name": "efficientnet-b2",
},
},
"efficientnet-b3": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
"params": {
"out_channels": (3, 40, 32, 48, 136, 384),
"stage_idxs": (5, 8, 18, 26),
"model_name": "efficientnet-b3",
},
},
"efficientnet-b4": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
"params": {
"out_channels": (3, 48, 32, 56, 160, 448),
"stage_idxs": (6, 10, 22, 32),
"model_name": "efficientnet-b4",
},
},
"efficientnet-b5": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
"params": {
"out_channels": (3, 48, 40, 64, 176, 512),
"stage_idxs": (8, 13, 27, 39),
"model_name": "efficientnet-b5",
},
},
"efficientnet-b6": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
"params": {
"out_channels": (3, 56, 40, 72, 200, 576),
"stage_idxs": (9, 15, 31, 45),
"model_name": "efficientnet-b6",
},
},
"efficientnet-b7": {
"encoder": EfficientNetEncoder,
"pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
"params": {
"out_channels": (3, 64, 48, 80, 224, 640),
"stage_idxs": (11, 18, 38, 55),
"model_name": "efficientnet-b7",
},
},
}

View File

@ -0,0 +1,90 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torch.nn as nn
from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2
from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings
from ._base import EncoderMixin
class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
# correct paddings
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m.kernel_size == (3, 3):
m.padding = (1, 1)
if isinstance(m, nn.MaxPool2d):
m.padding = (1, 1)
# remove linear layers
del self.avgpool_1a
del self.last_linear
def make_dilated(self, output_stride):
raise ValueError("InceptionResnetV2 encoder does not support dilated mode "
"due to pooling operation for downsampling!")
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b),
nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a),
nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat),
nn.Sequential(self.mixed_6a, self.repeat_1),
nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b),
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)
inceptionresnetv2_encoders = {
"inceptionresnetv2": {
"encoder": InceptionResNetV2Encoder,
"pretrained_settings": pretrained_settings["inceptionresnetv2"],
"params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000},
}
}

View File

@ -0,0 +1,93 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torch.nn as nn
from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d
from pretrainedmodels.models.inceptionv4 import pretrained_settings
from ._base import EncoderMixin
class InceptionV4Encoder(InceptionV4, EncoderMixin):
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._stage_idxs = stage_idxs
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
# correct paddings
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m.kernel_size == (3, 3):
m.padding = (1, 1)
if isinstance(m, nn.MaxPool2d):
m.padding = (1, 1)
# remove linear layers
del self.last_linear
def make_dilated(self, output_stride):
raise ValueError("InceptionV4 encoder does not support dilated mode "
"due to pooling operation for downsampling!")
def get_stages(self):
return [
nn.Identity(),
self.features[: self._stage_idxs[0]],
self.features[self._stage_idxs[0]: self._stage_idxs[1]],
self.features[self._stage_idxs[1]: self._stage_idxs[2]],
self.features[self._stage_idxs[2]: self._stage_idxs[3]],
self.features[self._stage_idxs[3]:],
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)
inceptionv4_encoders = {
"inceptionv4": {
"encoder": InceptionV4Encoder,
"pretrained_settings": pretrained_settings["inceptionv4"],
"params": {
"stage_idxs": (3, 5, 9, 15),
"out_channels": (3, 64, 192, 384, 1024, 1536),
"num_classes": 1001,
},
}
}

View File

@ -0,0 +1,192 @@
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from pretrainedmodels.models.torchvision_models import pretrained_settings
from ._base import EncoderMixin
from .mix_transformer import MixVisionTransformer
class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
def get_stages(self):
return [nn.Identity()]
def forward(self, x):
stages = self.get_stages()
features = []
for stage in stages:
x = stage(x)
features.append(x)
outs = self.forward_features(x)
add_feature = F.interpolate(outs[0], scale_factor=2)
features = features + [add_feature] + outs
return features
def load_state_dict(self, state_dict, **kwargs):
new_state_dict = {}
if state_dict.get('state_dict'):
state_dict = state_dict['state_dict']
for k, v in state_dict.items():
if k.startswith('backbone'):
new_state_dict[k.replace('backbone.', '')] = v
else:
new_state_dict = deepcopy(state_dict)
super().load_state_dict(new_state_dict, **kwargs)
# https://github.com/NVlabs/SegFormer
new_settings = {
"mit-b0": {
"imagenet": "https://lino.local.server/mit_b0.pth"
},
"mit-b1": {
"imagenet": "https://lino.local.server/mit_b1.pth"
},
"mit-b2": {
"imagenet": "https://lino.local.server/mit_b2.pth"
},
"mit-b3": {
"imagenet": "https://lino.local.server/mit_b3.pth"
},
"mit-b4": {
"imagenet": "https://lino.local.server/mit_b4.pth"
},
"mit-b5": {
"imagenet": "https://lino.local.server/mit_b5.pth"
},
}
pretrained_settings = deepcopy(pretrained_settings)
for model_name, sources in new_settings.items():
if model_name not in pretrained_settings:
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
mit_encoders = {
"mit-b0": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b0"],
"params": {
"patch_size": 4,
"embed_dims": [32, 64, 160, 256],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [2, 2, 2, 2],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 32, 32, 64, 160, 256)
}
},
"mit-b1": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b1"],
"params": {
"patch_size": 4,
"embed_dims": [64, 128, 320, 512],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [2, 2, 2, 2],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 64, 64, 128, 320, 512)
}
},
"mit-b2": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b2"],
"params": {
"patch_size": 4,
"embed_dims": [64, 128, 320, 512],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [3, 4, 6, 3],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 64, 64, 128, 320, 512)
}
},
"mit-b3": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b3"],
"params": {
"patch_size": 4,
"embed_dims": [64, 128, 320, 512],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [3, 4, 18, 3],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 64, 64, 128, 320, 512)
}
},
"mit-b4": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b4"],
"params": {
"patch_size": 4,
"embed_dims": [64, 128, 320, 512],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [3, 8, 27, 3],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 64, 64, 128, 320, 512)
}
},
"mit-b5": {
"encoder": MixVisionTransformerEncoder,
"pretrained_settings": pretrained_settings["mit-b5"],
"params": {
"patch_size": 4,
"embed_dims": [64, 128, 320, 512],
"num_heads": [1, 2, 5, 8],
"mlp_ratios": [4, 4, 4, 4],
"qkv_bias": True,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"depths": [3, 6, 40, 3],
"sr_ratios": [8, 4, 2, 1],
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"out_channels": (3, 64, 64, 128, 320, 512)
}
},
}

View File

@ -0,0 +1,361 @@
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import math
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.num_classes = num_classes
self.depths = depths
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
# classification head
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x

View File

@ -0,0 +1,83 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torchvision
import torch.nn as nn
from ._base import EncoderMixin
class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.classifier
def get_stages(self):
return [
nn.Identity(),
self.features[:2],
self.features[2:4],
self.features[4:7],
self.features[7:14],
self.features[14:],
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("classifier.1.bias", None)
state_dict.pop("classifier.1.weight", None)
super().load_state_dict(state_dict, **kwargs)
mobilenet_encoders = {
"mobilenet_v2": {
"encoder": MobileNetV2Encoder,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"input_space": "RGB",
"input_range": [0, 1],
},
},
"params": {
"out_channels": (3, 16, 24, 32, 96, 1280),
},
},
}

View File

@ -0,0 +1,238 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
from copy import deepcopy
import torch.nn as nn
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck
from pretrainedmodels.models.torchvision_models import pretrained_settings
from ._base import EncoderMixin
class ResNetEncoder(ResNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.fc
del self.avgpool
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.conv1, self.bn1, self.relu),
nn.Sequential(self.maxpool, self.layer1),
self.layer2,
self.layer3,
self.layer4,
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("fc.bias", None)
state_dict.pop("fc.weight", None)
super().load_state_dict(state_dict, **kwargs)
new_settings = {
"resnet18": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
},
"resnet50": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
},
"resnext50_32x4d": {
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth",
},
"resnext101_32x4d": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
},
"resnext101_32x8d": {
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth",
},
"resnext101_32x16d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth",
},
"resnext101_32x32d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
},
"resnext101_32x48d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
}
}
pretrained_settings = deepcopy(pretrained_settings)
for model_name, sources in new_settings.items():
if model_name not in pretrained_settings:
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
resnet_encoders = {
"resnet18": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnet18"],
"params": {
"out_channels": (3, 64, 64, 128, 256, 512),
"block": BasicBlock,
"layers": [2, 2, 2, 2],
},
},
"resnet34": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnet34"],
"params": {
"out_channels": (3, 64, 64, 128, 256, 512),
"block": BasicBlock,
"layers": [3, 4, 6, 3],
},
},
"resnet50": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnet50"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 6, 3],
},
},
"resnet101": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnet101"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
},
},
"resnet152": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnet152"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 8, 36, 3],
},
},
"resnext50_32x4d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext50_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 6, 3],
"groups": 32,
"width_per_group": 4,
},
},
"resnext101_32x4d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 4,
},
},
"resnext101_32x8d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x8d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 8,
},
},
"resnext101_32x16d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x16d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 16,
},
},
"resnext101_32x32d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x32d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 32,
},
},
"resnext101_32x48d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x48d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 48,
},
},
}

View File

@ -0,0 +1,174 @@
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
Attributes:
_out_channels (list of int): specify number of channels for each encoder feature tensor
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
Methods:
forward(self, x: torch.Tensor)
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
with resolution same as input `x` tensor).
Input: `x` with shape (1, 3, 64, 64)
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
also should support number of features according to specified depth, e.g. if depth = 5,
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
import torch.nn as nn
from pretrainedmodels.models.senet import (
SENet,
SEBottleneck,
SEResNetBottleneck,
SEResNeXtBottleneck,
pretrained_settings,
)
from ._base import EncoderMixin
class SENetEncoder(SENet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
del self.last_linear
del self.avg_pool
def get_stages(self):
return [
nn.Identity(),
self.layer0[:-1],
nn.Sequential(self.layer0[-1], self.layer1),
self.layer2,
self.layer3,
self.layer4,
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)
senet_encoders = {
"senet154": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["senet154"],
"params": {
"out_channels": (3, 128, 256, 512, 1024, 2048),
"block": SEBottleneck,
"dropout_p": 0.2,
"groups": 64,
"layers": [3, 8, 36, 3],
"num_classes": 1000,
"reduction": 16,
},
},
"se_resnet50": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["se_resnet50"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": SEResNetBottleneck,
"layers": [3, 4, 6, 3],
"downsample_kernel_size": 1,
"downsample_padding": 0,
"dropout_p": None,
"groups": 1,
"inplanes": 64,
"input_3x3": False,
"num_classes": 1000,
"reduction": 16,
},
},
"se_resnet101": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["se_resnet101"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": SEResNetBottleneck,
"layers": [3, 4, 23, 3],
"downsample_kernel_size": 1,
"downsample_padding": 0,
"dropout_p": None,
"groups": 1,
"inplanes": 64,
"input_3x3": False,
"num_classes": 1000,
"reduction": 16,
},
},
"se_resnet152": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["se_resnet152"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": SEResNetBottleneck,
"layers": [3, 8, 36, 3],
"downsample_kernel_size": 1,
"downsample_padding": 0,
"dropout_p": None,
"groups": 1,
"inplanes": 64,
"input_3x3": False,
"num_classes": 1000,
"reduction": 16,
},
},
"se_resnext50_32x4d": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["se_resnext50_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": SEResNeXtBottleneck,
"layers": [3, 4, 6, 3],
"downsample_kernel_size": 1,
"downsample_padding": 0,
"dropout_p": None,
"groups": 32,
"inplanes": 64,
"input_3x3": False,
"num_classes": 1000,
"reduction": 16,
},
},
"se_resnext101_32x4d": {
"encoder": SENetEncoder,
"pretrained_settings": pretrained_settings["se_resnext101_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": SEResNeXtBottleneck,
"layers": [3, 4, 23, 3],
"downsample_kernel_size": 1,
"downsample_padding": 0,
"dropout_p": None,
"groups": 32,
"inplanes": 64,
"input_3x3": False,
"num_classes": 1000,
"reduction": 16,
},
},
}

View File

@ -0,0 +1,196 @@
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import OrderedDict
from pretrainedmodels.models.torchvision_models import pretrained_settings
from ._base import EncoderMixin
from .swin_transformer_model import SwinTransformer
class SwinTransformerEncoder(SwinTransformer, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
def get_stages(self):
return [nn.Identity()]
def feature_forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return outs
def forward(self, x):
stages = self.get_stages()
features = []
for stage in stages:
x = stage(x)
features.append(x)
outs = self.feature_forward(x)
# Note: An additional interpolated feature to accommodate five-stage decoders,\
# the additional feature will be ignored if a decoder with fewer stages is used.
add_feature = F.interpolate(outs[0], scale_factor=2)
features = features + [add_feature] + outs
return features
def load_state_dict(self, state_dict, **kwargs):
new_state_dict = OrderedDict()
if 'state_dict' in state_dict:
_state_dict = state_dict['state_dict']
elif 'model' in state_dict:
_state_dict = state_dict['model']
else:
_state_dict = state_dict
for k, v in _state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
else:
new_state_dict[k] = v
# Note: In swin seg model: `attn_mask` is no longer a class attribute for
# multi-scale inputs; a norm layer is added for each output; the head layer
# is removed.
kwargs.update({'strict': False})
super().load_state_dict(new_state_dict, **kwargs)
# https://github.com/microsoft/Swin-Transformer
new_settings = {
"Swin-T": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth"
},
"Swin-S": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_small_patch4_window7_512x512.pth"
},
"Swin-B": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth",
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth"
},
"Swin-L": {
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth"
},
}
pretrained_settings = deepcopy(pretrained_settings)
for model_name, sources in new_settings.items():
if model_name not in pretrained_settings:
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
swin_transformer_encoders = {
"Swin-T": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-T"],
"params": {
"embed_dim": 96,
"out_channels": (3, 96, 96, 192, 384, 768),
"depths": [2, 2, 6, 2],
"num_heads": [3, 6, 12, 24],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-S": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-S"],
"params": {
"embed_dim": 96,
"out_channels": (3, 96, 96, 192, 384, 768),
"depths": [2, 2, 18, 2],
"num_heads": [3, 6, 12, 24],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-B": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-B"],
"params": {
"embed_dim": 128,
"out_channels": (3, 128, 128, 256, 512, 1024),
"depths": [2, 2, 18, 2],
"num_heads": [4, 8, 16, 32],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-L": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-L"],
"params": {
"embed_dim": 192,
"out_channels": (3, 192, 192, 384, 768, 1536),
"depths": [2, 2, 18, 2],
"num_heads": [6, 12, 24, 48],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
}
}
if __name__ == "__main__":
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input = torch.randn(1, 3, 256, 256).to(device)
model = SwinTransformerEncoder(2, window_size=8)
# print(model)
res = model.forward(input)
for i in res:
print(i.shape)

View File

@ -0,0 +1,626 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu, Yutong Lin, Yixuan Wei
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False,
**kwargs
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.init_weights() # the pre-trained model will be loaded later if needed
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
# logger = get_root_logger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
pass
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)

View File

@ -0,0 +1,382 @@
from functools import partial
import torch
import torch.nn as nn
from timm.models.efficientnet import EfficientNet
from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs
from timm.models.layers.activations import Swish
from ._base import EncoderMixin
def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
act_layer=Swish,
drop_rate=drop_rate,
drop_path_rate=0.2,
)
return model_kwargs
def gen_efficientnet_lite_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
"""Creates an EfficientNet-Lite model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r2_k5_s2_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['ir_r3_k5_s1_e6_c112'],
['ir_r4_k5_s2_e6_c192'],
['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
num_features=1280,
stem_size=32,
fix_stem=True,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
act_layer=nn.ReLU6,
drop_rate=drop_rate,
drop_path_rate=0.2,
)
return model_kwargs
class EfficientNetBaseEncoder(EfficientNet, EncoderMixin):
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._stage_idxs = stage_idxs
self._out_channels = out_channels
self._depth = depth
self._in_channels = 3
del self.classifier
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.conv_stem, self.bn1, self.act1),
self.blocks[:self._stage_idxs[0]],
self.blocks[self._stage_idxs[0]:self._stage_idxs[1]],
self.blocks[self._stage_idxs[1]:self._stage_idxs[2]],
self.blocks[self._stage_idxs[2]:],
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("classifier.bias", None)
state_dict.pop("classifier.weight", None)
super().load_state_dict(state_dict, **kwargs)
class EfficientNetEncoder(EfficientNetBaseEncoder):
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
kwargs = get_efficientnet_kwargs(channel_multiplier, depth_multiplier, drop_rate)
super().__init__(stage_idxs, out_channels, depth, **kwargs)
class EfficientNetLiteEncoder(EfficientNetBaseEncoder):
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
kwargs = gen_efficientnet_lite_kwargs(channel_multiplier, depth_multiplier, drop_rate)
super().__init__(stage_idxs, out_channels, depth, **kwargs)
def prepare_settings(settings):
return {
"mean": settings["mean"],
"std": settings["std"],
"url": settings["url"],
"input_range": (0, 1),
"input_space": "RGB",
}
timm_efficientnet_encoders = {
"timm-efficientnet-b0": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b0"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b0_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b0_ns"]),
},
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.0,
"depth_multiplier": 1.0,
"drop_rate": 0.2,
},
},
"timm-efficientnet-b1": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b1"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b1_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b1_ns"]),
},
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.0,
"depth_multiplier": 1.1,
"drop_rate": 0.2,
},
},
"timm-efficientnet-b2": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b2"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b2_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b2_ns"]),
},
"params": {
"out_channels": (3, 32, 24, 48, 120, 352),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.1,
"depth_multiplier": 1.2,
"drop_rate": 0.3,
},
},
"timm-efficientnet-b3": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b3"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b3_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b3_ns"]),
},
"params": {
"out_channels": (3, 40, 32, 48, 136, 384),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.2,
"depth_multiplier": 1.4,
"drop_rate": 0.3,
},
},
"timm-efficientnet-b4": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b4"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b4_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b4_ns"]),
},
"params": {
"out_channels": (3, 48, 32, 56, 160, 448),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.4,
"depth_multiplier": 1.8,
"drop_rate": 0.4,
},
},
"timm-efficientnet-b5": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b5"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b5_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b5_ns"]),
},
"params": {
"out_channels": (3, 48, 40, 64, 176, 512),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.6,
"depth_multiplier": 2.2,
"drop_rate": 0.4,
},
},
"timm-efficientnet-b6": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b6"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b6_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b6_ns"]),
},
"params": {
"out_channels": (3, 56, 40, 72, 200, 576),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.8,
"depth_multiplier": 2.6,
"drop_rate": 0.5,
},
},
"timm-efficientnet-b7": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b7"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b7_ap"]),
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b7_ns"]),
},
"params": {
"out_channels": (3, 64, 48, 80, 224, 640),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 2.0,
"depth_multiplier": 3.1,
"drop_rate": 0.5,
},
},
"timm-efficientnet-b8": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_b8"]),
"advprop": prepare_settings(default_cfgs["tf_efficientnet_b8_ap"]),
},
"params": {
"out_channels": (3, 72, 56, 88, 248, 704),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 2.2,
"depth_multiplier": 3.6,
"drop_rate": 0.5,
},
},
"timm-efficientnet-l2": {
"encoder": EfficientNetEncoder,
"pretrained_settings": {
"noisy-student": prepare_settings(default_cfgs["tf_efficientnet_l2_ns"]),
},
"params": {
"out_channels": (3, 136, 104, 176, 480, 1376),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 4.3,
"depth_multiplier": 5.3,
"drop_rate": 0.5,
},
},
"timm-tf_efficientnet_lite0": {
"encoder": EfficientNetLiteEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite0"]),
},
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.0,
"depth_multiplier": 1.0,
"drop_rate": 0.2,
},
},
"timm-tf_efficientnet_lite1": {
"encoder": EfficientNetLiteEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite1"]),
},
"params": {
"out_channels": (3, 32, 24, 40, 112, 320),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.0,
"depth_multiplier": 1.1,
"drop_rate": 0.2,
},
},
"timm-tf_efficientnet_lite2": {
"encoder": EfficientNetLiteEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite2"]),
},
"params": {
"out_channels": (3, 32, 24, 48, 120, 352),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.1,
"depth_multiplier": 1.2,
"drop_rate": 0.3,
},
},
"timm-tf_efficientnet_lite3": {
"encoder": EfficientNetLiteEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite3"]),
},
"params": {
"out_channels": (3, 32, 32, 48, 136, 384),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.2,
"depth_multiplier": 1.4,
"drop_rate": 0.3,
},
},
"timm-tf_efficientnet_lite4": {
"encoder": EfficientNetLiteEncoder,
"pretrained_settings": {
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite4"]),
},
"params": {
"out_channels": (3, 32, 32, 56, 160, 448),
"stage_idxs": (2, 3, 5),
"channel_multiplier": 1.4,
"depth_multiplier": 1.8,
"drop_rate": 0.4,
},
},
}

View File

@ -0,0 +1,124 @@
from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet
from ._base import EncoderMixin
import torch.nn as nn
class GERNetEncoder(ByobNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.head
def get_stages(self):
return [
nn.Identity(),
self.stem,
self.stages[0],
self.stages[1],
self.stages[2],
nn.Sequential(self.stages[3], self.stages[4], self.final_conv)
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("head.fc.weight", None)
state_dict.pop("head.fc.bias", None)
super().load_state_dict(state_dict, **kwargs)
regnet_weights = {
'timm-gernet_s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth',
},
'timm-gernet_m': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth',
},
'timm-gernet_l': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
},
}
pretrained_settings = {}
for model_name, sources in regnet_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
timm_gernet_encoders = {
'timm-gernet_s': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_s"],
'params': {
'out_channels': (3, 13, 48, 48, 384, 1920),
'cfg': ByoModelCfg(
blocks=(
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
),
stem_chs=13,
stem_pool=None,
num_features=1920,
)
},
},
'timm-gernet_m': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_m"],
'params': {
'out_channels': (3, 32, 128, 192, 640, 2560),
'cfg': ByoModelCfg(
blocks=(
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
)
},
},
'timm-gernet_l': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_l"],
'params': {
'out_channels': (3, 32, 128, 192, 640, 2560),
'cfg': ByoModelCfg(
blocks=(
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
)
},
},
}

View File

@ -0,0 +1,175 @@
import timm
import numpy as np
import torch.nn as nn
from ._base import EncoderMixin
def _make_divisible(x, divisible_by=8):
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
class MobileNetV3Encoder(nn.Module, EncoderMixin):
def __init__(self, model_name, width_mult, depth=5, **kwargs):
super().__init__()
if "large" not in model_name and "small" not in model_name:
raise ValueError(
'MobileNetV3 wrong model name {}'.format(model_name)
)
self._mode = "small" if "small" in model_name else "large"
self._depth = depth
self._out_channels = self._get_channels(self._mode, width_mult)
self._in_channels = 3
# minimal models replace hardswish with relu
self.model = timm.create_model(
model_name=model_name,
scriptable=True, # torch.jit scriptable
exportable=True, # onnx export
features_only=True,
)
def _get_channels(self, mode, width_mult):
if mode == "small":
channels = [16, 16, 24, 48, 576]
else:
channels = [16, 24, 40, 112, 960]
channels = [3,] + [_make_divisible(x * width_mult) for x in channels]
return tuple(channels)
def get_stages(self):
if self._mode == 'small':
return [
nn.Identity(),
nn.Sequential(
self.model.conv_stem,
self.model.bn1,
self.model.act1,
),
self.model.blocks[0],
self.model.blocks[1],
self.model.blocks[2:4],
self.model.blocks[4:],
]
elif self._mode == 'large':
return [
nn.Identity(),
nn.Sequential(
self.model.conv_stem,
self.model.bn1,
self.model.act1,
self.model.blocks[0],
),
self.model.blocks[1],
self.model.blocks[2],
self.model.blocks[3:5],
self.model.blocks[5:],
]
else:
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self._mode))
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop('conv_head.weight', None)
state_dict.pop('conv_head.bias', None)
state_dict.pop('classifier.weight', None)
state_dict.pop('classifier.bias', None)
self.model.load_state_dict(state_dict, **kwargs)
mobilenetv3_weights = {
'tf_mobilenetv3_large_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
},
'tf_mobilenetv3_large_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
},
'tf_mobilenetv3_large_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
},
'tf_mobilenetv3_small_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
},
'tf_mobilenetv3_small_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
},
'tf_mobilenetv3_small_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
},
}
pretrained_settings = {}
for model_name, sources in mobilenetv3_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'input_space': 'RGB',
}
timm_mobilenetv3_encoders = {
'timm-mobilenetv3_large_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
'params': {
'model_name': 'tf_mobilenetv3_large_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_large_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
'params': {
'model_name': 'tf_mobilenetv3_large_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_large_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
'params': {
'model_name': 'tf_mobilenetv3_large_minimal_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
'params': {
'model_name': 'tf_mobilenetv3_small_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_small_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
'params': {
'model_name': 'tf_mobilenetv3_small_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
'params': {
'model_name': 'tf_mobilenetv3_small_minimal_100',
'width_mult': 1.0
}
},
}

View File

@ -0,0 +1,332 @@
from ._base import EncoderMixin
from timm.models.regnet import RegNet
import torch.nn as nn
class RegNetEncoder(RegNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.head
def get_stages(self):
return [
nn.Identity(),
self.stem,
self.s1,
self.s2,
self.s3,
self.s4,
]
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("head.fc.weight", None)
state_dict.pop("head.fc.bias", None)
super().load_state_dict(state_dict, **kwargs)
regnet_weights = {
'timm-regnetx_002': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth',
},
'timm-regnetx_004': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth',
},
'timm-regnetx_006': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth',
},
'timm-regnetx_008': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth',
},
'timm-regnetx_016': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth',
},
'timm-regnetx_032': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth',
},
'timm-regnetx_040': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth',
},
'timm-regnetx_064': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth',
},
'timm-regnetx_080': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth',
},
'timm-regnetx_120': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth',
},
'timm-regnetx_160': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth',
},
'timm-regnetx_320': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth',
},
'timm-regnety_002': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth',
},
'timm-regnety_004': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth',
},
'timm-regnety_006': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth',
},
'timm-regnety_008': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth',
},
'timm-regnety_016': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth',
},
'timm-regnety_032': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'
},
'timm-regnety_040': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'
},
'timm-regnety_064': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'
},
'timm-regnety_080': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth',
},
'timm-regnety_120': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth',
},
'timm-regnety_160': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth',
},
'timm-regnety_320': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
}
}
pretrained_settings = {}
for model_name, sources in regnet_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
def _mcfg(**kwargs):
cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
cfg.update(**kwargs)
return cfg
timm_regnet_encoders = {
'timm-regnetx_002': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_002"],
'params': {
'out_channels': (3, 32, 24, 56, 152, 368),
'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13)
},
},
'timm-regnetx_004': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_004"],
'params': {
'out_channels': (3, 32, 32, 64, 160, 384),
'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22)
},
},
'timm-regnetx_006': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_006"],
'params': {
'out_channels': (3, 32, 48, 96, 240, 528),
'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16)
},
},
'timm-regnetx_008': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_008"],
'params': {
'out_channels': (3, 32, 64, 128, 288, 672),
'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16)
},
},
'timm-regnetx_016': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_016"],
'params': {
'out_channels': (3, 32, 72, 168, 408, 912),
'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18)
},
},
'timm-regnetx_032': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_032"],
'params': {
'out_channels': (3, 32, 96, 192, 432, 1008),
'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25)
},
},
'timm-regnetx_040': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_040"],
'params': {
'out_channels': (3, 32, 80, 240, 560, 1360),
'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23)
},
},
'timm-regnetx_064': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_064"],
'params': {
'out_channels': (3, 32, 168, 392, 784, 1624),
'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17)
},
},
'timm-regnetx_080': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_080"],
'params': {
'out_channels': (3, 32, 80, 240, 720, 1920),
'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23)
},
},
'timm-regnetx_120': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_120"],
'params': {
'out_channels': (3, 32, 224, 448, 896, 2240),
'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19)
},
},
'timm-regnetx_160': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_160"],
'params': {
'out_channels': (3, 32, 256, 512, 896, 2048),
'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22)
},
},
'timm-regnetx_320': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnetx_320"],
'params': {
'out_channels': (3, 32, 336, 672, 1344, 2520),
'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23)
},
},
#regnety
'timm-regnety_002': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_002"],
'params': {
'out_channels': (3, 32, 24, 56, 152, 368),
'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25)
},
},
'timm-regnety_004': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_004"],
'params': {
'out_channels': (3, 32, 48, 104, 208, 440),
'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25)
},
},
'timm-regnety_006': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_006"],
'params': {
'out_channels': (3, 32, 48, 112, 256, 608),
'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25)
},
},
'timm-regnety_008': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_008"],
'params': {
'out_channels': (3, 32, 64, 128, 320, 768),
'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25)
},
},
'timm-regnety_016': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_016"],
'params': {
'out_channels': (3, 32, 48, 120, 336, 888),
'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25)
},
},
'timm-regnety_032': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_032"],
'params': {
'out_channels': (3, 32, 72, 216, 576, 1512),
'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25)
},
},
'timm-regnety_040': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_040"],
'params': {
'out_channels': (3, 32, 128, 192, 512, 1088),
'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25)
},
},
'timm-regnety_064': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_064"],
'params': {
'out_channels': (3, 32, 144, 288, 576, 1296),
'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25)
},
},
'timm-regnety_080': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_080"],
'params': {
'out_channels': (3, 32, 168, 448, 896, 2016),
'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25)
},
},
'timm-regnety_120': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_120"],
'params': {
'out_channels': (3, 32, 224, 448, 896, 2240),
'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25)
},
},
'timm-regnety_160': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_160"],
'params': {
'out_channels': (3, 32, 224, 448, 1232, 3024),
'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25)
},
},
'timm-regnety_320': {
'encoder': RegNetEncoder,
"pretrained_settings": pretrained_settings["timm-regnety_320"],
'params': {
'out_channels': (3, 32, 232, 696, 1392, 3712),
'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25)
},
},
}

View File

@ -0,0 +1,163 @@
from ._base import EncoderMixin
from timm.models.resnet import ResNet
from timm.models.res2net import Bottle2neck
import torch.nn as nn
class Res2NetEncoder(ResNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.fc
del self.global_pool
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.conv1, self.bn1, self.act1),
nn.Sequential(self.maxpool, self.layer1),
self.layer2,
self.layer3,
self.layer4,
]
def make_dilated(self, output_stride):
raise ValueError("Res2Net encoders do not support dilated mode")
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("fc.bias", None)
state_dict.pop("fc.weight", None)
super().load_state_dict(state_dict, **kwargs)
res2net_weights = {
'timm-res2net50_26w_4s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'
},
'timm-res2net50_48w_2s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'
},
'timm-res2net50_14w_8s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth',
},
'timm-res2net50_26w_6s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth',
},
'timm-res2net50_26w_8s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth',
},
'timm-res2net101_26w_4s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth',
},
'timm-res2next50': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth',
}
}
pretrained_settings = {}
for model_name, sources in res2net_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
timm_res2net_encoders = {
'timm-res2net50_26w_4s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 26,
'block_args': {'scale': 4}
},
},
'timm-res2net101_26w_4s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 23, 3],
'base_width': 26,
'block_args': {'scale': 4}
},
},
'timm-res2net50_26w_6s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 26,
'block_args': {'scale': 6}
},
},
'timm-res2net50_26w_8s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 26,
'block_args': {'scale': 8}
},
},
'timm-res2net50_48w_2s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 48,
'block_args': {'scale': 2}
},
},
'timm-res2net50_14w_8s': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 14,
'block_args': {'scale': 8}
},
},
'timm-res2next50': {
'encoder': Res2NetEncoder,
"pretrained_settings": pretrained_settings["timm-res2next50"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': Bottle2neck,
'layers': [3, 4, 6, 3],
'base_width': 4,
'cardinality': 8,
'block_args': {'scale': 4}
},
}
}

View File

@ -0,0 +1,208 @@
from ._base import EncoderMixin
from timm.models.resnet import ResNet
from timm.models.resnest import ResNestBottleneck
import torch.nn as nn
class ResNestEncoder(ResNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
del self.fc
del self.global_pool
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(self.conv1, self.bn1, self.act1),
nn.Sequential(self.maxpool, self.layer1),
self.layer2,
self.layer3,
self.layer4,
]
def make_dilated(self, output_stride):
raise ValueError("ResNest encoders do not support dilated mode")
def forward(self, x):
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("fc.bias", None)
state_dict.pop("fc.weight", None)
super().load_state_dict(state_dict, **kwargs)
resnest_weights = {
'timm-resnest14d': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'
},
'timm-resnest26d': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'
},
'timm-resnest50d': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth',
},
'timm-resnest101e': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
},
'timm-resnest200e': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
},
'timm-resnest269e': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
},
'timm-resnest50d_4s2x40d': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
},
'timm-resnest50d_1s4x24d': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
}
}
pretrained_settings = {}
for model_name, sources in resnest_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
timm_resnest_encoders = {
'timm-resnest14d': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest14d"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [1, 1, 1, 1],
'stem_type': 'deep',
'stem_width': 32,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
}
},
'timm-resnest26d': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest26d"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [2, 2, 2, 2],
'stem_type': 'deep',
'stem_width': 32,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
}
},
'timm-resnest50d': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest50d"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 4, 6, 3],
'stem_type': 'deep',
'stem_width': 32,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
}
},
'timm-resnest101e': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest101e"],
'params': {
'out_channels': (3, 128, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 4, 23, 3],
'stem_type': 'deep',
'stem_width': 64,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
}
},
'timm-resnest200e': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest200e"],
'params': {
'out_channels': (3, 128, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 24, 36, 3],
'stem_type': 'deep',
'stem_width': 64,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
}
},
'timm-resnest269e': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest269e"],
'params': {
'out_channels': (3, 128, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 30, 48, 8],
'stem_type': 'deep',
'stem_width': 64,
'avg_down': True,
'base_width': 64,
'cardinality': 1,
'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
},
},
'timm-resnest50d_4s2x40d': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 4, 6, 3],
'stem_type': 'deep',
'stem_width': 32,
'avg_down': True,
'base_width': 40,
'cardinality': 2,
'block_args': {'radix': 4, 'avd': True, 'avd_first': True}
}
},
'timm-resnest50d_1s4x24d': {
'encoder': ResNestEncoder,
"pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
'params': {
'out_channels': (3, 64, 256, 512, 1024, 2048),
'block': ResNestBottleneck,
'layers': [3, 4, 6, 3],
'stem_type': 'deep',
'stem_width': 32,
'avg_down': True,
'base_width': 24,
'cardinality': 4,
'block_args': {'radix': 1, 'avd': True, 'avd_first': True}
}
}
}

Some files were not shown because too many files have changed in this diff Show More