177 lines
6.0 KiB
Python
177 lines
6.0 KiB
Python
"""
|
||
使用 GDAL BuildVRT 与 Translate 快速将影像批量镶嵌并转换为 8bit RGB (Light version)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import logging
|
||
from pathlib import Path
|
||
from osgeo import gdal
|
||
import numpy as np
|
||
|
||
# 添加父目录到 sys.path 以导入 utils
|
||
BASE_DIR = Path(__file__).parent.parent
|
||
sys.path.append(str(BASE_DIR))
|
||
|
||
from utils.common_utils import setup_logging
|
||
|
||
gdal.UseExceptions()
|
||
# 设置 GDAL 选项以优化性能
|
||
gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS")
|
||
gdal.SetConfigOption("GDAL_CACHEMAX", "1024")
|
||
|
||
|
||
def vrt_to_8bit_simple(input_dir: str | Path, output_path: str | Path, num: int = 2):
|
||
"""
|
||
使用 gdal.Translate 将指定目录下的 tif 文件合并并转换为 8bit
|
||
NaN 值将被赋值为 0
|
||
|
||
Parameters
|
||
----------
|
||
input_dir : str | Path
|
||
输入 TIF 文件所在目录
|
||
output_path : str | Path
|
||
输出文件路径
|
||
num : int, optional
|
||
拉伸的百分比范围, 默认2%到98%
|
||
"""
|
||
input_dir = Path(input_dir)
|
||
output_path = Path(output_path)
|
||
|
||
# 1. 获取所有tif文件
|
||
tif_files = [
|
||
str(f)
|
||
for f in input_dir.iterdir()
|
||
if f.is_file() and f.suffix.lower() == ".tif"
|
||
]
|
||
|
||
if not tif_files:
|
||
logging.warning(f"{input_dir} 中没有找到 tif 文件.")
|
||
return
|
||
|
||
logging.info(f"1) 找到 {len(tif_files)} 个 tif 文件")
|
||
|
||
vrt_path = input_dir / "temp.vrt"
|
||
vrt_ds = None
|
||
|
||
try:
|
||
# 2. 创建VRT
|
||
logging.info("2) 构建 VRT 镶嵌...")
|
||
vrt_ds = gdal.BuildVRT(str(vrt_path), tif_files)
|
||
|
||
if vrt_ds is None:
|
||
logging.error("构建 VRT 失败.")
|
||
return
|
||
|
||
# 获取波段数与影像尺寸
|
||
num_bands = vrt_ds.RasterCount
|
||
width = vrt_ds.RasterXSize
|
||
height = vrt_ds.RasterYSize
|
||
logging.info(f"波段数: {num_bands}, 影像尺寸: {width}x{height}")
|
||
|
||
# 3. 计算每个波段的统计信息并处理 NaN 值
|
||
logging.info("3) 计算影像统计信息...")
|
||
scale_params = []
|
||
for band_idx in range(1, num_bands + 1):
|
||
band = vrt_ds.GetRasterBand(band_idx)
|
||
|
||
# 读取数据块以计算统计 (避免加载全部数据)
|
||
block_size = 1024
|
||
stats_min = []
|
||
stats_max = []
|
||
|
||
for y in range(0, vrt_ds.RasterYSize, block_size):
|
||
height = min(block_size, vrt_ds.RasterYSize - y)
|
||
for x in range(0, vrt_ds.RasterXSize, block_size):
|
||
width = min(block_size, vrt_ds.RasterXSize - x)
|
||
|
||
data = band.ReadAsArray(x, y, width, height)
|
||
if data is not None:
|
||
# 计算非NaN的最小最大值
|
||
valid_data = data[~np.isnan(data)]
|
||
# 读取2%-98%的有效数据范围
|
||
if len(valid_data) > 0:
|
||
stats_min.append(np.nanpercentile(valid_data, num))
|
||
stats_max.append(np.nanpercentile(valid_data, 100 - num))
|
||
|
||
if stats_min and stats_max:
|
||
min_val = min(stats_min)
|
||
max_val = max(stats_max)
|
||
logging.info(f"波段 {band_idx}: 有效值范围 [{min_val:.4f}, {max_val:.4f}]")
|
||
scale_params.append([min_val, max_val, 0, 255])
|
||
else:
|
||
logging.warning(f"波段 {band_idx}: 未找到有效数据, 使用默认范围 [0, 1]")
|
||
scale_params.append([0, 1, 0, 255])
|
||
|
||
# 4. 使用gdal_translate转换为8bit, NaN值设为0
|
||
logging.info("4) 使用 gdal_translate 转换为 8bit (NaN->0)...")
|
||
translate_options = gdal.TranslateOptions(
|
||
scaleParams=scale_params,
|
||
outputType=gdal.GDT_Byte,
|
||
noData=0, # 将NaN转换为0
|
||
creationOptions=[
|
||
"COMPRESS=DEFLATE",
|
||
"TILED=YES",
|
||
"BIGTIFF=IF_SAFER",
|
||
"PHOTOMETRIC=RGB",
|
||
]
|
||
)
|
||
|
||
gdal.Translate(str(output_path), vrt_ds, options=translate_options)
|
||
logging.info(f"已保存: {output_path}")
|
||
|
||
# 5. 构建金字塔 (基于输出文件)
|
||
logging.info("5) 构建金字塔...")
|
||
# 释放VRT数据集,确保文件句柄释放
|
||
vrt_ds = None
|
||
|
||
# 打开生成的输出文件进行更新
|
||
out_ds = gdal.Open(str(output_path), gdal.GA_Update)
|
||
if out_ds:
|
||
try:
|
||
# 构建金字塔: 2, 4, 8, 16...
|
||
# 这里根据影像大小自适应或者固定层级
|
||
out_ds.BuildOverviews("AVERAGE", [2, 4, 8, 16])
|
||
logging.info("金字塔构建完成")
|
||
finally:
|
||
out_ds = None
|
||
else:
|
||
logging.error(f"无法打开文件构建金字塔: {output_path}")
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理过程中出现异常: {str(e)}")
|
||
import traceback
|
||
logging.error(traceback.format_exc())
|
||
finally:
|
||
# 清理资源
|
||
vrt_ds = None
|
||
|
||
|
||
def main(input_dir, output_path):
|
||
input_dir = Path(input_dir)
|
||
output_path = Path(output_path)
|
||
|
||
output_root = output_path.parent
|
||
os.makedirs(output_root, exist_ok=True)
|
||
|
||
log_file = output_root / "sr2rgb_light.log"
|
||
setup_logging(str(log_file))
|
||
|
||
logging.info("开始批量处理 (Light Mode)...")
|
||
logging.info(f"输入目录: {input_dir}")
|
||
logging.info(f"输出文件: {output_path}")
|
||
|
||
if output_path.exists():
|
||
logging.warning(f"结果文件已存在: {output_path}")
|
||
return
|
||
|
||
vrt_to_8bit_simple(input_dir, output_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 输入目录: 包含分块tif影像的根目录
|
||
input_root = Path(r"D:\CVEOdata\RS_Data\2025_S2")
|
||
# 输出目录: 存放最终RGB镶嵌结果的目录
|
||
output_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_RGB")
|
||
rgb_file = output_root / "Hubei_Sentinel-2_2025_RGB.tif"
|
||
main(input_root, rgb_file) |