add mean filter method

This commit is contained in:
copper 2022-11-05 20:04:27 +08:00
parent 80e03e82f7
commit 905c65b277
7 changed files with 362 additions and 175 deletions

View File

@ -0,0 +1,174 @@
# BilateralFilter
通过调节进度条,改变相关参数
```python
import cv2 as cv
import numpy as np
cv.namedWindow("image")
cv.createTrackbar("d","image",0,255,print)
cv.createTrackbar("sigmaColor","image",0,255,print)
cv.createTrackbar("sigmaSpace","image",0,255,print)
img = cv.imread("train_306.png",0)
while(1):
d = cv.getTrackbarPos("d","image")
sigmaColor = cv.getTrackbarPos("sigmaColor","image")
sigmaSpace = cv.getTrackbarPos("sigmaSpace","image")
result_img = cv.bilateralFilter(img,d,sigmaColor,sigmaSpace)
cv.imshow("result",result_img)
k = cv.waitKey(1) & 0xFF
if k ==27:
break
cv.destroyAllWindows()
```
# BilateralFilter2
运行速度慢,但是效果好
```python
import cv2 as cv
import numpy as np
import math
import copy
def spilt( a ):
if a/2 == 0:
x1 = x2 = a/2
else:
x1 = math.floor( a/2 )
x2 = a - x1
return -x1,x2
def d_value():
value = np.zeros(256)
var_temp = 30
for i in range(0,255):
t = i*i
value[i] = math.e ** (-t / (2 * var_temp * var_temp))
return value
def gaussian_b0x(a, b):
judge = 10
box =[]
x1, x2 = spilt(a)
y1, y2 = spilt(b)
for i in range (x1, x2 ):
for j in range(y1, y2):
t = i*i + j*j
re = math.e ** (-t/(2*judge*judge))
box.append(re)
# for x in box :
# print (x)
return box
def original (i, j, k, a, b, img):
x1, x2 = spilt(a)
y1, y2 = spilt(b)
temp = np.zeros(a * b)
count = 0
for m in range(x1, x2):
for n in range(y1, y2):
if i + m < 0 or i + m > img.shape[0] - 1 or j + n < 0 or j + n > img.shape[1] - 1:
temp[count] = img[i, j, k]
else:
temp[count] = img[i + m, j + n, k]
count += 1
return temp
def bilateral_function(a, b, img, gauss_fun,d_value_e ):
x1, x2 = spilt(a)
y1, y2 = spilt(b)
re = np.zeros(a * b)
img0 = copy.copy(img)
for i in range(img.shape[0]):
for j in range(img.shape[1]):
for k in range(0,2):
temp = original(i, j, k, a, b, img0)
# print("ave:",ave_temp)
count = 0
for m in range (x1,x2):
for n in range(y1,y2):
if i+m < 0 or i+m >img.shape[0]-1 or j+n <0 or j+n >img.shape[1]-1:
x = img[i,j,k]
else :
x = img[i+m,j+n,k]
t = int(math.fabs(int(x) - int(img[i,j,k])) )
re[count] = d_value_e[t]
count += 1
evalue = np.multiply(re, gauss_fun)
img[i,j,k] = int(np.average(temp, weights = evalue))
return img
def main():
gauss_new = gaussian_b0x(30, 30 )
# print(gauss_new)
d_value_e = d_value()
img0 = cv.imread(r'train_1.png')
bilateral_img = bilateral_function(30, 30, copy.copy(img0), gauss_new, d_value_e)
cv.imshow("shuangbian", bilateral_img)
cv.imshow("yuantu", img0)
cv.imwrite("shuangbian.jpg", bilateral_img)
cv.waitKey(0)
cv.destroyAllWindows()
if __name__ == "__main__":
main()
```
# MorphologyEx1
```python
import cv2
import numpy as np
#读图
img = cv2.imread('train_306.png',0)
#设置核
kernel = np.ones((5,5),np.uint8)
#形态学梯度调用
gradient = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
#显示效果
cv2.imshow('src',img)
cv2.imshow('result',gradient)
cv2.waitKey()
```
# MorphologyE2
```python
import cv2
import numpy as np
# 形态学梯度运算=膨胀-腐蚀
img = cv2.imread('train_306.png', 0)
kernel = np.ones((5, 5), np.uint8)
dilate = cv2.dilate(img, kernel, iterations=5)
erosion = cv2.erode(img, kernel, iterations=5)
gradient = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
res = np.hstack((dilate, erosion, gradient))
cv2.imshow('res', res)
cv2.imshow('result',gradient)
cv2.waitKey(0)
cv2.destroyAllWindows()
```

