From eeed789d5b6e32be648ce519012b75beea52b8b6 Mon Sep 17 00:00:00 2001 From: xhong Date: Tue, 6 Jan 2026 20:59:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(sr2rgb):=20=E9=87=8D=E6=9E=84=E5=9C=B0?= =?UTF-8?q?=E8=A1=A8=E5=8F=8D=E5=B0=84=E7=8E=87=E8=BD=ACRGB=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E5=8A=9F=E8=83=BD=E5=B9=B6=E6=B7=BB=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E6=B3=A2=E6=AE=B5=E6=94=AF=E6=8C=81.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/sr2rgb.py | 235 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 178 insertions(+), 57 deletions(-) diff --git a/utils/sr2rgb.py b/utils/sr2rgb.py index 23a5fb1..43a92e3 100644 --- a/utils/sr2rgb.py +++ b/utils/sr2rgb.py @@ -1,5 +1,8 @@ """ -将 COG 格式的 Red, Green, Blue 单波段地表反射率图像合成为 RGB 图像 +将地表反射率图像转换为 8bit RGB 图像 + +- Red, Green, Blue 单波段图像 +- 多波段图像 1. 对比度线性拉伸至 0-255; 2. 合并波段; @@ -7,77 +10,186 @@ """ import os +import sys +from pathlib import Path +from typing import List, Optional + +import numpy as np +from PIL import Image +from osgeo import gdal import xarray as xr from rioxarray import open_rasterio -import numpy as np + +gdal.UseExceptions() -def sr2rgb(red_path: str, green_path: str, blue_path: str, output_path: str) -> None: +def normalize_band( + band: np.ndarray, method: str = "minmax", percentile: float = 2.0 +) -> np.ndarray: """ - 将红、绿、蓝三个单波段地表反射率图像合成为RGB图像 + 将波段数据归一化到 0-255 uint8 - 参数: - red_path (str): 红色波段文件路径 - green_path (str): 绿色波段文件路径 - blue_path (str): 蓝色波段文件路径 - output_path (str): 输出RGB图像路径 + Parameters + ---------- + band: np.ndarray + 输入波段数据 + method: str, optional + 归一化方法 ('minmax', 'percentile', 'clip'), by default "minmax" + percentile: float, optional + 百分比截断参数 (仅用于 'percentile' 方法), by default 2.0 """ - # 检查文件是否存在 - for path in [red_path, green_path, blue_path]: + band = np.asarray(band, dtype=np.float32) + + if method == "clip": + # 对应原 render8bit 逻辑: 直接截断并转换 + band[~np.isfinite(band)] = 0.0 + return np.clip(band, 0.0, 255.0).astype(np.uint8) + + if method == "percentile": + # 对应原 linear_stretch 逻辑 + lower = np.nanpercentile(band, percentile) + upper = np.nanpercentile(band, 100 - percentile) + else: # minmax + # 使用0值处理NaN和负值 + band = np.where(np.isnan(band) | (band < 0), 0, band) + lower = np.min(band) + upper = np.max(band) + + if upper == lower: + return np.zeros_like(band, dtype=np.uint8) + + stretched = (band - lower) / (upper - lower) * 255.0 + return np.clip(stretched, 0, 255).astype(np.uint8) + + +def render_image_rgb( + tiff_path: Path, + bands_lst: List[int], + normalization_method: str, + percentile: float = 0, + save_tiff: bool = True, + output_suffix: str = "_rgb.png", +) -> str: + """ + 通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGB 图像. + + Parameters + ---------- + tiff_path : Path + 输入 TIFF 图像路径. + bands_lst : List[int] + RGB 波段索引列表 (GDAL 索引, 从 1 开始). + normalization_method : str + 归一化方法 ('minmax', 'percentile', 'clip'). + percentile : float, optional + 百分比截断参数 (仅用于 'percentile' 方法), by default 0. + save_tiff : bool, optional + 是否保存为 RGB TIFF 文件, by default True. + output_suffix : str, optional + 输出文件后缀, by default "_rgb.png". + + Returns + ------- + str + 生成的 PNG 文件路径. + """ + if not tiff_path.exists(): + raise FileNotFoundError(f"input not found: {tiff_path}") + + if len(bands_lst) < 3: + raise ValueError("rgb need at least 3 bands") + + ds = gdal.Open(str(tiff_path)) + xsize = ds.RasterXSize + ysize = ds.RasterYSize + bands = bands_lst[:3] + + chans = [] + for b in bands: + band = ds.GetRasterBand(int(b)) + arr = band.ReadAsArray(0, 0, xsize, ysize) + chans.append( + normalize_band(arr, method=normalization_method, percentile=percentile) + ) + + rgb = np.dstack(chans) + + # 构建输出路径 + base_path = str(tiff_path) + if base_path.lower().endswith(".tif"): + base_path = base_path[:-4] + + png_path = base_path + output_suffix + + if save_tiff: + tiff_output_path = base_path + "_rgb.tif" + driver = gdal.GetDriverByName("GTiff") + out_ds = driver.Create(tiff_output_path, xsize, ysize, 3, gdal.GDT_Byte) + out_ds.SetProjection(ds.GetProjection()) + out_ds.SetGeoTransform(ds.GetGeoTransform()) + for i, chan in enumerate(chans): + out_ds.GetRasterBand(i + 1).WriteArray(chan) + out_ds.FlushCache() + out_ds = None + + ds = None + Image.fromarray(rgb, mode="RGB").save(png_path, format="PNG") + return png_path + + +def combine_bands_to_rgb( + red_path: str, green_path: str, blue_path: str, output_path: str +) -> None: + """ + 将红, 绿, 蓝三个单波段图像文件合并为 RGB 图像 (GeoTIFF). + + Parameters + ---------- + red_path : str + 红色波段文件路径. + green_path : str + 绿色波段文件路径. + blue_path : str + 蓝色波段文件路径. + output_path : str + 输出 RGB 图像路径. + """ + paths = [red_path, green_path, blue_path] + for path in paths: if not os.path.exists(path): raise FileNotFoundError(f"文件不存在: {path}") # 读取三个波段数据 - red_band = open_rasterio(red_path, masked=True).squeeze(dim="band", drop=True) - green_band = open_rasterio(green_path, masked=True).squeeze(dim="band", drop=True) - blue_band = open_rasterio(blue_path, masked=True).squeeze(dim="band", drop=True) - # 暂存元数据 - y_coords = red_band.y - x_coords = red_band.x - crs = red_band.rio.crs - transform = red_band.rio.transform() + bands_data = [] + ref_band = None - def stretch_band(band): - """ - 线性拉伸到0-255范围 - """ - # 处理NaN值与负值 - band_no_nan = np.where(np.isnan(band), 0, band) - band_no_nan = np.where(band_no_nan < 0, 0, band_no_nan) - band_min = np.min(band_no_nan) - band_max = np.max(band_no_nan) - # 避免除零错误 - if band_max == band_min: - stretched = np.zeros_like(band_no_nan, dtype=np.uint8) - else: - stretched = ((band_no_nan - band_min) / (band_max - band_min) * 255).astype( - np.uint8 - ) - return stretched + for path in paths: + da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True) + if ref_band is None: + ref_band = da + # 使用 minmax 方法 (对应原 stretch_band) + bands_data.append(normalize_band(da.values, method="minmax")) - red_stretched = stretch_band(red_band.values) - green_stretched = stretch_band(green_band.values) - blue_stretched = stretch_band(blue_band.values) # 合并三个波段为RGB图像 rgb_array = xr.DataArray( - np.dstack((red_stretched, green_stretched, blue_stretched)), + np.dstack(bands_data), dims=("y", "x", "band"), - coords={"band": [1, 2, 3], "y": y_coords, "x": x_coords}, + coords={"band": [1, 2, 3], "y": ref_band.y, "x": ref_band.x}, ) # 转置维度顺序以符合rioxarray要求 rgb_array = rgb_array.transpose("band", "y", "x") # 写入元数据 - rgb_array.rio.write_crs(crs, inplace=True) - rgb_array.rio.write_transform(transform, inplace=True) + rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True) + rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True) rgb_array.rio.write_nodata(0, inplace=True) # 保存为TIFF文件 - rgb_array.rio.to_raster(output_path, dtype="uint8") + rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8") print(f"RGB图像已保存到: {output_path}") return if __name__ == "__main__": - # tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2024\\2024012" + # tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2024\\2024012" # red_path = os.path.join( # tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif" # ) @@ -89,16 +201,25 @@ if __name__ == "__main__": # ) # output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif") - tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2025\\2025011" - red_path = os.path.join( - tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif" - ) - green_path = os.path.join( - tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif" - ) - blue_path = os.path.join( - tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif" - ) - output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif") + # tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2025\\2025011" + # red_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif" + # ) + # green_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif" + # ) + # blue_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif" + # ) + # output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif") - sr2rgb(red_path, green_path, blue_path, output_path) + # combine_bands_to_rgb(red_path, green_path, blue_path, output_path) + + imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs") + region_name = "Target3" + year = 2025 + img_name = f"S1_S2_{region_name}_{year}.tif" + img_tif_path = imgs_dir / region_name / img_name + bands_lst = [3, 2, 1] # GDAL 波段索引从1开始 + # render_image_rgb(img_tif_path, bands_lst, "clip") + render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)