Open_EarthData_Tools/utils/sr2rgb_light.py

230 lines
8.3 KiB
Python

"""
使用 GDAL BuildVRT 与 Translate 快速将影像批量镶嵌并转换为 8bit RGB (Light version)
1. 将输入目录下的所有 tif 文件合并为一个 VRT 文件;
2. 采用 GDAL ComputeStatistics 进行快速统计计算, 并根据指定的百分比截断值计算有效数据范围;
3. 将 VRT 文件转换为 8bit RGB 格式的 COG 文件, 并保存到指定目录.
"""
import os
import sys
import logging
from pathlib import Path
from osgeo import gdal
import numpy as np
# 添加父目录到 sys.path 以导入 utils
BASE_DIR = Path(__file__).parent.parent
sys.path.append(str(BASE_DIR))
from utils.common_utils import setup_logging
gdal.UseExceptions()
# 设置 GDAL 选项以优化性能
gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS")
gdal.SetConfigOption("GDAL_CACHEMAX", "1024")
def compute_band_stats(band: gdal.Band, approx_ok: bool = True, no_data: float = np.nan, percentile: float = 2.0):
"""
使用 GDAL ComputeStatistics 计算波段的统计信息
实现类似 QGIS 图层符号化中的累计计数削减效果
Parameters
----------
band : gdal.Band
GDAL 波段对象
approx_ok : bool, optional
是否允许近似统计, 速度快, by default True
no_data : float, optional
无效值, 默认 np.nan, by default np.nan
percentile : float, optional
百分比截断值 (2 表示 2%-98% 范围), by default 2.0
Returns
-------
tuple
(min_val, max_val)
"""
try:
# 使用 GDAL 的 ComputeStatistics 计算统计信息
stats = band.ComputeStatistics(approx_ok)
if stats is not None and len(stats) >= 2:
min_val, max_val = stats[0], stats[1]
# 如果需要更精确的百分比统计, 使用直方图
if percentile > 0:
# 创建直方图
hist_min = min_val
hist_max = max_val
bucket_count = 256 # 直方图桶数
# 计算直方图
hist = band.GetHistogram(
hist_min, hist_max, bucket_count,
include_out_of_range=False, approx_ok=True
)
if hist and len(hist) > 0:
# 计算累积直方图
hist_array = np.array(hist, dtype=np.float32)
cum_hist = np.cumsum(hist_array)
total = cum_hist[-1]
# 找到百分比位置
# 考虑到无效值为 -9999, 所以直接从 0 开始, 尽量还原原始数据范围
if no_data <= 0:
lower_thresh = 0.0
else:
lower_thresh = total * (percentile / 100.0)
upper_thresh = total * (1 - percentile / 100.0)
# 找到对应的值
lower_idx = np.searchsorted(cum_hist, lower_thresh)
upper_idx = np.searchsorted(cum_hist, upper_thresh)
if lower_idx < len(hist_array) and upper_idx < len(hist_array):
# 将索引映射回实际值
bin_width = (hist_max - hist_min) / bucket_count
min_val = hist_min + lower_idx * bin_width
max_val = hist_min + upper_idx * bin_width
return min_val, max_val
except Exception as e:
logging.warning(f"计算统计信息失败: {str(e)}")
# 如果失败, 返回默认值
return 0.0, 1.0
def vrt_to_8bit_simple(input_dir: str | Path, output_path: str | Path,
no_data: float = np.nan, percentile: float = 1.0):
"""
使用 gdal.Translate 将指定目录下的 tif 文件合并并转换为 8bit
使用 GDAL 内置统计功能快速计算并转换, NaN 值将被赋值为 0
Parameters
----------
input_dir : str | Path
输入 TIF 文件所在目录
output_path : str | Path
输出文件路径
no_data : float, optional
输入影像中的无效值, 默认 np.nan
percentile : float, optional
百分比截断值, 默认 0% 到 99%
"""
input_dir = Path(input_dir)
output_path = Path(output_path)
# 1. 获取所有tif文件
tif_files = [
str(f)
for f in input_dir.iterdir()
if f.is_file() and f.suffix.lower() == ".tif"
]
if not tif_files:
logging.warning(f"{input_dir} 中没有找到 tif 文件.")
return
logging.info(f"1) 找到 {len(tif_files)} 个 tif 文件")
vrt_path = input_dir / "merged.vrt"
vrt_ds = None
try:
# 2. 创建VRT
logging.info("2) 构建 VRT 镶嵌...")
vrt_options = gdal.BuildVRTOptions(
srcNodata=np.nan, # 输入影像中的无效值, 设为 NaN 防止拼接处存在缝隙
VRTNodata=no_data, # 输出 VRT 中的无效值
)
vrt_ds = gdal.BuildVRT(str(vrt_path), tif_files, options=vrt_options)
if vrt_ds is None:
logging.error("构建 VRT 失败.")
return
# 获取波段数与影像尺寸
num_bands = vrt_ds.RasterCount
width = vrt_ds.RasterXSize
height = vrt_ds.RasterYSize
logging.info(f"波段数: {num_bands}, 影像尺寸: {width}x{height}")
# 3. 使用 GDAL 快速计算统计信息
logging.info("3) 使用 GDAL 计算影像统计信息...")
scale_params = []
for band_idx in range(1, num_bands + 1):
band = vrt_ds.GetRasterBand(band_idx)
# 使用 GDAL 内置函数计算统计
min_val, max_val = compute_band_stats(
band, no_data=no_data, percentile=percentile)
logging.info(
f"波段 {band_idx}: 有效值范围 [{min_val:.4f}, {max_val:.4f}]")
# 将有效值映射到 1-255, 保留 0 作为 NoData 专用值
# 避免暗部像素 (如水体、阴影) 会被映射为 0, 从而被误判为透明缺失
scale_params.append([min_val, max_val, 1, 255])
# 4. 使用gdal_translate转换为8bit, 无效值统一设为0
if output_path.exists():
logging.warning(f"结果文件已存在: {output_path}")
return
logging.info("4) 使用 gdal_translate 转换为 8bit COG (NaN->0)...")
translate_options = gdal.TranslateOptions(
format="COG", # 输出为 COG 格式, 自动构建金字塔
scaleParams=scale_params,
outputType=gdal.GDT_Byte,
noData=0, # 输出 NoData 设为 0
creationOptions=[
"COMPRESS=DEFLATE",
"ZLEVEL=4", # DEFLATE 压缩级别, 支持 1-9, 默认为 6
"PREDICTOR=2", # 差值预测, 利于影像压缩
"NUM_THREADS=ALL_CPUS",
"BIGTIFF=IF_SAFER",
# "TILED=YES", # COG 格式自带分块, 不支持手动设置
# "PHOTOMETRIC=RGB", # COG 格式不支持 photometric 参数
],
callback=gdal.TermProgress_nocb # 进度回调, 不显示进度条
)
gdal.Translate(str(output_path), vrt_ds, options=translate_options)
logging.info(f"已保存: {output_path}")
# 释放VRT数据集, 确保文件句柄释放
vrt_ds = None
except Exception as e:
logging.error(f"处理过程中出现异常: {str(e)}")
import traceback
logging.error(traceback.format_exc())
finally:
# 清理资源
vrt_ds = None
def main(input_dir, output_path):
input_dir = Path(input_dir)
output_path = Path(output_path)
output_root = output_path.parent
os.makedirs(output_root, exist_ok=True)
log_file = output_root / "sr2rgb_light.log"
setup_logging(str(log_file))
logging.info("开始批量处理 (Light Mode)...")
logging.info(f"输入目录: {input_dir}")
logging.info(f"输出文件: {output_path}")
vrt_to_8bit_simple(input_dir, output_path, no_data=-9999.0, percentile=1.0)
if __name__ == "__main__":
# 输入目录: 包含分块tif影像的根目录
input_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_COG")
# 输出目录: 存放最终RGB镶嵌结果的目录
output_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_RGB_COG")
rgb_file = output_root / "Hubei_Sentinel-2_2025_RGB_COG.tif"
main(input_root, rgb_file)