276 lines
8.7 KiB
Python
276 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
===============================================================================
|
||
将原始数值影像转换为 8bit RGBA 图像
|
||
|
||
- 支持 Red, Green, Blue 等单波段影像
|
||
- 支持多波段影像
|
||
|
||
1. 将原始数据中 NoData 值替换为 NaN, 并将其设置为 Alpha 通道
|
||
2. 对比度线性拉伸至 0-255;
|
||
3. 合并包含 Alpha 通道在内的 4 个波段;
|
||
4. 保存为 uint8 格式 RGBA 图像
|
||
|
||
-------------------------------------------------------------------------------
|
||
Authors: CVEO Team
|
||
Last Updated: 2025-01-08
|
||
===============================================================================
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
import numpy as np
|
||
from PIL import Image
|
||
from osgeo import gdal
|
||
import xarray as xr
|
||
from rioxarray import open_rasterio
|
||
|
||
gdal.UseExceptions()
|
||
|
||
|
||
def normalize_band(
|
||
band: np.ndarray,
|
||
method: str = "minmax",
|
||
percentile: float = 2.0,
|
||
nodata: Optional[float] = None,
|
||
) -> np.ndarray:
|
||
"""
|
||
将波段数据归一化到 0-255 uint8
|
||
|
||
Parameters
|
||
----------
|
||
band: np.ndarray
|
||
输入波段数据
|
||
method: str, optional
|
||
归一化方法 ('minmax', 'percentile', 'clip'), by default "minmax"
|
||
percentile: float, optional
|
||
百分比截断参数 (仅用于 'percentile' 方法), by default 2.0
|
||
nodata: float, optional
|
||
NoData 值, 如果不为 None, 则将该值替换为 NaN, by default None
|
||
"""
|
||
band = np.asarray(band, dtype=np.float32)
|
||
|
||
# 处理 NoData 值
|
||
if nodata is not None:
|
||
band[np.isclose(band, nodata)] = np.nan
|
||
|
||
if method == "clip":
|
||
# 直接截断并转换
|
||
band[~np.isfinite(band)] = 0.0
|
||
return np.clip(band, 0.0, 255.0).astype(np.uint8)
|
||
|
||
if method == "percentile":
|
||
# 百分比截断并线性拉伸
|
||
lower = np.nanpercentile(band, percentile)
|
||
upper = np.nanpercentile(band, 100 - percentile)
|
||
else: # minmax
|
||
# 使用0值处理NaN和负值
|
||
band = np.where(np.isnan(band) | (band < 0), 0, band)
|
||
lower = np.min(band)
|
||
upper = np.max(band)
|
||
|
||
if upper == lower:
|
||
return np.zeros_like(band, dtype=np.uint8)
|
||
|
||
stretched = (band - lower) / (upper - lower) * 255.0
|
||
# 处理拉伸后的 NaN 值 (原始数据中的 NaN 会传播)
|
||
stretched[~np.isfinite(stretched)] = 0.0
|
||
return np.clip(stretched, 0, 255).astype(np.uint8)
|
||
|
||
|
||
def render_image_rgb(
|
||
tiff_path: Path,
|
||
bands_lst: List[int],
|
||
normalization_method: str,
|
||
percentile: float = 0,
|
||
save_tiff: bool = True,
|
||
output_suffix: str = "_rgb.png",
|
||
) -> str:
|
||
"""
|
||
通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGBA 图像.
|
||
|
||
Parameters
|
||
----------
|
||
tiff_path : Path
|
||
输入 TIFF 图像路径.
|
||
bands_lst : List[int]
|
||
RGB 波段索引列表 (GDAL 索引, 从 1 开始).
|
||
normalization_method : str
|
||
归一化方法 ('minmax', 'percentile', 'clip').
|
||
percentile : float, optional
|
||
百分比截断参数 (仅用于 'percentile' 方法), by default 0.
|
||
save_tiff : bool, optional
|
||
是否保存为 RGBA TIFF 文件, by default True.
|
||
output_suffix : str, optional
|
||
输出文件后缀, by default "_rgb.png".
|
||
|
||
Returns
|
||
-------
|
||
str
|
||
生成的 PNG 文件路径.
|
||
"""
|
||
if not tiff_path.exists():
|
||
raise FileNotFoundError(f"input not found: {tiff_path}")
|
||
|
||
if len(bands_lst) < 3:
|
||
raise ValueError("rgb need at least 3 bands")
|
||
|
||
ds = gdal.Open(str(tiff_path))
|
||
xsize = ds.RasterXSize
|
||
ysize = ds.RasterYSize
|
||
bands = bands_lst[:3]
|
||
|
||
chans = []
|
||
# 初始化 Alpha 掩膜 (全 True 表示全不透明/有效)
|
||
alpha_mask = np.ones((ysize, xsize), dtype=bool)
|
||
|
||
for b in bands:
|
||
band = ds.GetRasterBand(int(b))
|
||
nodata = band.GetNoDataValue()
|
||
arr = band.ReadAsArray(0, 0, xsize, ysize)
|
||
|
||
# 更新 Alpha 掩膜: 任何波段为 NoData 或 NaN,则该像素透明
|
||
if nodata is not None:
|
||
alpha_mask &= ~np.isclose(arr, nodata)
|
||
|
||
if np.issubdtype(arr.dtype, np.floating):
|
||
alpha_mask &= np.isfinite(arr)
|
||
|
||
chans.append(
|
||
normalize_band(
|
||
arr, method=normalization_method, percentile=percentile, nodata=nodata
|
||
)
|
||
)
|
||
|
||
# 生成包含 Alpha 通道的 RGBA
|
||
alpha_channel = np.where(alpha_mask, 255, 0).astype(np.uint8)
|
||
chans.append(alpha_channel)
|
||
|
||
rgba = np.dstack(chans)
|
||
|
||
# 构建输出路径
|
||
base_path = str(tiff_path)
|
||
if base_path.lower().endswith(".tif"):
|
||
base_path = base_path[:-4]
|
||
|
||
png_path = base_path + output_suffix
|
||
|
||
if save_tiff:
|
||
tiff_output_path = base_path + "_rgb.tif"
|
||
driver = gdal.GetDriverByName("GTiff")
|
||
# 创建 4 波段 (RGBA)
|
||
out_ds = driver.Create(tiff_output_path, xsize, ysize, 4, gdal.GDT_Byte)
|
||
out_ds.SetProjection(ds.GetProjection())
|
||
out_ds.SetGeoTransform(ds.GetGeoTransform())
|
||
for i, chan in enumerate(chans):
|
||
out_ds.GetRasterBand(i + 1).WriteArray(chan)
|
||
out_ds.FlushCache()
|
||
out_ds = None
|
||
|
||
ds = None
|
||
Image.fromarray(rgba, mode="RGBA").save(png_path, format="PNG")
|
||
print(f"已完成RGBA图像转换, 保存至: {png_path}")
|
||
return png_path
|
||
|
||
|
||
def combine_bands_to_rgb(
|
||
red_path: str, green_path: str, blue_path: str, output_path: str
|
||
) -> None:
|
||
"""
|
||
将红, 绿, 蓝三个单波段图像文件合并为 RGBA 图像 (GeoTIFF).
|
||
|
||
Parameters
|
||
----------
|
||
red_path : str
|
||
红色波段文件路径.
|
||
green_path : str
|
||
绿色波段文件路径.
|
||
blue_path : str
|
||
蓝色波段文件路径.
|
||
output_path : str
|
||
输出 RGBA 图像路径.
|
||
"""
|
||
paths = [red_path, green_path, blue_path]
|
||
for path in paths:
|
||
if not os.path.exists(path):
|
||
raise FileNotFoundError(f"文件不存在: {path}")
|
||
|
||
# 读取三个波段数据
|
||
bands_data = []
|
||
ref_band = None
|
||
mask = None
|
||
|
||
for path in paths:
|
||
da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True)
|
||
if ref_band is None:
|
||
ref_band = da
|
||
|
||
# 更新掩膜: masked=True 时 nodata 会被转换为 NaN
|
||
current_mask = np.isfinite(da.values)
|
||
if mask is None:
|
||
mask = current_mask
|
||
else:
|
||
mask &= current_mask
|
||
|
||
# 使用 minmax 方法
|
||
bands_data.append(normalize_band(da.values, method="minmax"))
|
||
|
||
# 不再设置 NoData 值,使用 Alpha 通道表示透明度
|
||
alpha_channel = np.where(mask, 255, 0).astype(np.uint8)
|
||
bands_data.append(alpha_channel)
|
||
|
||
# 合并四个波段为 RGBA 图像
|
||
rgb_array = xr.DataArray(
|
||
np.dstack(bands_data),
|
||
dims=("y", "x", "band"),
|
||
coords={"band": [1, 2, 3, 4], "y": ref_band.y, "x": ref_band.x},
|
||
)
|
||
# 转置维度顺序以符合rioxarray要求
|
||
rgb_array = rgb_array.transpose("band", "y", "x")
|
||
# 写入元数据
|
||
rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True)
|
||
rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True)
|
||
# 保存为TIFF文件
|
||
rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8")
|
||
print(f"RGBA图像已保存到: {output_path}")
|
||
return
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2024\\2024012"
|
||
# red_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif"
|
||
# )
|
||
# green_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.GREEN.subset.tif"
|
||
# )
|
||
# blue_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.BLUE.subset.tif"
|
||
# )
|
||
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif")
|
||
|
||
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2025\\2025011"
|
||
# red_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif"
|
||
# )
|
||
# green_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif"
|
||
# )
|
||
# blue_path = os.path.join(
|
||
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif"
|
||
# )
|
||
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif")
|
||
|
||
# combine_bands_to_rgb(red_path, green_path, blue_path, output_path)
|
||
|
||
imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs")
|
||
for region_name in ["Target1", "Target2", "Target3", "Target4"]:
|
||
for year in range(2015, 2025 + 1):
|
||
img_name = f"S1_S2_{region_name}_{year}.tif"
|
||
img_tif_path = imgs_dir / region_name / img_name
|
||
bands_lst = [3, 2, 1] # GDAL 波段索引从1开始
|
||
# render_image_rgb(img_tif_path, bands_lst, "clip")
|
||
render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)
|