226 lines
6.9 KiB
Python

"""
将地表反射率图像转换为 8bit RGB 图像
- Red, Green, Blue 单波段图像
- 多波段图像
1. 对比度线性拉伸至 0-255;
2. 合并波段;
3. 保存为 uint8 格式 RGB 图像;
"""
import os
import sys
from pathlib import Path
from typing import List, Optional
import numpy as np
from PIL import Image
from osgeo import gdal
import xarray as xr
from rioxarray import open_rasterio
gdal.UseExceptions()
def normalize_band(
band: np.ndarray, method: str = "minmax", percentile: float = 2.0
) -> np.ndarray:
"""
将波段数据归一化到 0-255 uint8
Parameters
----------
band: np.ndarray
输入波段数据
method: str, optional
归一化方法 ('minmax', 'percentile', 'clip'), by default "minmax"
percentile: float, optional
百分比截断参数 (仅用于 'percentile' 方法), by default 2.0
"""
band = np.asarray(band, dtype=np.float32)
if method == "clip":
# 对应原 render8bit 逻辑: 直接截断并转换
band[~np.isfinite(band)] = 0.0
return np.clip(band, 0.0, 255.0).astype(np.uint8)
if method == "percentile":
# 对应原 linear_stretch 逻辑
lower = np.nanpercentile(band, percentile)
upper = np.nanpercentile(band, 100 - percentile)
else: # minmax
# 使用0值处理NaN和负值
band = np.where(np.isnan(band) | (band < 0), 0, band)
lower = np.min(band)
upper = np.max(band)
if upper == lower:
return np.zeros_like(band, dtype=np.uint8)
stretched = (band - lower) / (upper - lower) * 255.0
return np.clip(stretched, 0, 255).astype(np.uint8)
def render_image_rgb(
tiff_path: Path,
bands_lst: List[int],
normalization_method: str,
percentile: float = 0,
save_tiff: bool = True,
output_suffix: str = "_rgb.png",
) -> str:
"""
通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGB 图像.
Parameters
----------
tiff_path : Path
输入 TIFF 图像路径.
bands_lst : List[int]
RGB 波段索引列表 (GDAL 索引, 从 1 开始).
normalization_method : str
归一化方法 ('minmax', 'percentile', 'clip').
percentile : float, optional
百分比截断参数 (仅用于 'percentile' 方法), by default 0.
save_tiff : bool, optional
是否保存为 RGB TIFF 文件, by default True.
output_suffix : str, optional
输出文件后缀, by default "_rgb.png".
Returns
-------
str
生成的 PNG 文件路径.
"""
if not tiff_path.exists():
raise FileNotFoundError(f"input not found: {tiff_path}")
if len(bands_lst) < 3:
raise ValueError("rgb need at least 3 bands")
ds = gdal.Open(str(tiff_path))
xsize = ds.RasterXSize
ysize = ds.RasterYSize
bands = bands_lst[:3]
chans = []
for b in bands:
band = ds.GetRasterBand(int(b))
arr = band.ReadAsArray(0, 0, xsize, ysize)
chans.append(
normalize_band(arr, method=normalization_method, percentile=percentile)
)
rgb = np.dstack(chans)
# 构建输出路径
base_path = str(tiff_path)
if base_path.lower().endswith(".tif"):
base_path = base_path[:-4]
png_path = base_path + output_suffix
if save_tiff:
tiff_output_path = base_path + "_rgb.tif"
driver = gdal.GetDriverByName("GTiff")
out_ds = driver.Create(tiff_output_path, xsize, ysize, 3, gdal.GDT_Byte)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i, chan in enumerate(chans):
out_ds.GetRasterBand(i + 1).WriteArray(chan)
out_ds.FlushCache()
out_ds = None
ds = None
Image.fromarray(rgb, mode="RGB").save(png_path, format="PNG")
return png_path
def combine_bands_to_rgb(
red_path: str, green_path: str, blue_path: str, output_path: str
) -> None:
"""
将红, 绿, 蓝三个单波段图像文件合并为 RGB 图像 (GeoTIFF).
Parameters
----------
red_path : str
红色波段文件路径.
green_path : str
绿色波段文件路径.
blue_path : str
蓝色波段文件路径.
output_path : str
输出 RGB 图像路径.
"""
paths = [red_path, green_path, blue_path]
for path in paths:
if not os.path.exists(path):
raise FileNotFoundError(f"文件不存在: {path}")
# 读取三个波段数据
bands_data = []
ref_band = None
for path in paths:
da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True)
if ref_band is None:
ref_band = da
# 使用 minmax 方法 (对应原 stretch_band)
bands_data.append(normalize_band(da.values, method="minmax"))
# 合并三个波段为RGB图像
rgb_array = xr.DataArray(
np.dstack(bands_data),
dims=("y", "x", "band"),
coords={"band": [1, 2, 3], "y": ref_band.y, "x": ref_band.x},
)
# 转置维度顺序以符合rioxarray要求
rgb_array = rgb_array.transpose("band", "y", "x")
# 写入元数据
rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True)
rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True)
rgb_array.rio.write_nodata(0, inplace=True)
# 保存为TIFF文件
rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8")
print(f"RGB图像已保存到: {output_path}")
return
if __name__ == "__main__":
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2024\\2024012"
# red_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif"
# )
# green_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.GREEN.subset.tif"
# )
# blue_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.BLUE.subset.tif"
# )
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif")
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2025\\2025011"
# red_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif"
# )
# green_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif"
# )
# blue_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif"
# )
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif")
# combine_bands_to_rgb(red_path, green_path, blue_path, output_path)
imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs")
region_name = "Target3"
year = 2025
img_name = f"S1_S2_{region_name}_{year}.tif"
img_tif_path = imgs_dir / region_name / img_name
bands_lst = [3, 2, 1] # GDAL 波段索引从1开始
# render_image_rgb(img_tif_path, bands_lst, "clip")
render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)