View File

@ -2,4 +2,5 @@ from misc import Register
FILTER = Register('滤波处理算法')
from .mean_filter import MeanFilter
from filter_collection.main import *

View File

@ -0,0 +1,11 @@
from misc import AlgFrontend
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
import os
from datetime import datetime

View File

@ -1,7 +1,7 @@
from datetime import datetime
import os
from threading import Thread
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5.QtWidgets import QDialog, QAction, QVBoxLayout, QDialogButtonBox, QHBoxLayout, QPushButton
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import Qt, QModelIndex, pyqtSignal
from rscder.gui.actions import ActionManager
@ -13,143 +13,55 @@ from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from misc import AlgFrontend
from misc import AlgFrontend, AlgSelectWidget
import functools
@FILTER.register
class MainFilter(AlgFrontend):
@staticmethod
def get_name():
return '均值滤波'
@staticmethod
def get_widget(parent=None):
widget = QtWidgets.QWidget(parent)
x_size_input = QtWidgets.QLineEdit(widget)
x_size_input.setText('3')
x_size_input.setValidator(QtGui.QIntValidator())
x_size_input.setObjectName('xinput')
y_size_input = QtWidgets.QLineEdit(widget)
y_size_input.setValidator(QtGui.QIntValidator())
y_size_input.setObjectName('yinput')
y_size_input.setText('3')
size_label = QtWidgets.QLabel(widget)
size_label.setText('窗口大小:')
time_label = QtWidgets.QLabel(widget)
time_label.setText('X')
hlayout1 = QtWidgets.QHBoxLayout()
hlayout1.addWidget(size_label)
hlayout1.addWidget(x_size_input)
hlayout1.addWidget(time_label)
hlayout1.addWidget(y_size_input)
widget.setLayout(hlayout1)
return widget
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
if widget is None:
return dict(x_size=3, y_size=3)
x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
if x_input is None or y_input is None:
return dict(x_size=3, y_size=3)
x_size = int(x_input.text())
y_size = int(y_input.text())
return dict(x_size=x_size, y_size=y_size)
@staticmethod
def run_alg(pth, x_size, y_size, *args, **kargs):
x_size = int(x_size)
y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
out_path = os.path.join(Project().other_path, 'mean_filter_{}.tif'.format(int(datetime.now().timestamp() * 1000)))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = rank.mean(data, rectangle(y_size, x_size))
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path
class FilterSetting(QDialog):
def __init__(self, parent=None):
super(FilterSetting, self).__init__(parent)
self.setWindowTitle('滤波设置')
self.setWindowIcon(IconInstance().FILTER)
class FilterMethod(QDialog):
def __init__(self,parent=None, alg:AlgFrontend=None):
super(FilterMethod, self).__init__(parent)
self.alg = alg
self.setWindowTitle('滤波算法:{}'.format(alg.get_name()))
self.setWindowIcon(IconInstance().LOGO)
self.initUI()
self.setMinimumWidth(500)
def initUI(self):
#图层
self.layer_combox = RasterLayerCombox(self)
layer_label = QtWidgets.QLabel('图层:')
layerbox = QHBoxLayout()
layerbox.addWidget(self.layer_combox)
self.param_widget = self.alg.get_widget(self)
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(layer_label)
hbox.addWidget(self.layer_combox)
self.ok_button = QPushButton('确定', self)
self.ok_button.setIcon(IconInstance().OK)
self.ok_button.clicked.connect(self.accept)
self.ok_button.setDefault(True)
x_size_input = QtWidgets.QLineEdit(self)
x_size_input.setText('3')
y_size_input = QtWidgets.QLineEdit(self)
y_size_input.setText('3')
size_label = QtWidgets.QLabel(self)
size_label.setText('窗口大小:')
time_label = QtWidgets.QLabel(self)
time_label.setText('X')
self.x_size_input = x_size_input
self.y_size_input = y_size_input
hlayout1 = QtWidgets.QHBoxLayout()
hlayout1.addWidget(size_label)
hlayout1.addWidget(x_size_input)
hlayout1.addWidget(time_label)
hlayout1.addWidget(y_size_input)
ok_button = QtWidgets.QPushButton(self)
ok_button.setText('确定')
ok_button.clicked.connect(self.accept)
cancel_button = QtWidgets.QPushButton(self)
cancel_button.setText('取消')
cancel_button.clicked.connect(self.reject)
hlayout2 = QtWidgets.QHBoxLayout()
hlayout2.addWidget(ok_button,0,alignment=Qt.AlignHCenter)
hlayout2.addWidget(cancel_button,0,alignment=Qt.AlignHCenter)
vlayout = QtWidgets.QVBoxLayout()
vlayout.addLayout(hbox)
vlayout.addLayout(hlayout1)
vlayout.addLayout(hlayout2)
self.setLayout(vlayout)
self.cancel_button = QPushButton('取消', self)
self.cancel_button.setIcon(IconInstance().CANCEL)
self.cancel_button.clicked.connect(self.reject)
self.cancel_button.setDefault(False)
buttonbox=QDialogButtonBox(self)
buttonbox.addButton(self.ok_button,QDialogButtonBox.NoRole)
buttonbox.addButton(self.cancel_button,QDialogButtonBox.NoRole)
buttonbox.setCenterButtons(True)
totalvlayout=QVBoxLayout()
totalvlayout.addLayout(layerbox)
# totalvlayout.addWidget(self.filter_select)
if self.param_widget is not None:
totalvlayout.addWidget(self.param_widget)
# totalvlayout.addWidget(self.thres_select)
totalvlayout.addStretch(1)
hbox = QHBoxLayout()
hbox.addStretch(1)
hbox.addWidget(buttonbox)
totalvlayout.addLayout(hbox)
# totalvlayout.addStretch()
self.setLayout(totalvlayout)
class MainPlugin(BasicPlugin):
@ -165,53 +77,49 @@ class MainPlugin(BasicPlugin):
}
def set_action(self):
self.action = QAction('均值滤波', self.mainwindow)
# self.action.setCheckable)
# self.action.setChecked(False)
self.action.triggered.connect(self.run)
ActionManager().filter_menu.addAction(self.action)
self.alg_ok.connect(self.alg_oked)
toolbar = ActionManager().add_toolbar('Filter Collection')
for key in FILTER.keys():
alg:AlgFrontend = FILTER[key]
name = alg.get_name() or key
action = QAction(name, self.mainwindow)
func = functools.partial(self.run, key)
action.triggered.connect(func)
toolbar.addAction(action)
ActionManager().filter_menu.addAction(action)
# self.action = QAction('均值滤波', self.mainwindow)
# # self.action.setCheckable)
# # self.action.setChecked(False)
# self.action.triggered.connect(self.run)
# ActionManager().filter_menu.addAction(self.action)
# self.alg_ok.connect(self.alg_oked)
# basic
def alg_oked(self, parent, layer:RasterLayer):
parent.add_result_layer(layer)
def run_alg(self, layer:RasterLayer, x_size, y_size, method='mean'):
x_size = int(x_size)
y_size = int(y_size)
def run_alg(self, layer:RasterLayer, alg:AlgFrontend, p):
pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
out_path = os.path.join(Project().other_path, '{}_mean_filter.tif'.format(layer.name))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = rank.mean(data, rectangle(y_size, x_size))
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
out_path = alg.run_alg(pth, **p)
rlayer = RasterLayer(path = out_path, enable= True, view_mode = layer.view_mode )
self.alg_ok.emit(layer.layer_parent, rlayer)
def run(self):
dialog = FilterSetting(self.mainwindow)
dialog.show()
def run(self, key):
if key not in FILTER:
self.send_message.emit(f'{key} not in {FILTER.keys()}')
return
alg:AlgFrontend = FILTER[key]
dialog = FilterMethod(self.mainwindow, alg)
if dialog.exec_():
x_size = int(dialog.x_size_input.text())
y_size = int(dialog.y_size_input.text())
t = Thread(target=self.run_alg, args=(dialog.layer_combox.current_layer, x_size, y_size))
t = Thread(target=self.run_alg, args=(dialog.layer_combox.current_layer, alg, alg.get_params(dialog.param_widget)))
t.start()

