feat(utils): 重构原始影像数据转换RGBA图像的工具函数.

This commit is contained in:
谢泓 2026-01-08 10:13:03 +08:00
parent 7347cd60bb
commit 32d629aede

View File

@ -1,19 +1,26 @@
# -*- coding: utf-8 -*-
""" """
将地表反射率图像转换为 8bit RGB 图像 ===============================================================================
将原始数值影像转换为 8bit RGBA 图像
- Red, Green, Blue 单波段图 - 支持 Red, Green, Blue 等单波段影
- 多波段图 - 支持多波段影
1. 对比度线性拉伸至 0-255; 1. 将原始数据中 NoData 值替换为 NaN, 并将其设置为 Alpha 通道
2. 合并波段; 2. 对比度线性拉伸至 0-255;
3. 保存为 uint8 格式 RGB 图像; 3. 合并包含 Alpha 通道在内的 4 个波段;
4. 保存为 uint8 格式 RGBA 图像
-------------------------------------------------------------------------------
Authors: CVEO Team
Last Updated: 2025-01-08
===============================================================================
""" """
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from osgeo import gdal from osgeo import gdal
@ -82,7 +89,7 @@ def render_image_rgb(
output_suffix: str = "_rgb.png", output_suffix: str = "_rgb.png",
) -> str: ) -> str:
""" """
通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGB 图像. 通用 GDAL 数据源处理函数, 读取多波段图像并转换为 RGBA 图像.
Parameters Parameters
---------- ----------
@ -95,7 +102,7 @@ def render_image_rgb(
percentile : float, optional percentile : float, optional
百分比截断参数 (仅用于 'percentile' 方法), by default 0. 百分比截断参数 (仅用于 'percentile' 方法), by default 0.
save_tiff : bool, optional save_tiff : bool, optional
是否保存为 RGB TIFF 文件, by default True. 是否保存为 RGBA TIFF 文件, by default True.
output_suffix : str, optional output_suffix : str, optional
输出文件后缀, by default "_rgb.png". 输出文件后缀, by default "_rgb.png".
@ -116,17 +123,32 @@ def render_image_rgb(
bands = bands_lst[:3] bands = bands_lst[:3]
chans = [] chans = []
# 初始化 Alpha 掩膜 (全 True 表示全不透明/有效)
alpha_mask = np.ones((ysize, xsize), dtype=bool)
for b in bands: for b in bands:
band = ds.GetRasterBand(int(b)) band = ds.GetRasterBand(int(b))
nodata = band.GetNoDataValue() nodata = band.GetNoDataValue()
arr = band.ReadAsArray(0, 0, xsize, ysize) arr = band.ReadAsArray(0, 0, xsize, ysize)
# 更新 Alpha 掩膜: 任何波段为 NoData 或 NaN则该像素透明
if nodata is not None:
alpha_mask &= ~np.isclose(arr, nodata)
if np.issubdtype(arr.dtype, np.floating):
alpha_mask &= np.isfinite(arr)
chans.append( chans.append(
normalize_band( normalize_band(
arr, method=normalization_method, percentile=percentile, nodata=nodata arr, method=normalization_method, percentile=percentile, nodata=nodata
) )
) )
rgb = np.dstack(chans) # 生成包含 Alpha 通道的 RGBA
alpha_channel = np.where(alpha_mask, 255, 0).astype(np.uint8)
chans.append(alpha_channel)
rgba = np.dstack(chans)
# 构建输出路径 # 构建输出路径
base_path = str(tiff_path) base_path = str(tiff_path)
@ -138,7 +160,8 @@ def render_image_rgb(
if save_tiff: if save_tiff:
tiff_output_path = base_path + "_rgb.tif" tiff_output_path = base_path + "_rgb.tif"
driver = gdal.GetDriverByName("GTiff") driver = gdal.GetDriverByName("GTiff")
out_ds = driver.Create(tiff_output_path, xsize, ysize, 3, gdal.GDT_Byte) # 创建 4 波段 (RGBA)
out_ds = driver.Create(tiff_output_path, xsize, ysize, 4, gdal.GDT_Byte)
out_ds.SetProjection(ds.GetProjection()) out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform()) out_ds.SetGeoTransform(ds.GetGeoTransform())
for i, chan in enumerate(chans): for i, chan in enumerate(chans):
@ -147,8 +170,8 @@ def render_image_rgb(
out_ds = None out_ds = None
ds = None ds = None
Image.fromarray(rgb, mode="RGB").save(png_path, format="PNG") Image.fromarray(rgba, mode="RGBA").save(png_path, format="PNG")
print(f"已完成RGB图像转换, 保存至: {png_path}") print(f"已完成RGBA图像转换, 保存至: {png_path}")
return png_path return png_path
@ -156,7 +179,7 @@ def combine_bands_to_rgb(
red_path: str, green_path: str, blue_path: str, output_path: str red_path: str, green_path: str, blue_path: str, output_path: str
) -> None: ) -> None:
""" """
将红, 绿, 蓝三个单波段图像文件合并为 RGB 图像 (GeoTIFF). 将红, 绿, 蓝三个单波段图像文件合并为 RGBA 图像 (GeoTIFF).
Parameters Parameters
---------- ----------
@ -167,7 +190,7 @@ def combine_bands_to_rgb(
blue_path : str blue_path : str
蓝色波段文件路径. 蓝色波段文件路径.
output_path : str output_path : str
输出 RGB 图像路径. 输出 RGBA 图像路径.
""" """
paths = [red_path, green_path, blue_path] paths = [red_path, green_path, blue_path]
for path in paths: for path in paths:
@ -177,29 +200,41 @@ def combine_bands_to_rgb(
# 读取三个波段数据 # 读取三个波段数据
bands_data = [] bands_data = []
ref_band = None ref_band = None
mask = None
for path in paths: for path in paths:
da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True) da = open_rasterio(path, masked=True).squeeze(dim="band", drop=True)
if ref_band is None: if ref_band is None:
ref_band = da ref_band = da
# 使用 minmax 方法 (对应原 stretch_band)
# 更新掩膜: masked=True 时 nodata 会被转换为 NaN
current_mask = np.isfinite(da.values)
if mask is None:
mask = current_mask
else:
mask &= current_mask
# 使用 minmax 方法
bands_data.append(normalize_band(da.values, method="minmax")) bands_data.append(normalize_band(da.values, method="minmax"))
# 合并三个波段为RGB图像 # 不再设置 NoData 值,使用 Alpha 通道表示透明度
alpha_channel = np.where(mask, 255, 0).astype(np.uint8)
bands_data.append(alpha_channel)
# 合并四个波段为 RGBA 图像
rgb_array = xr.DataArray( rgb_array = xr.DataArray(
np.dstack(bands_data), np.dstack(bands_data),
dims=("y", "x", "band"), dims=("y", "x", "band"),
coords={"band": [1, 2, 3], "y": ref_band.y, "x": ref_band.x}, coords={"band": [1, 2, 3, 4], "y": ref_band.y, "x": ref_band.x},
) )
# 转置维度顺序以符合rioxarray要求 # 转置维度顺序以符合rioxarray要求
rgb_array = rgb_array.transpose("band", "y", "x") rgb_array = rgb_array.transpose("band", "y", "x")
# 写入元数据 # 写入元数据
rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True) rgb_array.rio.write_crs(ref_band.rio.crs, inplace=True)
rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True) rgb_array.rio.write_transform(ref_band.rio.transform(), inplace=True)
rgb_array.rio.write_nodata(0, inplace=True)
# 保存为TIFF文件 # 保存为TIFF文件
rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8") rgb_array.rio.to_raster(output_path, driver="COG", dtype="uint8")
print(f"RGB图像已保存到: {output_path}") print(f"RGBA图像已保存到: {output_path}")
return return
@ -231,10 +266,10 @@ if __name__ == "__main__":
# combine_bands_to_rgb(red_path, green_path, blue_path, output_path) # combine_bands_to_rgb(red_path, green_path, blue_path, output_path)
imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs") imgs_dir = Path(r"D:\CVEOProjects\prjaef\aef-backend-demo\media\imgs")
region_name = "Target4" for region_name in ["Target1", "Target2", "Target3", "Target4"]:
year = 2025 for year in range(2015, 2025 + 1):
img_name = f"S1_S2_{region_name}_{year}.tif" img_name = f"S1_S2_{region_name}_{year}.tif"
img_tif_path = imgs_dir / region_name / img_name img_tif_path = imgs_dir / region_name / img_name
bands_lst = [3, 2, 1] # GDAL 波段索引从1开始 bands_lst = [3, 2, 1] # GDAL 波段索引从1开始
# render_image_rgb(img_tif_path, bands_lst, "clip") # render_image_rgb(img_tif_path, bands_lst, "clip")
render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1) render_image_rgb(img_tif_path, bands_lst, "percentile", percentile=1)