2023-07-29 17:39:51 +08:00

191 lines
6.9 KiB
Python

from . import AI_METHOD
from rscder.plugins.misc import AlgFrontend
from rscder.utils.icons import IconInstance
from rscder.utils.project import PairLayer
from osgeo import gdal, gdal_array
import os
from rscder.utils.project import Project
from rscder.utils.geomath import geo2imageRC, imageRC2geo
import math
from .packages import get_model
import numpy as np
class BasicAICD(AlgFrontend):
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC3
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, model=None, *args, **kargs):
if model is None and send_message is not None:
send_message.emit('未能加载模型!')
return
ds1: gdal.Dataset = gdal.Open(pth1)
ds2: gdal.Dataset = gdal.Open(pth2)
cell_size = (512, 512)
xsize = layer_parent.size[0]
ysize = layer_parent.size[1]
band = ds1.RasterCount
yblocks = ysize // cell_size[1]
xblocks = xsize // cell_size[0]
driver = gdal.GetDriverByName('GTiff')
out_tif = os.path.join(Project().other_path, 'temp.tif')
out_ds = driver.Create(out_tif, xsize, ysize, 1, gdal.GDT_Float32)
geo = layer_parent.grid.geo
proj = layer_parent.grid.proj
out_ds.SetGeoTransform(geo)
out_ds.SetProjection(proj)
max_diff = 0
min_diff = math.inf
start1x, start1y = geo2imageRC(ds1.GetGeoTransform(
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
end1x, end1y = geo2imageRC(ds1.GetGeoTransform(
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
start2x, start2y = geo2imageRC(ds2.GetGeoTransform(
), layer_parent.mask.xy[0], layer_parent.mask.xy[1])
end2x, end2y = geo2imageRC(ds2.GetGeoTransform(
), layer_parent.mask.xy[2], layer_parent.mask.xy[3])
for j in range(yblocks + 1): # 该改这里了
if send_message is not None:
send_message.emit(f'计算{j}/{yblocks}')
for i in range(xblocks +1):
block_xy1 = [start1x + i * cell_size[0], start1y+j * cell_size[1]]
block_xy2 = [start2x + i * cell_size[0], start2y+j * cell_size[1]]
block_xy = [i * cell_size[0], j * cell_size[1]]
if block_xy1[1] > end1y or block_xy2[1] > end2y:
break
if block_xy1[0] > end1x or block_xy2[0] > end2x:
break
block_size = list(cell_size)
if block_xy[1] + block_size[1] > ysize:
block_xy[1] = (ysize - block_size[1])
if block_xy[0] + block_size[0] > xsize:
block_xy[0] = ( xsize - block_size[0])
if block_xy1[1] + block_size[1] > end1y:
block_xy1[1] = (end1y - block_size[1])
if block_xy1[0] + block_size[0] > end1x:
block_xy1[0] = (end1x - block_size[0])
if block_xy2[1] + block_size[1] > end2y:
block_xy2[1] = (end2y - block_size[1])
if block_xy2[0] + block_size[0] > end2x:
block_xy2[0] = (end2x - block_size[0])
# if block_size1[0] * block_size1[1] == 0 or block_size2[0] * block_size2[1] == 0:
# continue
block_data1 = ds1.ReadAsArray(*block_xy1, *block_size)
block_data2 = ds2.ReadAsArray(*block_xy2, *block_size)
# if block_data1.shape[0] == 0:
# continue
if band == 1:
block_data1 = block_data1[None, ...]
block_data2 = block_data2[None, ...]
block_diff = model(block_data1, block_data2)
out_ds.GetRasterBand(1).WriteArray(block_diff, *block_xy)
if send_message is not None:
send_message.emit(f'完成{j}/{yblocks}')
del ds2
del ds1
out_ds.FlushCache()
del out_ds
if send_message is not None:
send_message.emit('归一化概率中...')
temp_in_ds = gdal.Open(out_tif)
out_normal_tif = os.path.join(Project().cmi_path, '{}_{}_cmi.tif'.format(
layer_parent.name, int(np.random.rand() * 100000)))
out_normal_ds = driver.Create(
out_normal_tif, xsize, ysize, 1, gdal.GDT_Byte)
out_normal_ds.SetGeoTransform(geo)
out_normal_ds.SetProjection(proj)
# hist = np.zeros(256, dtype=np.int32)
for j in range(yblocks+1):
block_xy = (0, j * cell_size[1])
if block_xy[1] > ysize:
break
block_size = (xsize, cell_size[1])
if block_xy[1] + block_size[1] > ysize:
block_size = (xsize, ysize - block_xy[1])
np.seterr(divide='ignore',invalid='ignore')
block_data = temp_in_ds.ReadAsArray(*block_xy, *block_size)
max_diff=block_data.max()
min_diff=block_data.min()
block_data = (block_data - min_diff) / (max_diff - min_diff) * 255
block_data = block_data.astype(np.uint8)
out_normal_ds.GetRasterBand(1).WriteArray(block_data, *block_xy)
# os.system('pause')
# hist_t, _ = np.histogram(block_data, bins=256, range=(0, 256))
# hist += hist_t
# print(hist)
del temp_in_ds
del out_normal_ds
try:
# os.system('pause')
os.remove(out_tif)
except:
pass
if send_message is not None:
send_message.emit('计算完成')
return out_normal_tif
@AI_METHOD.register
class DVCA(BasicAICD):
@staticmethod
def get_name():
return 'DVCA'
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('DVCA')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
@AI_METHOD.register
class DPFCN(BasicAICD):
@staticmethod
def get_name():
return 'DPFCN'
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('DPFCN')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)
@AI_METHOD.register
class RCNN(BasicAICD):
@staticmethod
def get_name():
return 'RCNN'
@staticmethod
def run_alg(pth1: str, pth2: str, layer_parent: PairLayer, send_message=None, *args, **kargs):
model = get_model('RCNN')
return BasicAICD.run_alg(pth1, pth2, layer_parent, send_message, model, *args, **kargs)