230 lines
8.3 KiB
Python
230 lines
8.3 KiB
Python
"""
|
|
使用 GDAL BuildVRT 与 Translate 快速将影像批量镶嵌并转换为 8bit RGB (Light version)
|
|
|
|
1. 将输入目录下的所有 tif 文件合并为一个 VRT 文件;
|
|
2. 采用 GDAL ComputeStatistics 进行快速统计计算, 并根据指定的百分比截断值计算有效数据范围;
|
|
3. 将 VRT 文件转换为 8bit RGB 格式的 COG 文件, 并保存到指定目录.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
from pathlib import Path
|
|
from osgeo import gdal
|
|
import numpy as np
|
|
|
|
# 添加父目录到 sys.path 以导入 utils
|
|
BASE_DIR = Path(__file__).parent.parent
|
|
sys.path.append(str(BASE_DIR))
|
|
|
|
from utils.common_utils import setup_logging
|
|
|
|
gdal.UseExceptions()
|
|
# 设置 GDAL 选项以优化性能
|
|
gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS")
|
|
gdal.SetConfigOption("GDAL_CACHEMAX", "1024")
|
|
|
|
|
|
def compute_band_stats(band: gdal.Band, approx_ok: bool = True, no_data: float = np.nan, percentile: float = 2.0):
|
|
"""
|
|
使用 GDAL ComputeStatistics 计算波段的统计信息
|
|
|
|
实现类似 QGIS 图层符号化中的累计计数削减效果
|
|
|
|
Parameters
|
|
----------
|
|
band : gdal.Band
|
|
GDAL 波段对象
|
|
approx_ok : bool, optional
|
|
是否允许近似统计, 速度快, by default True
|
|
no_data : float, optional
|
|
无效值, 默认 np.nan, by default np.nan
|
|
percentile : float, optional
|
|
百分比截断值 (2 表示 2%-98% 范围), by default 2.0
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
(min_val, max_val)
|
|
"""
|
|
try:
|
|
# 使用 GDAL 的 ComputeStatistics 计算统计信息
|
|
stats = band.ComputeStatistics(approx_ok)
|
|
if stats is not None and len(stats) >= 2:
|
|
min_val, max_val = stats[0], stats[1]
|
|
|
|
# 如果需要更精确的百分比统计, 使用直方图
|
|
if percentile > 0:
|
|
# 创建直方图
|
|
hist_min = min_val
|
|
hist_max = max_val
|
|
bucket_count = 256 # 直方图桶数
|
|
|
|
# 计算直方图
|
|
hist = band.GetHistogram(
|
|
hist_min, hist_max, bucket_count,
|
|
include_out_of_range=False, approx_ok=True
|
|
)
|
|
|
|
if hist and len(hist) > 0:
|
|
# 计算累积直方图
|
|
hist_array = np.array(hist, dtype=np.float32)
|
|
cum_hist = np.cumsum(hist_array)
|
|
total = cum_hist[-1]
|
|
|
|
# 找到百分比位置
|
|
# 考虑到无效值为 -9999, 所以直接从 0 开始, 尽量还原原始数据范围
|
|
if no_data <= 0:
|
|
lower_thresh = 0.0
|
|
else:
|
|
lower_thresh = total * (percentile / 100.0)
|
|
upper_thresh = total * (1 - percentile / 100.0)
|
|
|
|
# 找到对应的值
|
|
lower_idx = np.searchsorted(cum_hist, lower_thresh)
|
|
upper_idx = np.searchsorted(cum_hist, upper_thresh)
|
|
|
|
if lower_idx < len(hist_array) and upper_idx < len(hist_array):
|
|
# 将索引映射回实际值
|
|
bin_width = (hist_max - hist_min) / bucket_count
|
|
min_val = hist_min + lower_idx * bin_width
|
|
max_val = hist_min + upper_idx * bin_width
|
|
|
|
return min_val, max_val
|
|
except Exception as e:
|
|
logging.warning(f"计算统计信息失败: {str(e)}")
|
|
|
|
# 如果失败, 返回默认值
|
|
return 0.0, 1.0
|
|
|
|
|
|
def vrt_to_8bit_simple(input_dir: str | Path, output_path: str | Path,
|
|
no_data: float = np.nan, percentile: float = 1.0):
|
|
"""
|
|
使用 gdal.Translate 将指定目录下的 tif 文件合并并转换为 8bit
|
|
使用 GDAL 内置统计功能快速计算并转换, NaN 值将被赋值为 0
|
|
|
|
Parameters
|
|
----------
|
|
input_dir : str | Path
|
|
输入 TIF 文件所在目录
|
|
output_path : str | Path
|
|
输出文件路径
|
|
no_data : float, optional
|
|
输入影像中的无效值, 默认 np.nan
|
|
percentile : float, optional
|
|
百分比截断值, 默认 0% 到 99%
|
|
"""
|
|
input_dir = Path(input_dir)
|
|
output_path = Path(output_path)
|
|
|
|
# 1. 获取所有tif文件
|
|
tif_files = [
|
|
str(f)
|
|
for f in input_dir.iterdir()
|
|
if f.is_file() and f.suffix.lower() == ".tif"
|
|
]
|
|
|
|
if not tif_files:
|
|
logging.warning(f"{input_dir} 中没有找到 tif 文件.")
|
|
return
|
|
|
|
logging.info(f"1) 找到 {len(tif_files)} 个 tif 文件")
|
|
|
|
vrt_path = input_dir / "merged.vrt"
|
|
vrt_ds = None
|
|
|
|
try:
|
|
# 2. 创建VRT
|
|
logging.info("2) 构建 VRT 镶嵌...")
|
|
vrt_options = gdal.BuildVRTOptions(
|
|
srcNodata=np.nan, # 输入影像中的无效值, 设为 NaN 防止拼接处存在缝隙
|
|
VRTNodata=no_data, # 输出 VRT 中的无效值
|
|
)
|
|
vrt_ds = gdal.BuildVRT(str(vrt_path), tif_files, options=vrt_options)
|
|
|
|
if vrt_ds is None:
|
|
logging.error("构建 VRT 失败.")
|
|
return
|
|
|
|
# 获取波段数与影像尺寸
|
|
num_bands = vrt_ds.RasterCount
|
|
width = vrt_ds.RasterXSize
|
|
height = vrt_ds.RasterYSize
|
|
logging.info(f"波段数: {num_bands}, 影像尺寸: {width}x{height}")
|
|
|
|
# 3. 使用 GDAL 快速计算统计信息
|
|
logging.info("3) 使用 GDAL 计算影像统计信息...")
|
|
scale_params = []
|
|
for band_idx in range(1, num_bands + 1):
|
|
band = vrt_ds.GetRasterBand(band_idx)
|
|
# 使用 GDAL 内置函数计算统计
|
|
min_val, max_val = compute_band_stats(
|
|
band, no_data=no_data, percentile=percentile)
|
|
logging.info(
|
|
f"波段 {band_idx}: 有效值范围 [{min_val:.4f}, {max_val:.4f}]")
|
|
# 将有效值映射到 1-255, 保留 0 作为 NoData 专用值
|
|
# 避免暗部像素 (如水体、阴影) 会被映射为 0, 从而被误判为透明缺失
|
|
scale_params.append([min_val, max_val, 1, 255])
|
|
|
|
# 4. 使用gdal_translate转换为8bit, 无效值统一设为0
|
|
if output_path.exists():
|
|
logging.warning(f"结果文件已存在: {output_path}")
|
|
return
|
|
logging.info("4) 使用 gdal_translate 转换为 8bit COG (NaN->0)...")
|
|
translate_options = gdal.TranslateOptions(
|
|
format="COG", # 输出为 COG 格式, 自动构建金字塔
|
|
scaleParams=scale_params,
|
|
outputType=gdal.GDT_Byte,
|
|
noData=0, # 输出 NoData 设为 0
|
|
creationOptions=[
|
|
"COMPRESS=DEFLATE",
|
|
"ZLEVEL=4", # DEFLATE 压缩级别, 支持 1-9, 默认为 6
|
|
"PREDICTOR=2", # 差值预测, 利于影像压缩
|
|
"NUM_THREADS=ALL_CPUS",
|
|
"BIGTIFF=IF_SAFER",
|
|
# "TILED=YES", # COG 格式自带分块, 不支持手动设置
|
|
# "PHOTOMETRIC=RGB", # COG 格式不支持 photometric 参数
|
|
],
|
|
callback=gdal.TermProgress_nocb # 进度回调, 不显示进度条
|
|
)
|
|
gdal.Translate(str(output_path), vrt_ds, options=translate_options)
|
|
logging.info(f"已保存: {output_path}")
|
|
|
|
# 释放VRT数据集, 确保文件句柄释放
|
|
vrt_ds = None
|
|
|
|
except Exception as e:
|
|
logging.error(f"处理过程中出现异常: {str(e)}")
|
|
import traceback
|
|
logging.error(traceback.format_exc())
|
|
finally:
|
|
# 清理资源
|
|
vrt_ds = None
|
|
|
|
|
|
def main(input_dir, output_path):
|
|
input_dir = Path(input_dir)
|
|
output_path = Path(output_path)
|
|
|
|
output_root = output_path.parent
|
|
os.makedirs(output_root, exist_ok=True)
|
|
|
|
log_file = output_root / "sr2rgb_light.log"
|
|
setup_logging(str(log_file))
|
|
|
|
logging.info("开始批量处理 (Light Mode)...")
|
|
logging.info(f"输入目录: {input_dir}")
|
|
logging.info(f"输出文件: {output_path}")
|
|
|
|
vrt_to_8bit_simple(input_dir, output_path, no_data=-9999.0, percentile=1.0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 输入目录: 包含分块tif影像的根目录
|
|
input_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_COG")
|
|
# 输出目录: 存放最终RGB镶嵌结果的目录
|
|
output_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_RGB_COG")
|
|
rgb_file = output_root / "Hubei_Sentinel-2_2025_RGB_COG.tif"
|
|
main(input_root, rgb_file)
|