276 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
===============================================================================
将原始数值影像转换为 8bit RGBA 图像
- 支持 Red, Green, Blue 等单波段影像
- 支持多波段影像
1. 将原始数据中 NoData 值替换为 NaN, 并将其设置为 Alpha 通道
2. 对比度线性拉伸至 0-255;
3. 合并包含 Alpha 通道在内的 4 个波段;
4. 保存为 uint8 格式 RGBA 图像
-------------------------------------------------------------------------------
Authors: CVEO Team
Last Updated: 2025-01-08
===============================================================================
"""
import os
import sys
from pathlib import Path
from typing import List, Optional
import numpy as np
from PIL import Image
from osgeo import gdal
import xarray as xr
from rioxarray import open_rasterio
gdal.UseExceptions()
def normalize_band(
band: np.ndarray,
method: str = "minmax",
percentile: float = 2.0,
nodata: Optional[float] = None,
) -> np.ndarray:
"""
将波段数据归一化到 0-255 uint8
Parameters
----------
band: np.ndarray
输入波段数据
method: str, optional
归一化方法 ('minmax', 'percentile', 'clip'), by default "minmax"
percentile: float, optional
百分比截断参数 (仅用于 'percentile' 方法), by default 2.0
nodata: float, optional
NoData 值, 如果不为 None, 则将该值替换为 NaN, by default None
"""
band = np.asarray(band, dtype=np.float32)
# 处理 NoData 值
if nodata is not None:
band[np.isclose(band, nodata)] = np.nan
if method == "clip":
# 直接截断并转换
band[~np.isfinite(band)] = 0.0
return np.clip(band, 0.0, 255.0).astype(np.uint8)
if method == "percentile":
# 百分比截断并线性拉伸
lower = np.nanpercentile(band, percentile)
upper = np.nanpercentile(band, 100 - percentile)
else: # minmax
# 使用0值处理NaN和负值
band = np.where(np.isnan(band) | (band < 0), 0, band)
lower = np.min(band)
upper = np.max(band)
if upper == lower:
return np.zeros_like(band, dtype=np.uint8)
stretched = (band - lower) / (upper - lower) * 255.0
# 处理拉伸后的 NaN 值 (原始数据中的 NaN 会传播)
stretched[~np.isfinite(stretched)] = 0.0
return np.clip(stretched, 0, 255).astype(np.uint8)
def render_image_rgb(
tiff_path: Path,
bands_lst: List[int],
normalization_method: str,
percentile: float = 0,
save_tiff: bool = True,
output_suffix: str = "_rgb.png",
) -> str:
"""
通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGBA 图像.
Parameters
----------
tiff_path : Path
输入 TIFF 图像路径.
bands_lst : List[int]
RGB 波段索引列表 (GDAL 索引, 从 1 开始).
normalization_method : str
归一化方法 ('minmax', 'percentile', 'clip').
percentile : float, optional
百分比截断参数 (仅用于 'percentile' 方法), by default 0.
save_tiff : bool, optional
是否保存为 RGBA TIFF 文件, by default True.
output_suffix : str, optional
输出文件后缀, by default "_rgb.png".
Returns
-------
str
生成的 PNG 文件路径.
"""
if not tiff_path.exists():
raise FileNotFoundError(f"input not found: {tiff_path}")
if len(bands_lst) < 3:
raise ValueError("rgb need at least 3 bands")
ds = gdal.Open(str(tiff_path))
xsize = ds.RasterXSize
ysize = ds.RasterYSize
bands = bands_lst[:3]
chans = []
# 初始化 Alpha 掩膜 (全 True 表示全不透明/有效)
alpha_mask = np.ones((ysize, xsize), dtype=bool)
for b in bands:
band = ds.GetRasterBand(int(b))
nodata = band.GetNoDataValue()
arr = band.ReadAsArray(0, 0, xsize, ysize)
# 更新 Alpha 掩膜: 任何波段为 NoData 或 NaN则该像素透明
if nodata is not None:
alpha_mask &= ~np.isclose(arr, nodata)
if np.issubdtype(arr.dtype, np.floating):
alpha_mask &= np.isfinite(arr)
chans.append(
normalize_band(
arr, method=normalization_method, percentile=percentile, nodata=nodata
)
)
# 生成包含 Alpha 通道的 RGBA
alpha_channel = np.where(alpha_mask, 255, 0).astype(np.uint8)
chans.append(alpha_channel)
rgba = np.dstack(chans)
# 构建输出路径
base_path = str(tiff_path)
if base_path.lower().endswith(".tif"):
base_path = base_path[:-4]
png_path = base_path + output_suffix
if save_tiff:
tiff_output_path = base_path + "_rgb.tif"
driver = gdal.GetDriverByName("GTiff")
# 创建 4 波段 (RGBA)
out_ds = driver.Create(tiff_output_path, xsize, ysize, 4, gdal.GDT_Byte)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i, chan in enumerate(chans):
out_ds.GetRasterBand(i + 1).WriteArray(chan)
out_ds.FlushCache()
out_ds = None
ds = None
Image.fromarray(rgba, mode="RGBA").save(png_path, format="PNG")
print(f"已完成RGBA图像转换, 保存至: {png_path}")
return png_path
def combine_bands_to_rgb(
red_path: str, green_path: str, blue_path: str, output_path: str
) -> None:
"""
将红, 绿, 蓝三个单波段图像文件合并为 RGBA 图像 (GeoTIFF).
Parameters
----------
red_path : str
红色波段文件路径.
green_path : str
绿色波段文件路径.
blue_path : str
蓝色波段文件路径.
output_path : str
输出 RGBA 图像路径.
"""
paths = [red_path, green_path, blue_path]
for path in paths:
if not os.path.exists(path):
raise FileNotFoundError(f"文件不存在: {path}")
# 读取三个波段数据
bands_data = []
ref_band = None
mask = None
for path in paths:
da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True)
if ref_band is None:
ref_band = da
# 更新掩膜: masked=True 时 nodata 会被转换为 NaN
current_mask = np.isfinite(da.values)
if mask is None:
mask = current_mask
else:
mask &= current_mask
# 使用 minmax 方法
bands_data.append(normalize_band(da.values, method="minmax"))
# 不再设置 NoData 值,使用 Alpha 通道表示透明度
alpha_channel = np.where(mask, 255, 0).astype(np.uint8)
bands_data.append(alpha_channel)
# 合并四个波段为 RGBA 图像
rgb_array = xr.DataArray(
np.dstack(bands_data),
dims=("y", "x", "band"),
coords={"band": [1, 2, 3, 4], "y": ref_band.y, "x": ref_band.x},
)
# 转置维度顺序以符合rioxarray要求
rgb_array = rgb_array.transpose("band", "y", "x")
# 写入元数据
rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True)
rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True)
# 保存为TIFF文件
rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8")
print(f"RGBA图像已保存到: {output_path}")
return
if __name__ == "__main__":
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2024\\2024012"
# red_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif"
# )
# green_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.GREEN.subset.tif"
# )
# blue_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.BLUE.subset.tif"
# )
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif")
# tif_dir = "D:\\Open_EarthData_Tools\\data\\HLS\\2025\\2025011"
# red_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif"
# )
# green_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif"
# )
# blue_path = os.path.join(
# tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif"
# )
# output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif")
# combine_bands_to_rgb(red_path, green_path, blue_path, output_path)
imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs")
for region_name in ["Target1", "Target2", "Target3", "Target4"]:
for year in range(2015, 2025 + 1):
img_name = f"S1_S2_{region_name}_{year}.tif"
img_tif_path = imgs_dir / region_name / img_name
bands_lst = [3, 2, 1] # GDAL 波段索引从1开始
# render_image_rgb(img_tif_path, bands_lst, "clip")
render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)