# -*- coding: utf-8 -*- """ =============================================================================== 本模块为公共工具函数模块 ------------------------------------------------------------------------------- Authors: Hong Xie Last Updated: 2025-08-13 =============================================================================== """ import os import sys import glob import json import logging from datetime import datetime import earthaccess import numpy as np import pandas as pd from affine import Affine from osgeo import gdal, gdal_array from shapely import box import xarray as xr from rasterio.enums import Resampling from rasterio.merge import merge from rioxarray.merge import merge_arrays from rioxarray import open_rasterio import geopandas as gpd import matplotlib.pyplot as plt gdal.UseExceptions() def get_month_from_filenames(file_path): """ 从格式化后的文件名中提取月份 args: file_path (str): 文件路径 returns: int: 月份 """ # 获取文件名中的年份与DOY date = os.path.basename(file_path).split(".")[3] # 结合年份和DOY, 判断当前文件的月份 month = datetime.strptime(date, "%Y%j").month return month def group_by_month(file_list, out_path): """ 根据文件名中的日期, 将文件按月分组 """ grouped_files = {} # 遍历文件列表, 将文件按月分组 for file in file_list: month = get_month_from_filenames(file) # 将文件添加到对应月份的列表中 if month not in grouped_files: grouped_files[month] = [] grouped_files[month].append(file) # 将字典转换为按月份排序的列表 grouped_files = [grouped_files[month] for month in sorted(grouped_files.keys())] # 将结果存入json文件中 with open(out_path, "w") as f: json.dump(grouped_files, f) return grouped_files def time_index_from_filenames(file_list): """ 根据文件名创建时间索引, 用于时间序列分析 """ return [datetime.strptime(file.split(".")[-4], "%Y%jT%H%M%S") for file in file_list] def square_kernel(radius): """ 方形核函数 生成一个边长为 2 * radius + 1 且值全为 1 的方形核 args: radius: 方形核半径, 使用半径可以确保核的大小为奇数 """ # 方形核的边长 = 2 * 核半径 + 1 kernel_size = 2 * radius + 1 kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) return kernel def array_calc_to_xr( arr1: xr.DataArray | xr.Dataset, arr2: xr.DataArray | xr.Dataset, func: str, ) -> xr.DataArray: """ 数组计算 [返回带有属性的xarray.DataArray] 仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算 Parameters ---------- arr1 : xr.DataArray | xr.Dataset 数组1 arr2 : xr.DataArray | xr.Dataset 数组2 func : str 计算函数, 支持 "subtract", "add", "multiply", "divide" Returns ------- xr.DataArray 计算结果 """ if isinstance(arr1, xr.Dataset): arr1 = arr1.to_array() if isinstance(arr2, xr.Dataset): arr2 = arr2.to_array() # 备份源数据坐标与属性 org_attrs = arr1.attrs org_dim = arr1.dims org_coords = arr1.coords org_crs = arr1.rio.crs if arr1.rio.crs else None # 将输入数组转换为纯numpy数组, 计算时仅关注数组本身 arr1 = arr1.values.astype(np.float32, copy=False) arr2 = arr2.values.astype(np.float32, copy=False) # 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜 arr1[arr1 == -9999] = np.nan arr2[arr2 == -9999] = np.nan # 再过滤数组1与数组2均为空的元素 if arr1.shape == arr2.shape: mask = np.isnan(arr1) & np.isnan(arr2) arr1[mask] = np.nan arr2[mask] = np.nan else: mask = None # 根据计算函数进行计算 func = func.lower() if func == "subtract": result = arr1 - arr2 elif func == "add": result = arr1 + arr2 elif func == "multiply": result = arr1 * arr2 elif func == "divide": # 若分母为0,则结果为nan arr2[arr2 == 0] = np.nan result = arr1 / arr2 else: raise ValueError("Unsupported operation") # 释放不再使用的内存 del arr1, arr2, mask # 恢复因计算而丢失的属性与空间坐标系 result = xr.DataArray( data=result, coords=org_coords, dims=org_dim, attrs=org_attrs, ) if org_crs: result.rio.write_crs(org_crs, inplace=True) return result def array_calc( arr1: xr.DataArray | xr.Dataset | np.ndarray, arr2: xr.DataArray | xr.Dataset | np.ndarray, func: str, ) -> np.ndarray: """ 数组计算 [仅返回数组本身即np.ndarray] 仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算 Parameters ---------- arr1 : xr.DataArray | xr.Dataset | np.ndarray 数组1 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray] arr2 : xr.DataArray | xr.Dataset | np.ndarray 数组2 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray] func : str 计算函数, 支持 "subtract", "add", "multiply", "divide" Returns ------- np.ndarray 计算结果 """ # 将输入数组转换为纯numpy数组, 计算时仅关注数组本身 if isinstance(arr1, xr.Dataset): arr1 = arr1.to_array().values.astype(np.float32, copy=False) if isinstance(arr2, xr.Dataset): arr2 = arr2.to_array().values.astype(np.float32, copy=False) if isinstance(arr1, xr.DataArray): arr1 = arr1.values.astype(np.float32, copy=False) if isinstance(arr2, xr.DataArray): arr2 = arr2.values.astype(np.float32, copy=False) # 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜 arr1[arr1 == -9999] = np.nan arr2[arr2 == -9999] = np.nan # 再过滤数组1与数组2均为空的元素 if arr1.shape == arr2.shape: if func == "divide": # 若分母为0,则结果为nan arr2[arr2 == 0] = np.nan mask = np.isnan(arr1) & np.isnan(arr2) arr1[mask] = np.nan arr2[mask] = np.nan else: mask = None # 根据计算函数进行计算 func = func.lower() if func == "subtract": result = arr1 - arr2 elif func == "add": result = arr1 + arr2 elif func == "multiply": result = arr1 * arr2 elif func == "divide": result = arr1 / arr2 else: raise ValueError("Unsupported operation") # 释放不再使用的内存 del arr1, arr2, mask return result def setup_dask_environment(): """ Passes RIO environment variables to dask workers for authentication. """ import os import rasterio cookie_file_path = os.path.expanduser("~/cookies.txt") global env gdal_config = { "GDAL_HTTP_UNSAFESSL": "YES", "GDAL_HTTP_COOKIEFILE": cookie_file_path, "GDAL_HTTP_COOKIEJAR": cookie_file_path, "GDAL_DISABLE_READDIR_ON_OPEN": "YES", "CPL_VSIL_CURL_ALLOWED_EXTENSIONS": "TIF", "GDAL_HTTP_MAX_RETRY": "10", "GDAL_HTTP_RETRY_DELAY": "0.5", "GDAL_HTTP_TIMEOUT": "300", } env = rasterio.Env(**gdal_config) env.__enter__() def setup_logging(log_file: str = "dask_worker.log"): """ 在Dask工作进程中设置logging Parameters ---------- log_file : str, optional 日志文件路径, by default "dask_worker.log" """ logging.basicConfig( level=logging.INFO, format="%(levelname)s:%(asctime)s ||| %(message)s", handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler(log_file), ], ) def load_band_as_arr(org_tif_path, band_num=1): """ 读取波段数据 args: org_tif_path (str): 原始tif文件路径 band_num (int): 波段号, 默认为 1 returns: numpy.ndarray: 数组化波段数据 """ org_tif = gdal.Open(org_tif_path) if not org_tif: raise ValueError(f"GDAL could not open {org_tif_path}") band = org_tif.GetRasterBand(band_num) data = band.ReadAsArray() # 获取 NoData 值 nodata = band.GetNoDataValue() if nodata is not None: # 将 NoData 值替换为 NaN data[data == nodata] = np.nan return data def get_proj_info(org_tif_path): """ 获取原始影像的投影和变换 args: org_tif_path (str): 原始tif文件路径 returns: str: 投影与变换 """ org_tif = gdal.Open(org_tif_path) projection = org_tif.GetProjection() transform = org_tif.GetGeoTransform() org_tif = None return projection, transform def calc_time_series(file_list, calc_method): """ 时间序列合成 args: file_list (list): 文件列表 calc_method (str): 计算方法, 包括 "mean", "median", "max", "min" returns: numpy.ndarray: 时间序列合成结果 """ if not file_list: raise ValueError("file_list is empty.") calc_method = calc_method.lower() if calc_method == "mean": data = np.nanmean([load_band_as_arr(file) for file in file_list], axis=0) elif calc_method == "median": data = np.nanmedian([load_band_as_arr(file) for file in file_list], axis=0) elif calc_method == "max": data = np.nanmax([load_band_as_arr(file) for file in file_list], axis=0) elif calc_method == "min": data = np.nanmin([load_band_as_arr(file) for file in file_list], axis=0) else: raise ValueError("Invalid calc_method.") return data.astype(np.float32) def save_as_tif(data, projection, transform, file_path): """ 保存为tif args: data (numpy.ndarray): 要保存的数据 projection (str): 投影 transform (str): 变换 file_path (str): 文件输出完整路径 """ if data is None: return y, x = data.shape gtiff_driver = gdal.GetDriverByName("GTiff") out_ds = gtiff_driver.Create( file_path, x, y, 1, gdal.GDT_Float32, options=["COMPRESS=DEFLATE"] ) out_ds.SetGeoTransform(transform) out_ds.SetProjection(projection) out_band = out_ds.GetRasterBand(1) out_band.WriteArray(data) out_band.FlushCache() out_ds = None # 确保文件正确关闭 return def array_to_raster( data: np.ndarray, transform, wkt, dtype: str = None, nodata=-9999 ) -> gdal.Dataset: """ 将 numpy 数组转换为 gdal.Dataset 对象 reference: https://github.com/arthur-e/pyl4c/blob/master/pyl4c/spatial.py args: data (numpy.ndarray): 待转换的 numpy 数组 transform: (仿射)投影变换矩阵 wkt (str): 投影坐标系信息 dtype (str): 数据类型, 默认为 None nodata (float): NoData 值, 默认为 -9999 returns: gdal.Dataset: 转换后的 gdal.Dataset 对象 """ if dtype is not None: data = data.astype(dtype) try: rast = gdal_array.OpenNumPyArray(data) except AttributeError: # For backwards compatibility with older version of GDAL rast = gdal.Open(gdal_array.GetArrayFilename(data)) except: rast = gdal_array.OpenArray(data) rast.SetGeoTransform(transform) rast.SetProjection(wkt) if nodata is not None: for band in range(1, rast.RasterCount + 1): rast.GetRasterBand(band).SetNoDataValue(nodata) return rast def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]): """ Uses the Fmask layer and bit numbers to create a binary mask of good pixels. By default, bits 0-5 are used. """ mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1])) # Remove/Mask Fill Values and Convert to Integer quality_data = np.nan_to_num(quality_data, 0).astype(np.int8) for bit in bit_nums: # Create a Single Binary Mask Layer mask_temp = np.array(quality_data) & 1 << bit > 0 mask_array = np.logical_or(mask_array, mask_temp) return mask_array def clip_image( image: xr.DataArray | xr.Dataset, roi: gpd.GeoDataFrame, clip_by_box=True ): """ Clip Image data to ROI. args: image (xarray.DataArray | xarray.Dataset): 通过 rioxarray.open_rasterio 加载的图像数据. roi (gpd.GeoDataFrame): 感兴趣区数据. clip_by_box (bool): 是否使用 bbox 进行裁剪, 默认为 True. """ if roi is None: return image org_crs = image.rio.crs if not roi.crs == org_crs: # 若 crs 不一致, 则重新投影感兴趣区数据 roi_bound = roi.to_crs(org_crs) else: roi_bound = roi if clip_by_box: clip_area = [box(*roi_bound.total_bounds)] else: clip_area = roi_bound.geometry.values # 设置nodata值防止裁剪时未相交部分出现nan值 [仅对DataArray有效] if isinstance(image, xr.DataArray): nodata_value = -9999 image.rio.write_nodata(nodata_value, inplace=True) image_cliped = image.rio.clip( clip_area, roi_bound.crs, all_touched=True, from_disk=True ) return image_cliped def clip_roi_image(file_path: str, grid: gpd.GeoDataFrame) -> xr.DataArray | None: """ 按研究区范围裁剪影像 args: file_path (str): 待裁剪影像路径 grid (gpd.GeoDataFrame): 格网范围 return: raster_cliped (xr.DataArray): 裁剪后的影像 """ raster = open_rasterio(file_path) try: doy = os.path.basename(file_path).split(".")[3] except Exception as e: doy = None if doy: raster.attrs["DOY"] = doy # 先对数据进行降维, 若存在band波段, 则降维为二维数组; 若不存在band波段, 则继续计算 if "band" in raster.dims and raster.sizes["band"] == 1: raster = raster.squeeze("band") # 由于当前实施均在同一格网下进行, 且MODIS数据原始坐标为正弦投影, 在原始影像爬取与预处理阶段使用格网裁剪时并未裁剪出目标效果 # 所以需要先使用格网裁剪, 再后使用感兴趣区域裁剪 # TODO: 跨格网实施算法时, 修改为使用感兴趣区域裁剪 if grid is not None: raster_cliped = clip_image(raster, grid) else: raster_cliped = raster # TODO: 待完善, 若裁剪后的数据全为空或空值数量大于裁剪后数据总数量, 则跳过 # raster_clip_grid = raster_clip_grid.where(raster_clip_grid != -9999) # if ( # raster_clip_grid.count().item() == 0 # or raster_clip_grid.isnull().sum() > raster_clip_grid.count().item() # ): # return raster_cliped.attrs = raster.attrs.copy() raster_cliped.attrs["file_name"] = os.path.basename(file_path) return raster_cliped def reproject_image( image: xr.DataArray, target_crs: str = None, target_shape: tuple = None, target_image: xr.DataArray = None, ): """ Reproject Image data to target CRS or target data. args: image (xarray.DataArray): 通过 rioxarray.open_rasterio 加载的图像数据. target_crs (str): Target CRS, eg. EPSG:4326. target_shape (tuple): Target shape, eg. (1000, 1000). target_image (xarray.DataArray): Target image, eg. rioxarray.open_rasterio 加载的图像数据. """ if target_image is not None: # 使用 target_image 进行重投影匹配 if ( target_image.shape[1] > image.shape[1] or target_image.shape[2] > image.shape[2] ) or ( target_image.shape[1] == image.shape[1] and target_image.shape[2] == image.shape[2] ): # 若判断为降尺度/等尺度, 则直接使用 cubic 重采样投影到目标影像 image_reprojed = image.rio.reproject_match( target_image, resampling=Resampling.cubic ) else: # print("target_image shape is not match with image shape", image.shape, "to", target_image.shape) # 若判断为升尺度, 为减少图像像元丢失, 需要使用聚合类方法, 考虑使用 average/mode/med 重采样投影到目标影像 image_reprojed = image.rio.reproject_match( target_image, resampling=Resampling.med ) elif target_crs is not None: # 使用 target_crs 进行重投影 reproject_kwargs = { "dst_crs": target_crs, "resampling": Resampling.cubic, } if target_shape is not None: reproject_kwargs["shape"] = target_shape image_reprojed = image.rio.reproject(**reproject_kwargs) else: # 没有任何重投影参数, 返回原始图像 image_reprojed = image return image_reprojed def mosaic_images( tif_list: list[str | xr.DataArray], nodata=np.nan, tif_crs: any = None, method: str = "first", ) -> xr.DataArray: """ 影像镶嵌合成 将同一天/同一区域的数据集进行镶嵌, 生成形状为 (1, y, x) 的 numpy 数组, 以及 transform Parameters ---------- tif_list : list[str | xr.DataArray] 待镶嵌的影像文件路径列表 或已加载的 xarray.DataArray 影像数据列表 nodata : float | int 空值填充值 tif_crs : Any, optional tif影像的crs method : str, optional 合成方法, 默认为 "first", 可选 "last", "min", "max" Returns ------- ds : xr.DataArray 镶嵌后的影像数据 """ if isinstance(tif_list[0], str): ds_np, transform = merge( tif_list, nodata=nodata, method=method, resampling=Resampling.cubic, ) # 将结果重新构建为 xarray 数据集 # 单张SAR影像直接读取 transform: 233400.0 30.0 0.0 3463020.0 0.0 -30.0 # 而镶嵌后输出 transform 为: (30.0, 0.0, 221250.0, 0.0, -30.0, 3536970.0) x_scale, _, x_origin, _, y_scale, y_origin = transform[:6] affine_transform = Affine.from_gdal( *(y_origin, x_scale, 0.0, x_origin, 0.0, y_scale) ) y = np.arange(ds_np.shape[1]) * y_scale + y_origin x = np.arange(ds_np.shape[2]) * x_scale + x_origin ds = xr.DataArray( ds_np, coords={ "band": np.arange(ds_np.shape[0]), "y": y, "x": x, }, dims=["band", "y", "x"], ) ds.rio.write_transform(affine_transform, inplace=True) else: ds = merge_arrays(tif_list, nodata=nodata, method=method) if tif_crs is not None: ds.rio.write_crs(tif_crs, inplace=True) return ds def merge_time_series(raster_dir: str) -> xr.Dataset: """ 合成时间序列 读取指定目录内的所有tif文件, 将它们按照时间方向合并为一个xarray.Dataset Parameters ---------- raster_dir : str 包含tif文件的目录路径 Returns ------- raster_dataset : xr.Dataset 包含时间维度的xarray.Dataset """ raster_list = [] # 遍历每个文件, 读取时间属性, 并将其作为新波段合并到dataset中 for file in glob.glob(os.path.join(raster_dir, "*.tif")): date = os.path.basename(file).split(".")[3] if len(str(date)) < 7: continue # 读取影像 raster_data = open_rasterio(file, masked=True).squeeze(dim="band", drop=True) # 获取时间属性 time_attr = datetime.strptime(date, "%Y%j") # 将时间属性作为新的维度 raster_data = raster_data.assign_coords(time=time_attr) # 将新的波段添加到dataset中 raster_list.append(raster_data) raster_dataset = xr.concat(raster_list, dim="time").sortby("time") return raster_dataset def get_value_at_point( dataset: xr.Dataset, value_name: str, lon: float, lat: float ) -> pd.DataFrame: """ 从 Dataset 中提取点位置的值, 提取矢量区域所在的值的时间序列 Parameters ---------- dataset : xr.Dataset 包含 x, y, time 维度的数据集 value_name : str 数据集的变量名 lon : float 经度 lat : float 纬度 Returns ------- df : pd.DataFrame 包含时间序列的 DataFrame """ point_values = dataset.sel(x=lon, y=lat, method="nearest").to_dataset(dim="time") # 将结果转换为 DataFrame df = ( point_values.to_dataframe() .reset_index() .melt(id_vars=["y", "x"], var_name="TIME", value_name=value_name) .drop(columns=["y", "x"]) ) # 剔除time列为spatial_ref和值为-9999无效值的行 df = df[df["TIME"] != "spatial_ref"] df = df[df[value_name] != -9999] # 过滤掉Nan值 df.dropna(subset=[value_name], inplace=True) df["TIME"] = pd.to_datetime(df["TIME"]) return df def valid_raster(raster_path, threshold=0.4) -> tuple[str, float] | None: """ 判断栅格数据是否有效 有效数据占比超过阈值则认为有效, 返回栅格数据路径, 否则返回 None Parameters ---------- raster_path : str 栅格数据路径 threshold : float 有效数据占比阈值 Returns ------- raster_path : str 栅格数据路径 valid_data_percentage : float 有效数据占比 """ with open_rasterio(raster_path, masked=True) as raster: raster = raster.where(raster != -9999) all_data_count = raster.size valid_data_count = raster.count().item() valid_data_percentage = valid_data_count / all_data_count if valid_data_percentage >= threshold: return raster_path, valid_data_percentage else: return None def valid_raster_list( raster_path_list, threshold=0.4, only_path=False ) -> list[tuple[str, float]]: """ 判断栅格数据列表是否有效 有效数据占比超过阈值则认为有效, 返回栅格数据路径列表, 否则返回空列表 Parameters ---------- raster_path_list : list[str] 栅格数据路径列表 threshold : float 有效数据占比阈值 only_path : bool, optional 仅保留有效数据路径列表, 默认为 False Returns ------- valid_raster_path_list : list[tuple[str, float]] | list[str] 有效栅格数据路径列表, 每个元素为 (栅格数据路径, 有效数据占比) 或栅格数据路径 [] 无效栅格数据 """ raster_path_list = list( map( lambda x: valid_raster(x, threshold), raster_path_list, ) ) # 剔除列表中的 None 值 raster_path_list = list(filter(None, raster_path_list)) if only_path: raster_path_list = list(map(lambda x: x[0], raster_path_list)) return raster_path_list def match_utm_grid(roi: gpd.GeoDataFrame, mgrs_kml_file: str) -> gpd.GeoDataFrame: """ 根据 ROI 匹配对应的 ESA Sentinel-2 UTM 格网信息. NOTE: 需要在 fiona > 1.9 环境下才能运行. """ import fiona # enable KML support which is disabled by default fiona.drvsupport.supported_drivers["LIBKML"] = "rw" # bbox = tuple(list(roi.total_bounds)) if not os.path.isfile(mgrs_kml_file): kml_url = "https://hls.gsfc.nasa.gov/wp-content/uploads/2016/03/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.kml" earthaccess.download([kml_url], mgrs_kml_file) # 尽管 geopandas.read_file() 可以直接读取在线资源, 但是由于格网 kml 文件过大 (约 106M), 导致加载会非常慢, 所以先下载到本地再进行读取 mgrs_gdf = gpd.read_file(mgrs_kml_file, roi) # 空间连接, 筛选网格中与ROI边界相交的部分 grid_in_roi = gpd.sjoin(mgrs_gdf, roi, predicate="intersects", how="left") # 剔除连接后的产生的冗余属性 grid_in_roi = grid_in_roi[mgrs_gdf.columns].drop( columns=[ "description", "timestamp", "begin", "end", "altitudeMode", "drawOrder", ] ) # 处理GeometryCollection类型 for i in range(len(grid_in_roi)): grid = grid_in_roi.iloc[i] # 将 GeometryCollection 转换为 Polygon if grid.geometry.geom_type == "GeometryCollection": grid_in_roi.at[i, "geometry"] = grid.geometry.geoms[0] return grid_in_roi def plot(data, title=None, cmap="gray"): """ 绘制影像图像 args: data (numpy.ndarray): 要绘制的数据 title (str): 标题 cmap (str): 颜色映射 """ plt.imshow(data) plt.title(title) plt.axis("off") # 关闭坐标轴 plt.colorbar() plt.show()