feat(sr2rgb): 重构地表反射率转RGB图像功能并添加多波段支持.
This commit is contained in:
parent
c8cdbff7b9
commit
eeed789d5b
235
utils/sr2rgb.py
235
utils/sr2rgb.py
@ -1,5 +1,8 @@
|
||||
"""
|
||||
将 COG 格式的 Red, Green, Blue 单波段地表反射率图像合成为 RGB 图像
|
||||
将地表反射率图像转换为 8bit RGB 图像
|
||||
|
||||
- Red, Green, Blue 单波段图像
|
||||
- 多波段图像
|
||||
|
||||
1. 对比度线性拉伸至 0-255;
|
||||
2. 合并波段;
|
||||
@ -7,77 +10,186 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from osgeo import gdal
|
||||
import xarray as xr
|
||||
from rioxarray import open_rasterio
|
||||
import numpy as np
|
||||
|
||||
gdal.UseExceptions()
|
||||
|
||||
|
||||
def sr2rgb(red_path: str, green_path: str, blue_path: str, output_path: str) -> None:
|
||||
def normalize_band(
|
||||
band: np.ndarray, method: str = "minmax", percentile: float = 2.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
将红、绿、蓝三个单波段地表反射率图像合成为RGB图像
|
||||
将波段数据归一化到 0-255 uint8
|
||||
|
||||
参数:
|
||||
red_path (str): 红色波段文件路径
|
||||
green_path (str): 绿色波段文件路径
|
||||
blue_path (str): 蓝色波段文件路径
|
||||
output_path (str): 输出RGB图像路径
|
||||
Parameters
|
||||
----------
|
||||
band: np.ndarray
|
||||
输入波段数据
|
||||
method: str, optional
|
||||
归一化方法 ('minmax', 'percentile', 'clip'), by default "minmax"
|
||||
percentile: float, optional
|
||||
百分比截断参数 (仅用于 'percentile' 方法), by default 2.0
|
||||
"""
|
||||
# 检查文件是否存在
|
||||
for path in [red_path, green_path, blue_path]:
|
||||
band = np.asarray(band, dtype=np.float32)
|
||||
|
||||
if method == "clip":
|
||||
# 对应原 render8bit 逻辑: 直接截断并转换
|
||||
band[~np.isfinite(band)] = 0.0
|
||||
return np.clip(band, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
if method == "percentile":
|
||||
# 对应原 linear_stretch 逻辑
|
||||
lower = np.nanpercentile(band, percentile)
|
||||
upper = np.nanpercentile(band, 100 - percentile)
|
||||
else: # minmax
|
||||
# 使用0值处理NaN和负值
|
||||
band = np.where(np.isnan(band) | (band < 0), 0, band)
|
||||
lower = np.min(band)
|
||||
upper = np.max(band)
|
||||
|
||||
if upper == lower:
|
||||
return np.zeros_like(band, dtype=np.uint8)
|
||||
|
||||
stretched = (band - lower) / (upper - lower) * 255.0
|
||||
return np.clip(stretched, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def render_image_rgb(
|
||||
tiff_path: Path,
|
||||
bands_lst: List[int],
|
||||
normalization_method: str,
|
||||
percentile: float = 0,
|
||||
save_tiff: bool = True,
|
||||
output_suffix: str = "_rgb.png",
|
||||
) -> str:
|
||||
"""
|
||||
通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGB 图像.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tiff_path : Path
|
||||
输入 TIFF 图像路径.
|
||||
bands_lst : List[int]
|
||||
RGB 波段索引列表 (GDAL 索引, 从 1 开始).
|
||||
normalization_method : str
|
||||
归一化方法 ('minmax', 'percentile', 'clip').
|
||||
percentile : float, optional
|
||||
百分比截断参数 (仅用于 'percentile' 方法), by default 0.
|
||||
save_tiff : bool, optional
|
||||
是否保存为 RGB TIFF 文件, by default True.
|
||||
output_suffix : str, optional
|
||||
输出文件后缀, by default "_rgb.png".
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
生成的 PNG 文件路径.
|
||||
"""
|
||||
if not tiff_path.exists():
|
||||
raise FileNotFoundError(f"input not found: {tiff_path}")
|
||||
|
||||
if len(bands_lst) < 3:
|
||||
raise ValueError("rgb need at least 3 bands")
|
||||
|
||||
ds = gdal.Open(str(tiff_path))
|
||||
xsize = ds.RasterXSize
|
||||
ysize = ds.RasterYSize
|
||||
bands = bands_lst[:3]
|
||||
|
||||
chans = []
|
||||
for b in bands:
|
||||
band = ds.GetRasterBand(int(b))
|
||||
arr = band.ReadAsArray(0, 0, xsize, ysize)
|
||||
chans.append(
|
||||
normalize_band(arr, method=normalization_method, percentile=percentile)
|
||||
)
|
||||
|
||||
rgb = np.dstack(chans)
|
||||
|
||||
# 构建输出路径
|
||||
base_path = str(tiff_path)
|
||||
if base_path.lower().endswith(".tif"):
|
||||
base_path = base_path[:-4]
|
||||
|
||||
png_path = base_path + output_suffix
|
||||
|
||||
if save_tiff:
|
||||
tiff_output_path = base_path + "_rgb.tif"
|
||||
driver = gdal.GetDriverByName("GTiff")
|
||||
out_ds = driver.Create(tiff_output_path, xsize, ysize, 3, gdal.GDT_Byte)
|
||||
out_ds.SetProjection(ds.GetProjection())
|
||||
out_ds.SetGeoTransform(ds.GetGeoTransform())
|
||||
for i, chan in enumerate(chans):
|
||||
out_ds.GetRasterBand(i + 1).WriteArray(chan)
|
||||
out_ds.FlushCache()
|
||||
out_ds = None
|
||||
|
||||
ds = None
|
||||
Image.fromarray(rgb, mode="RGB").save(png_path, format="PNG")
|
||||
return png_path
|
||||
|
||||
|
||||
def combine_bands_to_rgb(
|
||||
red_path: str, green_path: str, blue_path: str, output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
将红, 绿, 蓝三个单波段图像文件合并为 RGB 图像 (GeoTIFF).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
red_path : str
|
||||
红色波段文件路径.
|
||||
green_path : str
|
||||
绿色波段文件路径.
|
||||
blue_path : str
|
||||
蓝色波段文件路径.
|
||||
output_path : str
|
||||
输出 RGB 图像路径.
|
||||
"""
|
||||
paths = [red_path, green_path, blue_path]
|
||||
for path in paths:
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"文件不存在: {path}")
|
||||
|
||||
# 读取三个波段数据
|
||||
red_band = open_rasterio(red_path, masked=True).squeeze(dim="band", drop=True)
|
||||
green_band = open_rasterio(green_path, masked=True).squeeze(dim="band", drop=True)
|
||||
blue_band = open_rasterio(blue_path, masked=True).squeeze(dim="band", drop=True)
|
||||
# 暂存元数据
|
||||
y_coords = red_band.y
|
||||
x_coords = red_band.x
|
||||
crs = red_band.rio.crs
|
||||
transform = red_band.rio.transform()
|
||||
bands_data = []
|
||||
ref_band = None
|
||||
|
||||
def stretch_band(band):
|
||||
"""
|
||||
线性拉伸到0-255范围
|
||||
"""
|
||||
# 处理NaN值与负值
|
||||
band_no_nan = np.where(np.isnan(band), 0, band)
|
||||
band_no_nan = np.where(band_no_nan < 0, 0, band_no_nan)
|
||||
band_min = np.min(band_no_nan)
|
||||
band_max = np.max(band_no_nan)
|
||||
# 避免除零错误
|
||||
if band_max == band_min:
|
||||
stretched = np.zeros_like(band_no_nan, dtype=np.uint8)
|
||||
else:
|
||||
stretched = ((band_no_nan - band_min) / (band_max - band_min) * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
return stretched
|
||||
for path in paths:
|
||||
da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True)
|
||||
if ref_band is None:
|
||||
ref_band = da
|
||||
# 使用 minmax 方法 (对应原 stretch_band)
|
||||
bands_data.append(normalize_band(da.values, method="minmax"))
|
||||
|
||||
red_stretched = stretch_band(red_band.values)
|
||||
green_stretched = stretch_band(green_band.values)
|
||||
blue_stretched = stretch_band(blue_band.values)
|
||||
# 合并三个波段为RGB图像
|
||||
rgb_array = xr.DataArray(
|
||||
np.dstack((red_stretched, green_stretched, blue_stretched)),
|
||||
np.dstack(bands_data),
|
||||
dims=("y", "x", "band"),
|
||||
coords={"band": [1, 2, 3], "y": y_coords, "x": x_coords},
|
||||
coords={"band": [1, 2, 3], "y": ref_band.y, "x": ref_band.x},
|
||||
)
|
||||
# 转置维度顺序以符合rioxarray要求
|
||||
rgb_array = rgb_array.transpose("band", "y", "x")
|
||||
# 写入元数据
|
||||
rgb_array.rio.write_crs(crs, inplace=True)
|
||||
rgb_array.rio.write_transform(transform, inplace=True)
|
||||
rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True)
|
||||
rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True)
|
||||
rgb_array.rio.write_nodata(0, inplace=True)
|
||||
# 保存为TIFF文件
|
||||
rgb_array.rio.to_raster(output_path, dtype="uint8")
|
||||
rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8")
|
||||
print(f"RGB图像已保存到: {output_path}")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2024\\2024012"
|
||||
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2024\\2024012"
|
||||
# red_path = os.path.join(
|
||||
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif"
|
||||
# )
|
||||
@ -89,16 +201,25 @@ if __name__ == "__main__":
|
||||
# )
|
||||
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif")
|
||||
|
||||
tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2025\\2025011"
|
||||
red_path = os.path.join(
|
||||
tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif"
|
||||
)
|
||||
green_path = os.path.join(
|
||||
tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif"
|
||||
)
|
||||
blue_path = os.path.join(
|
||||
tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif"
|
||||
)
|
||||
output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif")
|
||||
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2025\\2025011"
|
||||
# red_path = os.path.join(
|
||||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif"
|
||||
# )
|
||||
# green_path = os.path.join(
|
||||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif"
|
||||
# )
|
||||
# blue_path = os.path.join(
|
||||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif"
|
||||
# )
|
||||
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif")
|
||||
|
||||
sr2rgb(red_path, green_path, blue_path, output_path)
|
||||
# combine_bands_to_rgb(red_path, green_path, blue_path, output_path)
|
||||
|
||||
imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs")
|
||||
region_name = "Target3"
|
||||
year = 2025
|
||||
img_name = f"S1_S2_{region_name}_{year}.tif"
|
||||
img_tif_path = imgs_dir / region_name / img_name
|
||||
bands_lst = [3, 2, 1] # GDAL 波段索引从1开始
|
||||
# render_image_rgb(img_tif_path, bands_lst, "clip")
|
||||
render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user