View File

@ -0,0 +1,96 @@
from misc import AlgFrontend
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
import os
from datetime import datetime
@FILTER.register
class MeanFilter(AlgFrontend):
@staticmethod
def get_name():
return '均值滤波'
@staticmethod
def get_icon():
return None
@staticmethod
def get_widget(parent=None):
widget = QtWidgets.QWidget(parent)
x_size_input = QtWidgets.QLineEdit(widget)
x_size_input.setText('3')
x_size_input.setValidator(QtGui.QIntValidator())
x_size_input.setObjectName('xinput')
y_size_input = QtWidgets.QLineEdit(widget)
y_size_input.setValidator(QtGui.QIntValidator())
y_size_input.setObjectName('yinput')
y_size_input.setText('3')
size_label = QtWidgets.QLabel(widget)
size_label.setText('窗口大小:')
time_label = QtWidgets.QLabel(widget)
time_label.setText('X')
hlayout1 = QtWidgets.QHBoxLayout()
hlayout1.addWidget(size_label)
hlayout1.addWidget(x_size_input)
hlayout1.addWidget(time_label)
hlayout1.addWidget(y_size_input)
widget.setLayout(hlayout1)
return widget
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
if widget is None:
return dict(x_size=3, y_size=3)
x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
if x_input is None or y_input is None:
return dict(x_size=3, y_size=3)
x_size = int(x_input.text())
y_size = int(y_input.text())
return dict(x_size=x_size, y_size=y_size)
@staticmethod
def run_alg(pth, x_size, y_size, *args, **kargs):
x_size = int(x_size)
y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
out_path = os.path.join(Project().other_path, 'mean_filter_{}.tif'.format(int(datetime.now().timestamp() * 1000)))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = rank.mean(data, rectangle(y_size, x_size))
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path

