diff --git a/plugins/filter_collection/bilater_filter.py b/plugins/filter_collection/bilater_filter.py index 5a10be0..65fe534 100644 --- a/plugins/filter_collection/bilater_filter.py +++ b/plugins/filter_collection/bilater_filter.py @@ -1,4 +1,5 @@ from misc import AlgFrontend +from misc.utils import format_now from osgeo import gdal, gdal_array from skimage.filters import rank from skimage.morphology import rectangle @@ -23,27 +24,42 @@ class BilaterFilter(AlgFrontend): @staticmethod def get_widget(parent=None): widget = QtWidgets.QWidget(parent) - x_size_input = QtWidgets.QLineEdit(widget) - x_size_input.setText('3') - x_size_input.setValidator(QtGui.QIntValidator()) - x_size_input.setObjectName('xinput') - y_size_input = QtWidgets.QLineEdit(widget) - y_size_input.setValidator(QtGui.QIntValidator()) - y_size_input.setObjectName('yinput') - y_size_input.setText('3') + filter_window_r = QtWidgets.QLineEdit(widget) + filter_window_r.setText('6') + filter_window_r.setValidator(QtGui.QIntValidator()) + filter_window_r.setObjectName('filter_window_r') - size_label = QtWidgets.QLabel(widget) - size_label.setText('窗口大小:') + sigma_color = QtWidgets.QLineEdit(widget) + sigma_color.setValidator(QtGui.QIntValidator()) + sigma_color.setObjectName('sigma_color') + sigma_color.setText('50') + + sigma_space = QtWidgets.QLineEdit(widget) + sigma_space.setValidator(QtGui.QIntValidator()) + sigma_space.setObjectName('sigma_space') + sigma_space.setText('50') + + filter_window_r_label = QtWidgets.QLabel(widget) + filter_window_r_label.setText('滤波窗口直径:') + + sigma_space_label = QtWidgets.QLabel(widget) + sigma_space_label.setText('空间域方差:') + + sigma_color_label = QtWidgets.QLabel(widget) + sigma_color_label.setText('像素域方差:') - time_label = QtWidgets.QLabel(widget) - time_label.setText('X') hlayout1 = QtWidgets.QHBoxLayout() - hlayout1.addWidget(size_label) - hlayout1.addWidget(x_size_input) - hlayout1.addWidget(time_label) - hlayout1.addWidget(y_size_input) + hlayout1.addWidget(filter_window_r_label) + hlayout1.addWidget(filter_window_r) + hlayout1.addWidget(sigma_space_label) + hlayout1.addWidget(sigma_space) + hlayout1.addWidget(sigma_color_label) + hlayout1.addWidget(sigma_color) + # hlayout1.addWidget(x_size_input) + # hlayout1.addWidget(time_label) + # hlayout1.addWidget(y_size_input) widget.setLayout(hlayout1) @@ -52,31 +68,33 @@ class BilaterFilter(AlgFrontend): @staticmethod def get_params(widget:QtWidgets.QWidget=None): if widget is None: - return dict(x_size=3, y_size=3) - - x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput') - y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput') + return dict(w=6, sigma_color=50, sigma_space=50) + def default(o, v=None): + if o is None: + return v + else: + return o.text() - if x_input is None or y_input is None: - return dict(x_size=3, y_size=3) + w = int(default(widget.findChild(QtWidgets.QLineEdit, 'filter_window_r'), 6)) + sigma_space = int(default(widget.findChild(QtWidgets.QLineEdit, 'sigma_space'), 50)) + sigma_color = int(default(widget.findChild(QtWidgets.QLineEdit, 'sigma_color'), 50)) + # y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput') - x_size = int(x_input.text()) - y_size = int(y_input.text()) - return dict(x_size=x_size, y_size=y_size) + return dict(w=w, sigma_space=sigma_space, sigma_color=sigma_color) @staticmethod - def run_alg(pth, x_size, y_size, *args, **kargs): - x_size = int(x_size) - y_size = int(y_size) + def run_alg(pth, w, sigma_space, sigma_color, *args, **kargs): + # x_size = int(x_size) + # y_size = int(y_size) # pth = layer.path if pth is None: return ds = gdal.Open(pth) band_count = ds.RasterCount - - out_path = os.path.join(Project().other_path, 'bilater_filter_{}.tif'.format(int(datetime.now().timestamp() * 1000))) + name = os.path.splitext(os.path.basename(pth))[0] + out_path = os.path.join(Project().other_path, '{}_bilater_filter_{}.tif'.format(name, format_now())) out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType) out_ds.SetProjection(ds.GetProjection()) out_ds.SetGeoTransform(ds.GetGeoTransform()) @@ -85,9 +103,8 @@ class BilaterFilter(AlgFrontend): band = ds.GetRasterBand(i+1) data = band.ReadAsArray() #进行双边滤波处理cv2.bilateralFilter(影像,滤波窗口直径(0-255),像素域方差(0-255),空间域方差(0-255)) - data=cv2.bilateralFilter(data,6,50,50) + data=cv2.bilateralFilter(data,w,sigma_space,sigma_color) - out_band = out_ds.GetRasterBand(i+1) out_band.WriteArray(data) diff --git a/plugins/filter_collection/morphology_filter.py b/plugins/filter_collection/morphology_filter.py index d8ec9ea..14ced24 100644 --- a/plugins/filter_collection/morphology_filter.py +++ b/plugins/filter_collection/morphology_filter.py @@ -80,10 +80,6 @@ class MorphologyFilter(AlgFrontend): out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType) out_ds.SetProjection(ds.GetProjection()) out_ds.SetGeoTransform(ds.GetGeoTransform()) - - - - for i in range(band_count): band = ds.GetRasterBand(i+1)