feat(sr2rgb): 添加SR地表反射率影像批量镶嵌并转换8bit RGB影像代码.
This commit is contained in:
parent
7b3573e57d
commit
fc5a29696a
177
utils/sr2rgb_light.py
Normal file
177
utils/sr2rgb_light.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""
|
||||
使用 GDAL BuildVRT 与 Translate 快速将影像批量镶嵌并转换为 8bit RGB (Light version)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
|
||||
# 添加父目录到 sys.path 以导入 utils
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
sys.path.append(str(BASE_DIR))
|
||||
|
||||
from utils.common_utils import setup_logging
|
||||
|
||||
gdal.UseExceptions()
|
||||
# 设置 GDAL 选项以优化性能
|
||||
gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS")
|
||||
gdal.SetConfigOption("GDAL_CACHEMAX", "1024")
|
||||
|
||||
|
||||
def vrt_to_8bit_simple(input_dir: str | Path, output_path: str | Path, num: int = 2):
|
||||
"""
|
||||
使用 gdal.Translate 将指定目录下的 tif 文件合并并转换为 8bit
|
||||
NaN 值将被赋值为 0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dir : str | Path
|
||||
输入 TIF 文件所在目录
|
||||
output_path : str | Path
|
||||
输出文件路径
|
||||
num : int, optional
|
||||
拉伸的百分比范围, 默认2%到98%
|
||||
"""
|
||||
input_dir = Path(input_dir)
|
||||
output_path = Path(output_path)
|
||||
|
||||
# 1. 获取所有tif文件
|
||||
tif_files = [
|
||||
str(f)
|
||||
for f in input_dir.iterdir()
|
||||
if f.is_file() and f.suffix.lower() == ".tif"
|
||||
]
|
||||
|
||||
if not tif_files:
|
||||
logging.warning(f"{input_dir} 中没有找到 tif 文件.")
|
||||
return
|
||||
|
||||
logging.info(f"1) 找到 {len(tif_files)} 个 tif 文件")
|
||||
|
||||
vrt_path = input_dir / "temp.vrt"
|
||||
vrt_ds = None
|
||||
|
||||
try:
|
||||
# 2. 创建VRT
|
||||
logging.info("2) 构建 VRT 镶嵌...")
|
||||
vrt_ds = gdal.BuildVRT(str(vrt_path), tif_files)
|
||||
|
||||
if vrt_ds is None:
|
||||
logging.error("构建 VRT 失败.")
|
||||
return
|
||||
|
||||
# 获取波段数与影像尺寸
|
||||
num_bands = vrt_ds.RasterCount
|
||||
width = vrt_ds.RasterXSize
|
||||
height = vrt_ds.RasterYSize
|
||||
logging.info(f"波段数: {num_bands}, 影像尺寸: {width}x{height}")
|
||||
|
||||
# 3. 计算每个波段的统计信息并处理 NaN 值
|
||||
logging.info("3) 计算影像统计信息...")
|
||||
scale_params = []
|
||||
for band_idx in range(1, num_bands + 1):
|
||||
band = vrt_ds.GetRasterBand(band_idx)
|
||||
|
||||
# 读取数据块以计算统计 (避免加载全部数据)
|
||||
block_size = 1024
|
||||
stats_min = []
|
||||
stats_max = []
|
||||
|
||||
for y in range(0, vrt_ds.RasterYSize, block_size):
|
||||
height = min(block_size, vrt_ds.RasterYSize - y)
|
||||
for x in range(0, vrt_ds.RasterXSize, block_size):
|
||||
width = min(block_size, vrt_ds.RasterXSize - x)
|
||||
|
||||
data = band.ReadAsArray(x, y, width, height)
|
||||
if data is not None:
|
||||
# 计算非NaN的最小最大值
|
||||
valid_data = data[~np.isnan(data)]
|
||||
# 读取2%-98%的有效数据范围
|
||||
if len(valid_data) > 0:
|
||||
stats_min.append(np.nanpercentile(valid_data, num))
|
||||
stats_max.append(np.nanpercentile(valid_data, 100 - num))
|
||||
|
||||
if stats_min and stats_max:
|
||||
min_val = min(stats_min)
|
||||
max_val = max(stats_max)
|
||||
logging.info(f"波段 {band_idx}: 有效值范围 [{min_val:.4f}, {max_val:.4f}]")
|
||||
scale_params.append([min_val, max_val, 0, 255])
|
||||
else:
|
||||
logging.warning(f"波段 {band_idx}: 未找到有效数据, 使用默认范围 [0, 1]")
|
||||
scale_params.append([0, 1, 0, 255])
|
||||
|
||||
# 4. 使用gdal_translate转换为8bit, NaN值设为0
|
||||
logging.info("4) 使用 gdal_translate 转换为 8bit (NaN->0)...")
|
||||
translate_options = gdal.TranslateOptions(
|
||||
scaleParams=scale_params,
|
||||
outputType=gdal.GDT_Byte,
|
||||
noData=0, # 将NaN转换为0
|
||||
creationOptions=[
|
||||
"COMPRESS=DEFLATE",
|
||||
"TILED=YES",
|
||||
"BIGTIFF=IF_SAFER",
|
||||
"PHOTOMETRIC=RGB",
|
||||
]
|
||||
)
|
||||
|
||||
gdal.Translate(str(output_path), vrt_ds, options=translate_options)
|
||||
logging.info(f"已保存: {output_path}")
|
||||
|
||||
# 5. 构建金字塔 (基于输出文件)
|
||||
logging.info("5) 构建金字塔...")
|
||||
# 释放VRT数据集,确保文件句柄释放
|
||||
vrt_ds = None
|
||||
|
||||
# 打开生成的输出文件进行更新
|
||||
out_ds = gdal.Open(str(output_path), gdal.GA_Update)
|
||||
if out_ds:
|
||||
try:
|
||||
# 构建金字塔: 2, 4, 8, 16...
|
||||
# 这里根据影像大小自适应或者固定层级
|
||||
out_ds.BuildOverviews("AVERAGE", [2, 4, 8, 16])
|
||||
logging.info("金字塔构建完成")
|
||||
finally:
|
||||
out_ds = None
|
||||
else:
|
||||
logging.error(f"无法打开文件构建金字塔: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理过程中出现异常: {str(e)}")
|
||||
import traceback
|
||||
logging.error(traceback.format_exc())
|
||||
finally:
|
||||
# 清理资源
|
||||
vrt_ds = None
|
||||
|
||||
|
||||
def main(input_dir, output_path):
|
||||
input_dir = Path(input_dir)
|
||||
output_path = Path(output_path)
|
||||
|
||||
output_root = output_path.parent
|
||||
os.makedirs(output_root, exist_ok=True)
|
||||
|
||||
log_file = output_root / "sr2rgb_light.log"
|
||||
setup_logging(str(log_file))
|
||||
|
||||
logging.info("开始批量处理 (Light Mode)...")
|
||||
logging.info(f"输入目录: {input_dir}")
|
||||
logging.info(f"输出文件: {output_path}")
|
||||
|
||||
if output_path.exists():
|
||||
logging.warning(f"结果文件已存在: {output_path}")
|
||||
return
|
||||
|
||||
vrt_to_8bit_simple(input_dir, output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 输入目录: 包含分块tif影像的根目录
|
||||
input_root = Path(r"D:\CVEOdata\RS_Data\2025_S2")
|
||||
# 输出目录: 存放最终RGB镶嵌结果的目录
|
||||
output_root = Path(r"D:\CVEOdata\RS_Data\2025_S2_RGB")
|
||||
rgb_file = output_root / "Hubei_Sentinel-2_2025_RGB.tif"
|
||||
main(input_root, rgb_file)
|
||||
Loading…
x
Reference in New Issue
Block a user