View File

@ -51,7 +51,7 @@ class FollowPlugin(BasicPlugin):
def set_action(self):
follow_box:QtWidgets.QWidget = ActionManager().follow_box
toolbar = ActionManager().add_toolbar('Follow')
# toolbar = ActionManager().add_toolbar('Follow')
vbox = QtWidgets.QVBoxLayout(follow_box)
combox = QtWidgets.QComboBox(follow_box)
@ -65,19 +65,16 @@ class FollowPlugin(BasicPlugin):
name = alg.get_name()
combox.addItem(name, key)
action = QtWidgets.QAction(alg.get_icon(), name, self.mainwindow)
func = partial(self.run_dialog, alg)
action.triggered.connect(func)
toolbar.addAction(action)
# action = QtWidgets.QAction(alg.get_icon(), name, self.mainwindow)
# func = partial(self.run_dialog, alg)
# action.triggered.connect(func)
# toolbar.addAction(action)
combox.currentIndexChanged.connect(self.on_change)
vbox.addWidget(combox)
self.current_widget = None
self.combox = combox
self.layout = vbox
@ -128,7 +125,7 @@ class FollowPlugin(BasicPlugin):
params = alg.get_params(self.current_widget)
t = Thread(target=self.run_alg, args=(params,))
t = Thread(target=self.run_alg, args=(alg, params,))
t.start()
def run_dialog(self, alg:AlgFrontend):

View File

@ -176,7 +176,7 @@ class UnsupervisedPlugin(BasicPlugin):
unsupervised_menu = QMenu('&无监督变化检测', self.mainwindow)
unsupervised_menu.setIcon(IconInstance().UNSUPERVISED)
ActionManager().change_detection_menu.addMenu(unsupervised_menu)
toolbar = ActionManager().add_toolbar('Unsupervised')
for key in UNSUPER_CD.keys():
alg:AlgFrontend = UNSUPER_CD[key]
if alg.get_name() is None:
@ -187,7 +187,7 @@ class UnsupervisedPlugin(BasicPlugin):
action = QAction(name, unsupervised_menu)
func = partial(self.run_cd, alg)
action.triggered.connect(func)
toolbar.addAction(action)
unsupervised_menu.addAction(action)