add mean filter method
This commit is contained in:
		
							parent
							
								
									80e03e82f7
								
							
						
					
					
						commit
						905c65b277
					
				
							
								
								
									
										174
									
								
								plugins/filter_collection/BilateralFilter.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								plugins/filter_collection/BilateralFilter.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,174 @@
 | 
				
			|||||||
 | 
					# BilateralFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					通过调节进度条,改变相关参数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import cv2 as cv
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					cv.namedWindow("image")
 | 
				
			||||||
 | 
					cv.createTrackbar("d","image",0,255,print)
 | 
				
			||||||
 | 
					cv.createTrackbar("sigmaColor","image",0,255,print)
 | 
				
			||||||
 | 
					cv.createTrackbar("sigmaSpace","image",0,255,print)
 | 
				
			||||||
 | 
					img = cv.imread("train_306.png",0)
 | 
				
			||||||
 | 
					while(1):
 | 
				
			||||||
 | 
					      d = cv.getTrackbarPos("d","image")
 | 
				
			||||||
 | 
					      sigmaColor = cv.getTrackbarPos("sigmaColor","image")
 | 
				
			||||||
 | 
					      sigmaSpace = cv.getTrackbarPos("sigmaSpace","image")
 | 
				
			||||||
 | 
					      result_img = cv.bilateralFilter(img,d,sigmaColor,sigmaSpace)
 | 
				
			||||||
 | 
					      cv.imshow("result",result_img)
 | 
				
			||||||
 | 
					      k = cv.waitKey(1) & 0xFF
 | 
				
			||||||
 | 
					      if k ==27:
 | 
				
			||||||
 | 
					         break
 | 
				
			||||||
 | 
					cv.destroyAllWindows()
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# BilateralFilter2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					运行速度慢,但是效果好
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import cv2 as cv
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def spilt( a ):
 | 
				
			||||||
 | 
					    if a/2 == 0:
 | 
				
			||||||
 | 
					        x1 = x2 = a/2
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        x1 = math.floor( a/2 )
 | 
				
			||||||
 | 
					        x2 = a - x1
 | 
				
			||||||
 | 
					    return -x1,x2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def d_value():
 | 
				
			||||||
 | 
					    value = np.zeros(256)
 | 
				
			||||||
 | 
					    var_temp = 30
 | 
				
			||||||
 | 
					    for i in range(0,255):
 | 
				
			||||||
 | 
					        t = i*i
 | 
				
			||||||
 | 
					        value[i] = math.e ** (-t / (2 * var_temp * var_temp))
 | 
				
			||||||
 | 
					    return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gaussian_b0x(a, b):
 | 
				
			||||||
 | 
					    judge = 10
 | 
				
			||||||
 | 
					    box =[]
 | 
				
			||||||
 | 
					    x1, x2 = spilt(a)
 | 
				
			||||||
 | 
					    y1, y2 = spilt(b)
 | 
				
			||||||
 | 
					    for i in range (x1, x2 ):
 | 
				
			||||||
 | 
					        for j in range(y1, y2):
 | 
				
			||||||
 | 
					            t = i*i + j*j
 | 
				
			||||||
 | 
					            re = math.e ** (-t/(2*judge*judge))
 | 
				
			||||||
 | 
					            box.append(re)
 | 
				
			||||||
 | 
					    # for x in box :
 | 
				
			||||||
 | 
					    #     print (x)
 | 
				
			||||||
 | 
					    return box
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def original (i, j, k, a, b, img):
 | 
				
			||||||
 | 
					    x1, x2 = spilt(a)
 | 
				
			||||||
 | 
					    y1, y2 = spilt(b)
 | 
				
			||||||
 | 
					    temp = np.zeros(a * b)
 | 
				
			||||||
 | 
					    count = 0
 | 
				
			||||||
 | 
					    for m in range(x1, x2):
 | 
				
			||||||
 | 
					        for n in range(y1, y2):
 | 
				
			||||||
 | 
					            if i + m < 0 or i + m > img.shape[0] - 1 or j + n < 0 or j + n > img.shape[1] - 1:
 | 
				
			||||||
 | 
					                temp[count] = img[i, j, k]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                temp[count] = img[i + m, j + n, k]
 | 
				
			||||||
 | 
					            count += 1
 | 
				
			||||||
 | 
					    return   temp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def bilateral_function(a, b, img, gauss_fun,d_value_e ):
 | 
				
			||||||
 | 
					    x1, x2 = spilt(a)
 | 
				
			||||||
 | 
					    y1, y2 = spilt(b)
 | 
				
			||||||
 | 
					    re = np.zeros(a * b)
 | 
				
			||||||
 | 
					    img0 = copy.copy(img)
 | 
				
			||||||
 | 
					    for i in range(img.shape[0]):
 | 
				
			||||||
 | 
					        for j in range(img.shape[1]):
 | 
				
			||||||
 | 
					            for k in range(0,2):
 | 
				
			||||||
 | 
					                temp = original(i, j, k, a, b, img0)
 | 
				
			||||||
 | 
					                # print("ave:",ave_temp)
 | 
				
			||||||
 | 
					                count = 0
 | 
				
			||||||
 | 
					                for m in  range (x1,x2):
 | 
				
			||||||
 | 
					                    for n in range(y1,y2):
 | 
				
			||||||
 | 
					                        if i+m < 0 or i+m >img.shape[0]-1 or j+n <0 or j+n >img.shape[1]-1:
 | 
				
			||||||
 | 
					                            x = img[i,j,k]
 | 
				
			||||||
 | 
					                        else :
 | 
				
			||||||
 | 
					                            x = img[i+m,j+n,k]
 | 
				
			||||||
 | 
					                        t = int(math.fabs(int(x) - int(img[i,j,k])) )
 | 
				
			||||||
 | 
					                        re[count] =  d_value_e[t]
 | 
				
			||||||
 | 
					                        count += 1
 | 
				
			||||||
 | 
					                evalue = np.multiply(re, gauss_fun)
 | 
				
			||||||
 | 
					                img[i,j,k] = int(np.average(temp, weights = evalue))
 | 
				
			||||||
 | 
					    return  img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    gauss_new = gaussian_b0x(30, 30 )
 | 
				
			||||||
 | 
					    # print(gauss_new)
 | 
				
			||||||
 | 
					    d_value_e = d_value()
 | 
				
			||||||
 | 
					    img0 = cv.imread(r'train_1.png')
 | 
				
			||||||
 | 
					    bilateral_img = bilateral_function(30, 30, copy.copy(img0), gauss_new, d_value_e)
 | 
				
			||||||
 | 
					    cv.imshow("shuangbian", bilateral_img)
 | 
				
			||||||
 | 
					    cv.imshow("yuantu", img0)
 | 
				
			||||||
 | 
					    cv.imwrite("shuangbian.jpg", bilateral_img)
 | 
				
			||||||
 | 
					    cv.waitKey(0)
 | 
				
			||||||
 | 
					    cv.destroyAllWindows()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__  ==  "__main__":
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MorphologyEx1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					#读图
 | 
				
			||||||
 | 
					img = cv2.imread('train_306.png',0)
 | 
				
			||||||
 | 
					#设置核
 | 
				
			||||||
 | 
					kernel = np.ones((5,5),np.uint8)
 | 
				
			||||||
 | 
					#形态学梯度调用
 | 
				
			||||||
 | 
					gradient = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					#显示效果
 | 
				
			||||||
 | 
					cv2.imshow('src',img)
 | 
				
			||||||
 | 
					cv2.imshow('result',gradient)
 | 
				
			||||||
 | 
					cv2.waitKey()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MorphologyE2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 形态学梯度运算=膨胀-腐蚀
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					img = cv2.imread('train_306.png', 0)
 | 
				
			||||||
 | 
					kernel = np.ones((5, 5), np.uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					dilate = cv2.dilate(img, kernel, iterations=5)
 | 
				
			||||||
 | 
					erosion = cv2.erode(img, kernel, iterations=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					gradient = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					res = np.hstack((dilate, erosion, gradient))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cv2.imshow('res', res)
 | 
				
			||||||
 | 
					cv2.imshow('result',gradient)
 | 
				
			||||||
 | 
					cv2.waitKey(0)
 | 
				
			||||||
 | 
					cv2.destroyAllWindows()
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2,4 +2,5 @@ from misc import Register
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
FILTER = Register('滤波处理算法')
 | 
					FILTER = Register('滤波处理算法')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .mean_filter import MeanFilter
 | 
				
			||||||
from filter_collection.main import *
 | 
					from filter_collection.main import *
 | 
				
			||||||
							
								
								
									
										11
									
								
								plugins/filter_collection/bilater_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								plugins/filter_collection/bilater_filter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					from misc import AlgFrontend
 | 
				
			||||||
 | 
					from osgeo import gdal, gdal_array
 | 
				
			||||||
 | 
					from skimage.filters import rank
 | 
				
			||||||
 | 
					from skimage.morphology import  rectangle
 | 
				
			||||||
 | 
					from filter_collection import FILTER
 | 
				
			||||||
 | 
					from PyQt5.QtWidgets import QDialog, QAction
 | 
				
			||||||
 | 
					from PyQt5 import QtCore, QtGui, QtWidgets
 | 
				
			||||||
 | 
					from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1,7 +1,7 @@
 | 
				
			|||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from threading import Thread
 | 
					from threading import Thread
 | 
				
			||||||
from PyQt5.QtWidgets import QDialog, QAction
 | 
					from PyQt5.QtWidgets import QDialog, QAction, QVBoxLayout, QDialogButtonBox, QHBoxLayout, QPushButton
 | 
				
			||||||
from PyQt5 import QtCore, QtGui, QtWidgets
 | 
					from PyQt5 import QtCore, QtGui, QtWidgets
 | 
				
			||||||
from PyQt5.QtCore import Qt, QModelIndex, pyqtSignal
 | 
					from PyQt5.QtCore import Qt, QModelIndex, pyqtSignal
 | 
				
			||||||
from rscder.gui.actions import ActionManager
 | 
					from rscder.gui.actions import ActionManager
 | 
				
			||||||
@ -13,143 +13,55 @@ from osgeo import gdal, gdal_array
 | 
				
			|||||||
from skimage.filters import rank
 | 
					from skimage.filters import rank
 | 
				
			||||||
from skimage.morphology import  rectangle
 | 
					from skimage.morphology import  rectangle
 | 
				
			||||||
from filter_collection import FILTER
 | 
					from filter_collection import FILTER
 | 
				
			||||||
from misc import AlgFrontend
 | 
					from misc import AlgFrontend, AlgSelectWidget
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@FILTER.register
 | 
					 | 
				
			||||||
class MainFilter(AlgFrontend):
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					class FilterMethod(QDialog):
 | 
				
			||||||
    def get_name():
 | 
					    def __init__(self,parent=None, alg:AlgFrontend=None):
 | 
				
			||||||
        return '均值滤波'
 | 
					        super(FilterMethod, self).__init__(parent)
 | 
				
			||||||
    
 | 
					        self.alg = alg
 | 
				
			||||||
    @staticmethod
 | 
					        self.setWindowTitle('滤波算法:{}'.format(alg.get_name()))
 | 
				
			||||||
    def get_widget(parent=None):
 | 
					        self.setWindowIcon(IconInstance().LOGO)
 | 
				
			||||||
        widget = QtWidgets.QWidget(parent)
 | 
					 | 
				
			||||||
        x_size_input = QtWidgets.QLineEdit(widget)
 | 
					 | 
				
			||||||
        x_size_input.setText('3')
 | 
					 | 
				
			||||||
        x_size_input.setValidator(QtGui.QIntValidator())
 | 
					 | 
				
			||||||
        x_size_input.setObjectName('xinput')
 | 
					 | 
				
			||||||
        y_size_input = QtWidgets.QLineEdit(widget)
 | 
					 | 
				
			||||||
        y_size_input.setValidator(QtGui.QIntValidator())
 | 
					 | 
				
			||||||
        y_size_input.setObjectName('yinput')
 | 
					 | 
				
			||||||
        y_size_input.setText('3')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        size_label = QtWidgets.QLabel(widget)
 | 
					 | 
				
			||||||
        size_label.setText('窗口大小:')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        time_label = QtWidgets.QLabel(widget)
 | 
					 | 
				
			||||||
        time_label.setText('X')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hlayout1 = QtWidgets.QHBoxLayout()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hlayout1.addWidget(size_label)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(x_size_input)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(time_label)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(y_size_input)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        widget.setLayout(hlayout1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return widget
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def get_params(widget:QtWidgets.QWidget=None):
 | 
					 | 
				
			||||||
        if widget is None:
 | 
					 | 
				
			||||||
            return dict(x_size=3, y_size=3)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
 | 
					 | 
				
			||||||
        y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if x_input is None or y_input is None:
 | 
					 | 
				
			||||||
            return dict(x_size=3, y_size=3)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        x_size = int(x_input.text())
 | 
					 | 
				
			||||||
        y_size = int(y_input.text())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return dict(x_size=x_size, y_size=y_size)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def run_alg(pth, x_size, y_size, *args, **kargs):
 | 
					 | 
				
			||||||
        x_size = int(x_size)
 | 
					 | 
				
			||||||
        y_size = int(y_size)
 | 
					 | 
				
			||||||
        # pth = layer.path
 | 
					 | 
				
			||||||
        if pth is None:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        ds = gdal.Open(pth)
 | 
					 | 
				
			||||||
        band_count = ds.RasterCount
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        out_path = os.path.join(Project().other_path, 'mean_filter_{}.tif'.format(int(datetime.now().timestamp() * 1000)))
 | 
					 | 
				
			||||||
        out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
 | 
					 | 
				
			||||||
        out_ds.SetProjection(ds.GetProjection())
 | 
					 | 
				
			||||||
        out_ds.SetGeoTransform(ds.GetGeoTransform())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for i in range(band_count):
 | 
					 | 
				
			||||||
            band = ds.GetRasterBand(i+1)
 | 
					 | 
				
			||||||
            data = band.ReadAsArray()
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            data = rank.mean(data, rectangle(y_size, x_size))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_band = out_ds.GetRasterBand(i+1)
 | 
					 | 
				
			||||||
            out_band.WriteArray(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        out_ds.FlushCache()
 | 
					 | 
				
			||||||
        del out_ds
 | 
					 | 
				
			||||||
        del ds
 | 
					 | 
				
			||||||
        return out_path
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FilterSetting(QDialog):
 | 
					 | 
				
			||||||
    def __init__(self, parent=None):
 | 
					 | 
				
			||||||
        super(FilterSetting, self).__init__(parent)
 | 
					 | 
				
			||||||
        self.setWindowTitle('滤波设置')
 | 
					 | 
				
			||||||
        self.setWindowIcon(IconInstance().FILTER)
 | 
					 | 
				
			||||||
        self.initUI()
 | 
					        self.initUI()
 | 
				
			||||||
 | 
					        self.setMinimumWidth(500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def initUI(self):
 | 
					    def initUI(self):
 | 
				
			||||||
 | 
					        #图层
 | 
				
			||||||
        self.layer_combox = RasterLayerCombox(self)
 | 
					        self.layer_combox = RasterLayerCombox(self)
 | 
				
			||||||
        layer_label = QtWidgets.QLabel('图层:')
 | 
					        layerbox = QHBoxLayout()
 | 
				
			||||||
 | 
					        layerbox.addWidget(self.layer_combox)        
 | 
				
			||||||
 | 
					       
 | 
				
			||||||
 | 
					        self.param_widget = self.alg.get_widget(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        hbox = QtWidgets.QHBoxLayout()
 | 
					        self.ok_button = QPushButton('确定', self)
 | 
				
			||||||
        hbox.addWidget(layer_label)
 | 
					        self.ok_button.setIcon(IconInstance().OK)
 | 
				
			||||||
        hbox.addWidget(self.layer_combox)
 | 
					        self.ok_button.clicked.connect(self.accept)
 | 
				
			||||||
 | 
					        self.ok_button.setDefault(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x_size_input = QtWidgets.QLineEdit(self)
 | 
					        self.cancel_button = QPushButton('取消', self)
 | 
				
			||||||
        x_size_input.setText('3')
 | 
					        self.cancel_button.setIcon(IconInstance().CANCEL)
 | 
				
			||||||
        y_size_input = QtWidgets.QLineEdit(self)
 | 
					        self.cancel_button.clicked.connect(self.reject)
 | 
				
			||||||
        y_size_input.setText('3')
 | 
					        self.cancel_button.setDefault(False)
 | 
				
			||||||
 | 
					        buttonbox=QDialogButtonBox(self)
 | 
				
			||||||
        size_label = QtWidgets.QLabel(self)
 | 
					        buttonbox.addButton(self.ok_button,QDialogButtonBox.NoRole)
 | 
				
			||||||
        size_label.setText('窗口大小:')
 | 
					        buttonbox.addButton(self.cancel_button,QDialogButtonBox.NoRole)
 | 
				
			||||||
 | 
					        buttonbox.setCenterButtons(True)
 | 
				
			||||||
        time_label = QtWidgets.QLabel(self)
 | 
					 | 
				
			||||||
        time_label.setText('X')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.x_size_input = x_size_input
 | 
					 | 
				
			||||||
        self.y_size_input = y_size_input
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hlayout1 = QtWidgets.QHBoxLayout()
 | 
					 | 
				
			||||||
        hlayout1.addWidget(size_label)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(x_size_input)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(time_label)
 | 
					 | 
				
			||||||
        hlayout1.addWidget(y_size_input)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ok_button = QtWidgets.QPushButton(self)
 | 
					 | 
				
			||||||
        ok_button.setText('确定')
 | 
					 | 
				
			||||||
        ok_button.clicked.connect(self.accept)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cancel_button = QtWidgets.QPushButton(self)
 | 
					 | 
				
			||||||
        cancel_button.setText('取消')
 | 
					 | 
				
			||||||
        cancel_button.clicked.connect(self.reject)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hlayout2 = QtWidgets.QHBoxLayout()
 | 
					 | 
				
			||||||
        hlayout2.addWidget(ok_button,0,alignment=Qt.AlignHCenter)
 | 
					 | 
				
			||||||
        hlayout2.addWidget(cancel_button,0,alignment=Qt.AlignHCenter)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        vlayout = QtWidgets.QVBoxLayout()
 | 
					 | 
				
			||||||
        vlayout.addLayout(hbox)
 | 
					 | 
				
			||||||
        vlayout.addLayout(hlayout1)
 | 
					 | 
				
			||||||
        vlayout.addLayout(hlayout2)
 | 
					 | 
				
			||||||
        self.setLayout(vlayout)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        totalvlayout=QVBoxLayout()
 | 
				
			||||||
 | 
					        totalvlayout.addLayout(layerbox)
 | 
				
			||||||
 | 
					        # totalvlayout.addWidget(self.filter_select)
 | 
				
			||||||
 | 
					        if self.param_widget is not None:
 | 
				
			||||||
 | 
					            totalvlayout.addWidget(self.param_widget)
 | 
				
			||||||
 | 
					        # totalvlayout.addWidget(self.thres_select)
 | 
				
			||||||
 | 
					        totalvlayout.addStretch(1)
 | 
				
			||||||
 | 
					        hbox = QHBoxLayout()
 | 
				
			||||||
 | 
					        hbox.addStretch(1)
 | 
				
			||||||
 | 
					        hbox.addWidget(buttonbox)
 | 
				
			||||||
 | 
					        totalvlayout.addLayout(hbox)
 | 
				
			||||||
 | 
					        # totalvlayout.addStretch()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.setLayout(totalvlayout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MainPlugin(BasicPlugin):
 | 
					class MainPlugin(BasicPlugin):
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
@ -165,53 +77,49 @@ class MainPlugin(BasicPlugin):
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def set_action(self):
 | 
					    def set_action(self):
 | 
				
			||||||
        self.action = QAction('均值滤波', self.mainwindow)
 | 
					        
 | 
				
			||||||
        # self.action.setCheckable)
 | 
					        toolbar = ActionManager().add_toolbar('Filter Collection')
 | 
				
			||||||
        # self.action.setChecked(False)
 | 
					        for key in FILTER.keys():
 | 
				
			||||||
        self.action.triggered.connect(self.run)
 | 
					            alg:AlgFrontend = FILTER[key]
 | 
				
			||||||
        ActionManager().filter_menu.addAction(self.action)
 | 
					            name = alg.get_name() or key
 | 
				
			||||||
        self.alg_ok.connect(self.alg_oked)
 | 
					            action = QAction(name, self.mainwindow)
 | 
				
			||||||
 | 
					            func = functools.partial(self.run, key)
 | 
				
			||||||
 | 
					            action.triggered.connect(func)
 | 
				
			||||||
 | 
					            toolbar.addAction(action)
 | 
				
			||||||
 | 
					            ActionManager().filter_menu.addAction(action)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # self.action = QAction('均值滤波', self.mainwindow)
 | 
				
			||||||
 | 
					        # # self.action.setCheckable)
 | 
				
			||||||
 | 
					        # # self.action.setChecked(False)
 | 
				
			||||||
 | 
					        # self.action.triggered.connect(self.run)
 | 
				
			||||||
 | 
					        # ActionManager().filter_menu.addAction(self.action)
 | 
				
			||||||
 | 
					        # self.alg_ok.connect(self.alg_oked)
 | 
				
			||||||
        # basic
 | 
					        # basic
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def alg_oked(self, parent, layer:RasterLayer):
 | 
					    def alg_oked(self, parent, layer:RasterLayer):
 | 
				
			||||||
        parent.add_result_layer(layer)
 | 
					        parent.add_result_layer(layer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_alg(self, layer:RasterLayer, x_size, y_size, method='mean'):
 | 
					    def run_alg(self, layer:RasterLayer, alg:AlgFrontend, p):
 | 
				
			||||||
        x_size = int(x_size)
 | 
					        
 | 
				
			||||||
        y_size = int(y_size)
 | 
					 | 
				
			||||||
        pth = layer.path
 | 
					        pth = layer.path
 | 
				
			||||||
        if pth is None:
 | 
					        if pth is None:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        
 | 
					        out_path = alg.run_alg(pth, **p)
 | 
				
			||||||
        ds = gdal.Open(pth)
 | 
					 | 
				
			||||||
        band_count = ds.RasterCount
 | 
					 | 
				
			||||||
        out_path = os.path.join(Project().other_path, '{}_mean_filter.tif'.format(layer.name))
 | 
					 | 
				
			||||||
        out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
 | 
					 | 
				
			||||||
        out_ds.SetProjection(ds.GetProjection())
 | 
					 | 
				
			||||||
        out_ds.SetGeoTransform(ds.GetGeoTransform())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for i in range(band_count):
 | 
					 | 
				
			||||||
            band = ds.GetRasterBand(i+1)
 | 
					 | 
				
			||||||
            data = band.ReadAsArray()
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            data = rank.mean(data, rectangle(y_size, x_size))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_band = out_ds.GetRasterBand(i+1)
 | 
					 | 
				
			||||||
            out_band.WriteArray(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        out_ds.FlushCache()
 | 
					 | 
				
			||||||
        del out_ds
 | 
					 | 
				
			||||||
        del ds
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        rlayer = RasterLayer(path = out_path, enable= True, view_mode = layer.view_mode )
 | 
					        rlayer = RasterLayer(path = out_path, enable= True, view_mode = layer.view_mode )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.alg_ok.emit(layer.layer_parent, rlayer)
 | 
					        self.alg_ok.emit(layer.layer_parent, rlayer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self):
 | 
					    def run(self, key):
 | 
				
			||||||
        dialog = FilterSetting(self.mainwindow)
 | 
					        if key not in FILTER:
 | 
				
			||||||
        dialog.show()
 | 
					            self.send_message.emit(f'{key} not in {FILTER.keys()}')
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        alg:AlgFrontend = FILTER[key]
 | 
				
			||||||
 | 
					        dialog = FilterMethod(self.mainwindow, alg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if dialog.exec_():
 | 
					        if dialog.exec_():
 | 
				
			||||||
            x_size = int(dialog.x_size_input.text())
 | 
					            
 | 
				
			||||||
            y_size = int(dialog.y_size_input.text())
 | 
					            t = Thread(target=self.run_alg, args=(dialog.layer_combox.current_layer, alg, alg.get_params(dialog.param_widget)))
 | 
				
			||||||
            t = Thread(target=self.run_alg, args=(dialog.layer_combox.current_layer, x_size, y_size))
 | 
					 | 
				
			||||||
            t.start()
 | 
					            t.start()
 | 
				
			||||||
							
								
								
									
										96
									
								
								plugins/filter_collection/mean_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								plugins/filter_collection/mean_filter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,96 @@
 | 
				
			|||||||
 | 
					from misc import AlgFrontend
 | 
				
			||||||
 | 
					from osgeo import gdal, gdal_array
 | 
				
			||||||
 | 
					from skimage.filters import rank
 | 
				
			||||||
 | 
					from skimage.morphology import  rectangle
 | 
				
			||||||
 | 
					from filter_collection import FILTER
 | 
				
			||||||
 | 
					from PyQt5.QtWidgets import QDialog, QAction
 | 
				
			||||||
 | 
					from PyQt5 import QtCore, QtGui, QtWidgets
 | 
				
			||||||
 | 
					from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@FILTER.register
 | 
				
			||||||
 | 
					class MeanFilter(AlgFrontend):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_name():
 | 
				
			||||||
 | 
					        return '均值滤波'
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_icon():
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_widget(parent=None):
 | 
				
			||||||
 | 
					        widget = QtWidgets.QWidget(parent)
 | 
				
			||||||
 | 
					        x_size_input = QtWidgets.QLineEdit(widget)
 | 
				
			||||||
 | 
					        x_size_input.setText('3')
 | 
				
			||||||
 | 
					        x_size_input.setValidator(QtGui.QIntValidator())
 | 
				
			||||||
 | 
					        x_size_input.setObjectName('xinput')
 | 
				
			||||||
 | 
					        y_size_input = QtWidgets.QLineEdit(widget)
 | 
				
			||||||
 | 
					        y_size_input.setValidator(QtGui.QIntValidator())
 | 
				
			||||||
 | 
					        y_size_input.setObjectName('yinput')
 | 
				
			||||||
 | 
					        y_size_input.setText('3')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        size_label = QtWidgets.QLabel(widget)
 | 
				
			||||||
 | 
					        size_label.setText('窗口大小:')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        time_label = QtWidgets.QLabel(widget)
 | 
				
			||||||
 | 
					        time_label.setText('X')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hlayout1 = QtWidgets.QHBoxLayout()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hlayout1.addWidget(size_label)
 | 
				
			||||||
 | 
					        hlayout1.addWidget(x_size_input)
 | 
				
			||||||
 | 
					        hlayout1.addWidget(time_label)
 | 
				
			||||||
 | 
					        hlayout1.addWidget(y_size_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        widget.setLayout(hlayout1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return widget
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_params(widget:QtWidgets.QWidget=None):
 | 
				
			||||||
 | 
					        if widget is None:
 | 
				
			||||||
 | 
					            return dict(x_size=3, y_size=3)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
 | 
				
			||||||
 | 
					        y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if x_input is None or y_input is None:
 | 
				
			||||||
 | 
					            return dict(x_size=3, y_size=3)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        x_size = int(x_input.text())
 | 
				
			||||||
 | 
					        y_size = int(y_input.text())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return dict(x_size=x_size, y_size=y_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def run_alg(pth, x_size, y_size, *args, **kargs):
 | 
				
			||||||
 | 
					        x_size = int(x_size)
 | 
				
			||||||
 | 
					        y_size = int(y_size)
 | 
				
			||||||
 | 
					        # pth = layer.path
 | 
				
			||||||
 | 
					        if pth is None:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        ds = gdal.Open(pth)
 | 
				
			||||||
 | 
					        band_count = ds.RasterCount
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_path = os.path.join(Project().other_path, 'mean_filter_{}.tif'.format(int(datetime.now().timestamp() * 1000)))
 | 
				
			||||||
 | 
					        out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
 | 
				
			||||||
 | 
					        out_ds.SetProjection(ds.GetProjection())
 | 
				
			||||||
 | 
					        out_ds.SetGeoTransform(ds.GetGeoTransform())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(band_count):
 | 
				
			||||||
 | 
					            band = ds.GetRasterBand(i+1)
 | 
				
			||||||
 | 
					            data = band.ReadAsArray()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            data = rank.mean(data, rectangle(y_size, x_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_band = out_ds.GetRasterBand(i+1)
 | 
				
			||||||
 | 
					            out_band.WriteArray(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_ds.FlushCache()
 | 
				
			||||||
 | 
					        del out_ds
 | 
				
			||||||
 | 
					        del ds
 | 
				
			||||||
 | 
					        return out_path
 | 
				
			||||||
@ -51,7 +51,7 @@ class FollowPlugin(BasicPlugin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def set_action(self):
 | 
					    def set_action(self):
 | 
				
			||||||
        follow_box:QtWidgets.QWidget = ActionManager().follow_box
 | 
					        follow_box:QtWidgets.QWidget = ActionManager().follow_box
 | 
				
			||||||
        toolbar = ActionManager().add_toolbar('Follow')
 | 
					        # toolbar = ActionManager().add_toolbar('Follow')
 | 
				
			||||||
        vbox = QtWidgets.QVBoxLayout(follow_box)
 | 
					        vbox = QtWidgets.QVBoxLayout(follow_box)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        combox = QtWidgets.QComboBox(follow_box)
 | 
					        combox = QtWidgets.QComboBox(follow_box)
 | 
				
			||||||
@ -65,19 +65,16 @@ class FollowPlugin(BasicPlugin):
 | 
				
			|||||||
                name = alg.get_name()
 | 
					                name = alg.get_name()
 | 
				
			||||||
            combox.addItem(name, key)
 | 
					            combox.addItem(name, key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            action = QtWidgets.QAction(alg.get_icon(), name, self.mainwindow)
 | 
					            # action = QtWidgets.QAction(alg.get_icon(), name, self.mainwindow)
 | 
				
			||||||
            func = partial(self.run_dialog, alg)
 | 
					            # func = partial(self.run_dialog, alg)
 | 
				
			||||||
            action.triggered.connect(func)
 | 
					            # action.triggered.connect(func)
 | 
				
			||||||
            toolbar.addAction(action)
 | 
					            # toolbar.addAction(action)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					 | 
				
			||||||
        combox.currentIndexChanged.connect(self.on_change)
 | 
					        combox.currentIndexChanged.connect(self.on_change)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        vbox.addWidget(combox)
 | 
					        vbox.addWidget(combox)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.current_widget = None
 | 
					        self.current_widget = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        self.combox = combox
 | 
					        self.combox = combox
 | 
				
			||||||
        self.layout = vbox
 | 
					        self.layout = vbox
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -128,7 +125,7 @@ class FollowPlugin(BasicPlugin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        params = alg.get_params(self.current_widget)
 | 
					        params = alg.get_params(self.current_widget)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        t = Thread(target=self.run_alg, args=(params,))
 | 
					        t = Thread(target=self.run_alg, args=(alg, params,))
 | 
				
			||||||
        t.start()
 | 
					        t.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_dialog(self, alg:AlgFrontend):
 | 
					    def run_dialog(self, alg:AlgFrontend):
 | 
				
			||||||
 | 
				
			|||||||
@ -176,7 +176,7 @@ class UnsupervisedPlugin(BasicPlugin):
 | 
				
			|||||||
        unsupervised_menu = QMenu('&无监督变化检测', self.mainwindow)
 | 
					        unsupervised_menu = QMenu('&无监督变化检测', self.mainwindow)
 | 
				
			||||||
        unsupervised_menu.setIcon(IconInstance().UNSUPERVISED)
 | 
					        unsupervised_menu.setIcon(IconInstance().UNSUPERVISED)
 | 
				
			||||||
        ActionManager().change_detection_menu.addMenu(unsupervised_menu)
 | 
					        ActionManager().change_detection_menu.addMenu(unsupervised_menu)
 | 
				
			||||||
 | 
					        toolbar = ActionManager().add_toolbar('Unsupervised')
 | 
				
			||||||
        for key in UNSUPER_CD.keys():
 | 
					        for key in UNSUPER_CD.keys():
 | 
				
			||||||
            alg:AlgFrontend = UNSUPER_CD[key]
 | 
					            alg:AlgFrontend = UNSUPER_CD[key]
 | 
				
			||||||
            if alg.get_name() is None:
 | 
					            if alg.get_name() is None:
 | 
				
			||||||
@ -187,7 +187,7 @@ class UnsupervisedPlugin(BasicPlugin):
 | 
				
			|||||||
            action = QAction(name, unsupervised_menu)
 | 
					            action = QAction(name, unsupervised_menu)
 | 
				
			||||||
            func = partial(self.run_cd, alg)
 | 
					            func = partial(self.run_cd, alg)
 | 
				
			||||||
            action.triggered.connect(func)
 | 
					            action.triggered.connect(func)
 | 
				
			||||||
 | 
					            toolbar.addAction(action)
 | 
				
			||||||
            unsupervised_menu.addAction(action)
 | 
					            unsupervised_menu.addAction(action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user