diff --git a/utils/sr2rgb_light.py b/utils/sr2rgb_light.py new file mode 100644 index 0000000..592c438 --- /dev/null +++ b/utils/sr2rgb_light.py @@ -0,0 +1,177 @@ +""" +使用 GDAL BuildVRT 与 Translate 快速将影像批量镶嵌并转换为 8bit RGB (Light version) +""" + +import os +import sys +import logging +from pathlib import Path +from osgeo import gdal +import numpy as np + +# 添加父目录到 sys.path 以导入 utils +BASE_DIR = Path(__file__).parent.parent +sys.path.append(str(BASE_DIR)) + +from utils.common_utils import setup_logging + +gdal.UseExceptions() +# 设置 GDAL 选项以优化性能 +gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS") +gdal.SetConfigOption("GDAL_CACHEMAX", "1024") + + +def vrt_to_8bit_simple(input_dir: str | Path, output_path: str | Path, num: int = 2): + """ + 使用 gdal.Translate 将指定目录下的 tif 文件合并并转换为 8bit + NaN 值将被赋值为 0 + + Parameters + ---------- + input_dir : str | Path + 输入 TIF 文件所在目录 + output_path : str | Path + 输出文件路径 + num : int, optional + 拉伸的百分比范围, 默认2%到98% + """ + input_dir = Path(input_dir) + output_path = Path(output_path) + + # 1. 获取所有tif文件 + tif_files = [ + str(f) + for f in input_dir.iterdir() + if f.is_file() and f.suffix.lower() == ".tif" + ] + + if not tif_files: + logging.warning(f"{input_dir} 中没有找到 tif 文件.") + return + + logging.info(f"1) 找到 {len(tif_files)} 个 tif 文件") + + vrt_path = input_dir / "temp.vrt" + vrt_ds = None + + try: + # 2. 创建VRT + logging.info("2) 构建 VRT 镶嵌...") + vrt_ds = gdal.BuildVRT(str(vrt_path), tif_files) + + if vrt_ds is None: + logging.error("构建 VRT 失败.") + return + + # 获取波段数与影像尺寸 + num_bands = vrt_ds.RasterCount + width = vrt_ds.RasterXSize + height = vrt_ds.RasterYSize + logging.info(f"波段数: {num_bands}, 影像尺寸: {width}x{height}") + + # 3. 计算每个波段的统计信息并处理 NaN 值 + logging.info("3) 计算影像统计信息...") + scale_params = [] + for band_idx in range(1, num_bands + 1): + band = vrt_ds.GetRasterBand(band_idx) + + # 读取数据块以计算统计 (避免加载全部数据) + block_size = 1024 + stats_min = [] + stats_max = [] + + for y in range(0, vrt_ds.RasterYSize, block_size): + height = min(block_size, vrt_ds.RasterYSize - y) + for x in range(0, vrt_ds.RasterXSize, block_size): + width = min(block_size, vrt_ds.RasterXSize - x) + + data = band.ReadAsArray(x, y, width, height) + if data is not None: + # 计算非NaN的最小最大值 + valid_data = data[~np.isnan(data)] + # 读取2%-98%的有效数据范围 + if len(valid_data) > 0: + stats_min.append(np.nanpercentile(valid_data, num)) + stats_max.append(np.nanpercentile(valid_data, 100 - num)) + + if stats_min and stats_max: + min_val = min(stats_min) + max_val = max(stats_max) + logging.info(f"波段 {band_idx}: 有效值范围 [{min_val:.4f}, {max_val:.4f}]") + scale_params.append([min_val, max_val, 0, 255]) + else: + logging.warning(f"波段 {band_idx}: 未找到有效数据, 使用默认范围 [0, 1]") + scale_params.append([0, 1, 0, 255]) + + # 4. 使用gdal_translate转换为8bit, NaN值设为0 + logging.info("4) 使用 gdal_translate 转换为 8bit (NaN->0)...") + translate_options = gdal.TranslateOptions( + scaleParams=scale_params, + outputType=gdal.GDT_Byte, + noData=0, # 将NaN转换为0 + creationOptions=[ + "COMPRESS=DEFLATE", + "TILED=YES", + "BIGTIFF=IF_SAFER", + "PHOTOMETRIC=RGB", + ] + ) + + gdal.Translate(str(output_path), vrt_ds, options=translate_options) + logging.info(f"已保存: {output_path}") + + # 5. 构建金字塔 (基于输出文件) + logging.info("5) 构建金字塔...") + # 释放VRT数据集,确保文件句柄释放 + vrt_ds = None + + # 打开生成的输出文件进行更新 + out_ds = gdal.Open(str(output_path), gdal.GA_Update) + if out_ds: + try: + # 构建金字塔: 2, 4, 8, 16... + # 这里根据影像大小自适应或者固定层级 + out_ds.BuildOverviews("AVERAGE", [2, 4, 8, 16]) + logging.info("金字塔构建完成") + finally: + out_ds = None + else: + logging.error(f"无法打开文件构建金字塔: {output_path}") + + except Exception as e: + logging.error(f"处理过程中出现异常: {str(e)}") + import traceback + logging.error(traceback.format_exc()) + finally: + # 清理资源 + vrt_ds = None + + +def main(input_dir, output_path): + input_dir = Path(input_dir) + output_path = Path(output_path) + + output_root = output_path.parent + os.makedirs(output_root, exist_ok=True) + + log_file = output_root / "sr2rgb_light.log" + setup_logging(str(log_file)) + + logging.info("开始批量处理 (Light Mode)...") + logging.info(f"输入目录: {input_dir}") + logging.info(f"输出文件: {output_path}") + + if output_path.exists(): + logging.warning(f"结果文件已存在: {output_path}") + return + + vrt_to_8bit_simple(input_dir, output_path) + + +if __name__ == "__main__": + # 输入目录: 包含分块tif影像的根目录 + input_root = Path(r"D:\CVEOdata\RS_Data\2025_S2") + # 输出目录: 存放最终RGB镶嵌结果的目录 + output_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_RGB") + rgb_file = output_root / "Hubei_Sentinel-2_2025_RGB.tif" + main(input_root, rgb_file) \ No newline at end of file