diff --git a/utils/sr2rgb.py b/utils/sr2rgb.py index 6934fa2..23a5fb1 100644 --- a/utils/sr2rgb.py +++ b/utils/sr2rgb.py @@ -1,5 +1,5 @@ """ -将 COG 格式的 Red, Green, Blue 单波段地表反射率图像合成 RGB 图像 +将 COG 格式的 Red, Green, Blue 单波段地表反射率图像合成为 RGB 图像 1. 对比度线性拉伸至 0-255; 2. 合并波段; @@ -7,16 +7,18 @@ """ import os -import rasterio as rio +import xarray as xr +from rioxarray import open_rasterio import numpy as np -def sr2rgb(red_path, green_path, blue_path, output_path): + +def sr2rgb(red_path: str, green_path: str, blue_path: str, output_path: str) -> None: """ 将红、绿、蓝三个单波段地表反射率图像合成为RGB图像 - + 参数: red_path (str): 红色波段文件路径 - green_path (str): 绿色波段文件路径 + green_path (str): 绿色波段文件路径 blue_path (str): 蓝色波段文件路径 output_path (str): 输出RGB图像路径 """ @@ -24,74 +26,79 @@ def sr2rgb(red_path, green_path, blue_path, output_path): for path in [red_path, green_path, blue_path]: if not os.path.exists(path): raise FileNotFoundError(f"文件不存在: {path}") - - print(f"正在处理文件:") - print(f" 红波段: {red_path}") - print(f" 绿波段: {green_path}") - print(f" 蓝波段: {blue_path}") - + # 读取三个波段数据 - with rio.open(red_path) as red_src: - red_band = red_src.read(1) - profile = red_src.profile - print(f"红波段形状: {red_band.shape}, 数据类型: {red_band.dtype}") - - with rio.open(green_path) as green_src: - green_band = green_src.read(1) - print(f"绿波段形状: {green_band.shape}, 数据类型: {green_band.dtype}") - - with rio.open(blue_path) as blue_src: - blue_band = blue_src.read(1) - print(f"蓝波段形状: {blue_band.shape}, 数据类型: {blue_band.dtype}") - - # 线性拉伸到0-255范围 + red_band = open_rasterio(red_path, masked=True).squeeze(dim="band", drop=True) + green_band = open_rasterio(green_path, masked=True).squeeze(dim="band", drop=True) + blue_band = open_rasterio(blue_path, masked=True).squeeze(dim="band", drop=True) + # 暂存元数据 + y_coords = red_band.y + x_coords = red_band.x + crs = red_band.rio.crs + transform = red_band.rio.transform() + def stretch_band(band): - # 处理NaN值 + """ + 线性拉伸到0-255范围 + """ + # 处理NaN值与负值 band_no_nan = np.where(np.isnan(band), 0, band) + band_no_nan = np.where(band_no_nan < 0, 0, band_no_nan) band_min = np.min(band_no_nan) band_max = np.max(band_no_nan) - # 避免除零错误 if band_max == band_min: stretched = np.zeros_like(band_no_nan, dtype=np.uint8) else: - stretched = ((band_no_nan - band_min) / (band_max - band_min) * 255).astype(np.uint8) + stretched = ((band_no_nan - band_min) / (band_max - band_min) * 255).astype( + np.uint8 + ) return stretched - - red_stretched = stretch_band(red_band) - green_stretched = stretch_band(green_band) - blue_stretched = stretch_band(blue_band) - + + red_stretched = stretch_band(red_band.values) + green_stretched = stretch_band(green_band.values) + blue_stretched = stretch_band(blue_band.values) # 合并三个波段为RGB图像 - rgb_array = np.dstack((red_stretched, green_stretched, blue_stretched)) - - # 更新元数据 - profile.update( - dtype=rio.uint8, - count=3, - photometric='RGB', - nodata=0 # 设置nodata为0,因为uint8的范围是0-255 + rgb_array = xr.DataArray( + np.dstack((red_stretched, green_stretched, blue_stretched)), + dims=("y", "x", "band"), + coords={"band": [1, 2, 3], "y": y_coords, "x": x_coords}, ) - - # 保存RGB图像 - with rio.open(output_path, 'w', **profile) as dst: - for i in range(3): - dst.write(rgb_array[:, :, i], i + 1) - + # 转置维度顺序以符合rioxarray要求 + rgb_array = rgb_array.transpose("band", "y", "x") + # 写入元数据 + rgb_array.rio.write_crs(crs, inplace=True) + rgb_array.rio.write_transform(transform, inplace=True) + rgb_array.rio.write_nodata(0, inplace=True) + # 保存为TIFF文件 + rgb_array.rio.to_raster(output_path, dtype="uint8") print(f"RGB图像已保存到: {output_path}") + return if __name__ == "__main__": # tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2024\\2024012" - # red_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif") - # green_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.GREEN.subset.tif") - # blue_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.BLUE.subset.tif") + # red_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RED.subset.tif" + # ) + # green_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.GREEN.subset.tif" + # ) + # blue_path = os.path.join( + # tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.BLUE.subset.tif" + # ) # output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2024012T031101.v2.0.RGB.tif") tif_dir = "D:\\NASA_EarthData_Script\\data\\HLS\\2025\\2025011" - red_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif") - green_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif") - blue_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif") + red_path = os.path.join( + tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RED.subset.tif" + ) + green_path = os.path.join( + tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.GREEN.subset.tif" + ) + blue_path = os.path.join( + tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.BLUE.subset.tif" + ) output_path = os.path.join(tif_dir, "HLS.S30.T49RGP.2025011T031009.v2.0.RGB.tif") - - sr2rgb(red_path, green_path, blue_path, output_path) \ No newline at end of file + + sr2rgb(red_path, green_path, blue_path, output_path)