NASA_EarthData_Script/utils/common_utils.py

823 lines
25 KiB
Python

# -*- coding: utf-8 -*-
"""
===============================================================================
本模块为公共工具函数模块
-------------------------------------------------------------------------------
Authors: Hong Xie
Last Updated: 2025-09-11
===============================================================================
"""
import os
import sys
import glob
import json
import logging
from datetime import datetime
import earthaccess
import numpy as np
import pandas as pd
from affine import Affine
from osgeo import gdal, gdal_array
from shapely import box
import xarray as xr
from rasterio.enums import Resampling
from rasterio.merge import merge
from rioxarray.merge import merge_arrays
from rioxarray import open_rasterio
import geopandas as gpd
import matplotlib.pyplot as plt
gdal.UseExceptions()
def get_month_from_filenames(file_path):
"""
从格式化后的文件名中提取月份
args:
file_path (str): 文件路径
returns:
int: 月份
"""
# 获取文件名中的年份与DOY
date = os.path.basename(file_path).split(".")[3]
# 结合年份和DOY, 判断当前文件的月份
month = datetime.strptime(date, "%Y%j").month
return month
def group_by_month(file_list, out_path):
"""
根据文件名中的日期, 将文件按月分组
"""
grouped_files = {}
# 遍历文件列表, 将文件按月分组
for file in file_list:
month = get_month_from_filenames(file)
# 将文件添加到对应月份的列表中
if month not in grouped_files:
grouped_files[month] = []
grouped_files[month].append(file)
# 将字典转换为按月份排序的列表
grouped_files = [grouped_files[month] for month in sorted(grouped_files.keys())]
# 将结果存入json文件中
with open(out_path, "w") as f:
json.dump(grouped_files, f)
return grouped_files
def time_index_from_filenames(file_list):
"""
根据文件名创建时间索引, 用于时间序列分析
"""
return [datetime.strptime(file.split(".")[-4], "%Y%jT%H%M%S") for file in file_list]
def square_kernel(radius):
"""
方形核函数
生成一个边长为 2 * radius + 1 且值全为 1 的方形核
args:
radius: 方形核半径, 使用半径可以确保核的大小为奇数
"""
# 方形核的边长 = 2 * 核半径 + 1
kernel_size = 2 * radius + 1
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
return kernel
def array_calc_to_xr(
arr1: xr.DataArray | xr.Dataset,
arr2: xr.DataArray | xr.Dataset,
func: str,
) -> xr.DataArray:
"""
数组计算 [返回带有属性的xarray.DataArray]
仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算
Parameters
----------
arr1 : xr.DataArray | xr.Dataset
数组1
arr2 : xr.DataArray | xr.Dataset
数组2
func : str
计算函数, 支持 "subtract", "add", "multiply", "divide"
Returns
-------
xr.DataArray
计算结果
"""
if isinstance(arr1, xr.Dataset):
arr1 = arr1.to_array()
if isinstance(arr2, xr.Dataset):
arr2 = arr2.to_array()
# 备份源数据坐标与属性
org_attrs = arr1.attrs
org_dim = arr1.dims
org_coords = arr1.coords
org_crs = arr1.rio.crs if arr1.rio.crs else None
# 将输入数组转换为纯numpy数组, 计算时仅关注数组本身
arr1 = arr1.values.astype(np.float32, copy=False)
arr2 = arr2.values.astype(np.float32, copy=False)
# 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜
arr1[arr1 == -9999] = np.nan
arr2[arr2 == -9999] = np.nan
# 再过滤数组1与数组2均为空的元素
if arr1.shape == arr2.shape:
mask = np.isnan(arr1) & np.isnan(arr2)
arr1[mask] = np.nan
arr2[mask] = np.nan
else:
mask = None
# 根据计算函数进行计算
func = func.lower()
if func == "subtract":
result = arr1 - arr2
elif func == "add":
result = arr1 + arr2
elif func == "multiply":
result = arr1 * arr2
elif func == "divide":
# 若分母为0,则结果为nan
arr2[arr2 == 0] = np.nan
result = arr1 / arr2
else:
raise ValueError("Unsupported operation")
# 释放不再使用的内存
del arr1, arr2, mask
# 恢复因计算而丢失的属性与空间坐标系
result = xr.DataArray(
data=result,
coords=org_coords,
dims=org_dim,
attrs=org_attrs,
)
if org_crs:
result.rio.write_crs(org_crs, inplace=True)
return result
def array_calc(
arr1: xr.DataArray | xr.Dataset | np.ndarray,
arr2: xr.DataArray | xr.Dataset | np.ndarray,
func: str,
) -> np.ndarray:
"""
数组计算 [仅返回数组本身即np.ndarray]
仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算
Parameters
----------
arr1 : xr.DataArray | xr.Dataset | np.ndarray
数组1 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray]
arr2 : xr.DataArray | xr.Dataset | np.ndarray
数组2 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray]
func : str
计算函数, 支持 "subtract", "add", "multiply", "divide"
Returns
-------
np.ndarray
计算结果
"""
# 将输入数组转换为纯numpy数组, 计算时仅关注数组本身
if isinstance(arr1, xr.Dataset):
arr1 = arr1.to_array().values.astype(np.float32, copy=False)
if isinstance(arr2, xr.Dataset):
arr2 = arr2.to_array().values.astype(np.float32, copy=False)
if isinstance(arr1, xr.DataArray):
arr1 = arr1.values.astype(np.float32, copy=False)
if isinstance(arr2, xr.DataArray):
arr2 = arr2.values.astype(np.float32, copy=False)
# 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜
arr1[arr1 == -9999] = np.nan
arr2[arr2 == -9999] = np.nan
# 再过滤数组1与数组2均为空的元素
if arr1.shape == arr2.shape:
if func == "divide":
# 若分母为0,则结果为nan
arr2[arr2 == 0] = np.nan
mask = np.isnan(arr1) & np.isnan(arr2)
arr1[mask] = np.nan
arr2[mask] = np.nan
else:
mask = None
# 根据计算函数进行计算
func = func.lower()
if func == "subtract":
result = arr1 - arr2
elif func == "add":
result = arr1 + arr2
elif func == "multiply":
result = arr1 * arr2
elif func == "divide":
result = arr1 / arr2
else:
raise ValueError("Unsupported operation")
# 释放不再使用的内存
del arr1, arr2, mask
return result
def setup_dask_environment():
"""
Passes RIO environment variables to dask workers for authentication.
"""
import os
import rasterio
cookie_file_path = os.path.expanduser("~/cookies.txt")
global env
gdal_config = {
"GDAL_HTTP_UNSAFESSL": "YES",
"GDAL_HTTP_COOKIEFILE": cookie_file_path,
"GDAL_HTTP_COOKIEJAR": cookie_file_path,
"GDAL_DISABLE_READDIR_ON_OPEN": "YES",
"CPL_VSIL_CURL_ALLOWED_EXTENSIONS": "TIF",
"GDAL_HTTP_MAX_RETRY": "10",
"GDAL_HTTP_RETRY_DELAY": "0.5",
"GDAL_HTTP_TIMEOUT": "300",
}
env = rasterio.Env(**gdal_config)
env.__enter__()
def setup_logging(log_file: str = "dask_worker.log"):
"""
在Dask工作进程中设置logging
Parameters
----------
log_file : str, optional
日志文件路径, by default "dask_worker.log"
"""
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:%(asctime)s ||| %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file),
],
)
def load_band_as_arr(org_tif_path, band_num=1):
"""
读取波段数据
args:
org_tif_path (str): 原始tif文件路径
band_num (int): 波段号, 默认为 1
returns:
numpy.ndarray: 数组化波段数据
"""
org_tif = gdal.Open(org_tif_path)
if not org_tif:
raise ValueError(f"GDAL could not open {org_tif_path}")
band = org_tif.GetRasterBand(band_num)
data = band.ReadAsArray()
# 获取 NoData 值
nodata = band.GetNoDataValue()
if nodata is not None:
# 将 NoData 值替换为 NaN
data[data == nodata] = np.nan
return data
def get_proj_info(org_tif_path):
"""
获取原始影像的投影和变换
args:
org_tif_path (str): 原始tif文件路径
returns:
str: 投影与变换
"""
org_tif = gdal.Open(org_tif_path)
projection = org_tif.GetProjection()
transform = org_tif.GetGeoTransform()
org_tif = None
return projection, transform
def calc_time_series(file_list, calc_method):
"""
时间序列合成
args:
file_list (list): 文件列表
calc_method (str): 计算方法, 包括 "mean", "median", "max", "min"
returns:
numpy.ndarray: 时间序列合成结果
"""
if not file_list:
raise ValueError("file_list is empty.")
calc_method = calc_method.lower()
if calc_method == "mean":
data = np.nanmean([load_band_as_arr(file) for file in file_list], axis=0)
elif calc_method == "median":
data = np.nanmedian([load_band_as_arr(file) for file in file_list], axis=0)
elif calc_method == "max":
data = np.nanmax([load_band_as_arr(file) for file in file_list], axis=0)
elif calc_method == "min":
data = np.nanmin([load_band_as_arr(file) for file in file_list], axis=0)
else:
raise ValueError("Invalid calc_method.")
return data.astype(np.float32)
def save_as_tif(data, projection, transform, file_path):
"""
保存为tif
args:
data (numpy.ndarray): 要保存的数据
projection (str): 投影
transform (str): 变换
file_path (str): 文件输出完整路径
"""
if data is None:
return
y, x = data.shape
gtiff_driver = gdal.GetDriverByName("GTiff")
out_ds = gtiff_driver.Create(
file_path, x, y, 1, gdal.GDT_Float32, options=["COMPRESS=DEFLATE"]
)
out_ds.SetGeoTransform(transform)
out_ds.SetProjection(projection)
out_band = out_ds.GetRasterBand(1)
out_band.WriteArray(data)
out_band.FlushCache()
out_ds = None # 确保文件正确关闭
return
def array_to_raster(
data: np.ndarray, transform, wkt, dtype: str = None, nodata=-9999
) -> gdal.Dataset:
"""
将 numpy 数组转换为 gdal.Dataset 对象
reference: https://github.com/arthur-e/pyl4c/blob/master/pyl4c/spatial.py
args:
data (numpy.ndarray): 待转换的 numpy 数组
transform: (仿射)投影变换矩阵
wkt (str): 投影坐标系信息
dtype (str): 数据类型, 默认为 None
nodata (float): NoData 值, 默认为 -9999
returns:
gdal.Dataset: 转换后的 gdal.Dataset 对象
"""
if dtype is not None:
data = data.astype(dtype)
try:
rast = gdal_array.OpenNumPyArray(data)
except AttributeError:
# For backwards compatibility with older version of GDAL
rast = gdal.Open(gdal_array.GetArrayFilename(data))
except:
rast = gdal_array.OpenArray(data)
rast.SetGeoTransform(transform)
rast.SetProjection(wkt)
if nodata is not None:
for band in range(1, rast.RasterCount + 1):
rast.GetRasterBand(band).SetNoDataValue(nodata)
return rast
def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]):
"""
Uses the Fmask layer and bit numbers to create a binary mask of good pixels.
By default, bits 0-5 are used.
"""
mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))
# Remove/Mask Fill Values and Convert to Integer
quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)
for bit in bit_nums:
# Create a Single Binary Mask Layer
mask_temp = np.array(quality_data) & 1 << bit > 0
mask_array = np.logical_or(mask_array, mask_temp)
return mask_array
def clip_image(
image: xr.DataArray | xr.Dataset, roi: gpd.GeoDataFrame = None, clip_by_box=True
) -> xr.DataArray | xr.Dataset | None:
"""
Clip Image data to ROI.
Parameters
----------
image : xarray.DataArray | xarray.Dataset
通过 rioxarray.open_rasterio 加载的图像数据.
roi : gpd.GeoDataFrame, optional
感兴趣区数据.
clip_by_box : bool, optional
是否使用 bbox 进行裁剪, 默认为 True.
Returns
-------
xarray.DataArray | xarray.Dataset | None
裁剪后的图像数据. 若裁剪后数据全为无效值, 则返回 None.
"""
if roi is None:
return image
org_crs = image.rio.crs
if not roi.crs == org_crs:
# 若 crs 不一致, 则重新投影感兴趣区数据
roi_bound = roi.to_crs(org_crs)
else:
roi_bound = roi
if clip_by_box:
clip_area = [box(*roi_bound.total_bounds)]
else:
clip_area = roi_bound.geometry.values
# 设置nodata值防止裁剪时未相交部分出现nan值 [仅对DataArray有效]
if isinstance(image, xr.DataArray):
nodata_value = -9999
image.rio.write_nodata(nodata_value, inplace=True)
image_cliped = image.rio.clip(
clip_area, roi_bound.crs, all_touched=True, from_disk=True
)
return image_cliped
def clip_roi_image(
file_path: str, grid: gpd.GeoDataFrame = None
) -> xr.DataArray | None:
"""
按研究区范围裁剪影像
Parameters
----------
file_path : str
待裁剪影像路径
grid : gpd.GeoDataFrame, optional
格网范围, 默认为 None.
Returns
-------
raster_cliped : xr.DataArray
裁剪后的影像
"""
raster = open_rasterio(file_path)
try:
doy = os.path.basename(file_path).split(".")[3]
except Exception as e:
doy = None
if doy:
raster.attrs["DOY"] = doy
# 先对数据进行降维, 若存在band波段, 则降维为二维数组; 若不存在band波段, 则继续计算
if "band" in raster.dims and raster.sizes["band"] == 1:
raster = raster.squeeze("band")
# 由于当前实施均在同一格网下进行, 且MODIS数据原始坐标为正弦投影, 在原始影像爬取与预处理阶段使用格网裁剪时并未裁剪出目标效果
# 所以需要先使用格网裁剪, 再后使用感兴趣区域裁剪
# TODO: 跨格网实施算法时, 修改为使用感兴趣区域裁剪
if grid is not None:
raster_cliped = clip_image(raster, grid)
else:
raster_cliped = raster
# TODO: 待完善, 若裁剪后的数据全为空或空值数量大于裁剪后数据总数量, 则跳过
# raster_clip_grid = raster_clip_grid.where(raster_clip_grid != -9999)
# if (
# raster_clip_grid.count().item() == 0
# or raster_clip_grid.isnull().sum() > raster_clip_grid.count().item()
# ):
# return
raster_cliped.attrs = raster.attrs.copy()
raster_cliped.attrs["file_name"] = os.path.basename(file_path)
return raster_cliped
def reproject_image(
image: xr.DataArray,
target_crs: str = None,
target_shape: tuple = None,
target_image: xr.DataArray = None,
) -> xr.DataArray:
"""
Reproject Image data to target CRS or target data.
Parameters
----------
image : xarray.DataArray
通过 rioxarray.open_rasterio 加载的图像数据.
target_crs : str, optional
Target CRS, eg. EPSG:4326.
target_shape : tuple, optional
Target shape, eg. (1000, 1000).
target_image : xarray.DataArray, optional
Target image, eg. rioxarray.open_rasterio 加载的图像数据.
Returns
-------
xarray.DataArray
重投影后的图像数据.
"""
if target_image is not None:
# 使用 target_image 进行重投影匹配
if (
target_image.shape[1] > image.shape[1]
or target_image.shape[2] > image.shape[2]
) or (
target_image.shape[1] == image.shape[1]
and target_image.shape[2] == image.shape[2]
):
# 若判断为降尺度/等尺度, 则直接使用 cubic 双三次插值重采样投影到目标影像
image_reprojed = image.rio.reproject_match(
target_image, resampling=Resampling.cubic
)
else:
# print("target_image shape is not match with image shape", image.shape, "to", target_image.shape)
# 若判断为升尺度, 为减少图像像元丢失, 需要使用聚合类方法, 考虑使用 average/mode/med 重采样投影到目标影像
image_reprojed = image.rio.reproject_match(
target_image, resampling=Resampling.med
)
elif target_crs is not None:
# 使用 target_crs 进行重投影
reproject_kwargs = {
"dst_crs": target_crs,
"resampling": Resampling.cubic,
}
if target_shape is not None:
reproject_kwargs["shape"] = target_shape
image_reprojed = image.rio.reproject(**reproject_kwargs)
else:
# 没有任何重投影参数, 返回原始图像
image_reprojed = image
return image_reprojed
def mosaic_images(
tif_list: list[str | xr.DataArray],
nodata=np.nan,
tif_crs: any = None,
method: str = "first",
) -> xr.DataArray:
"""
影像镶嵌合成
将同一天/同一区域的数据集进行镶嵌, 生成形状为 (1, y, x) 的 numpy 数组, 以及 transform
Parameters
----------
tif_list : list[str | xr.DataArray]
待镶嵌的影像文件路径列表 或已加载的 xarray.DataArray 影像数据列表
nodata : float | int
空值填充值
tif_crs : Any, optional
tif影像的crs
method : str, optional
合成方法, 默认为 "first", 可选 "last", "min", "max"
Returns
-------
ds : xr.DataArray
镶嵌后的影像数据
"""
if isinstance(tif_list[0], str):
ds_np, transform = merge(
tif_list,
nodata=nodata,
method=method,
resampling=Resampling.cubic,
)
# 将结果重新构建为 xarray 数据集
# 单张SAR影像直接读取 transform: 233400.0 30.0 0.0 3463020.0 0.0 -30.0
# 而镶嵌后输出 transform 为: (30.0, 0.0, 221250.0, 0.0, -30.0, 3536970.0)
x_scale, _, x_origin, _, y_scale, y_origin = transform[:6]
affine_transform = Affine.from_gdal(
*(y_origin, x_scale, 0.0, x_origin, 0.0, y_scale)
)
y = np.arange(ds_np.shape[1]) * y_scale + y_origin
x = np.arange(ds_np.shape[2]) * x_scale + x_origin
ds = xr.DataArray(
ds_np,
coords={
"band": np.arange(ds_np.shape[0]),
"y": y,
"x": x,
},
dims=["band", "y", "x"],
)
ds.rio.write_transform(affine_transform, inplace=True)
else:
ds = merge_arrays(tif_list, nodata=nodata, method=method)
if tif_crs is not None:
ds.rio.write_crs(tif_crs, inplace=True)
return ds
def merge_time_series(raster_dir: str) -> xr.Dataset:
"""
合成时间序列
读取指定目录内的所有tif文件, 将它们按照时间方向合并为一个xarray.Dataset
Parameters
----------
raster_dir : str
包含tif文件的目录路径
Returns
-------
raster_dataset : xr.Dataset
包含时间维度的xarray.Dataset
"""
raster_list = []
# 遍历每个文件, 读取时间属性, 并将其作为新波段合并到dataset中
for file in glob.glob(os.path.join(raster_dir, "*.tif")):
date = os.path.basename(file).split(".")[3]
if len(str(date)) < 7:
continue
# 读取影像
raster_data = open_rasterio(file, masked=True).squeeze(dim="band", drop=True)
# 获取时间属性
time_attr = datetime.strptime(date, "%Y%j")
# 将时间属性作为新的维度
raster_data = raster_data.assign_coords(time=time_attr)
# 将新的波段添加到dataset中
raster_list.append(raster_data)
raster_dataset = xr.concat(raster_list, dim="time").sortby("time")
return raster_dataset
def get_value_at_point(
dataset: xr.Dataset, value_name: str, lon: float, lat: float
) -> pd.DataFrame:
"""
从 Dataset 中提取点位置的值, 提取矢量区域所在的值的时间序列
Parameters
----------
dataset : xr.Dataset
包含 x, y, time 维度的数据集
value_name : str
数据集的变量名
lon : float
经度
lat : float
纬度
Returns
-------
df : pd.DataFrame
包含时间序列的 DataFrame
"""
point_values = dataset.sel(x=lon, y=lat, method="nearest").to_dataset(dim="time")
# 将结果转换为 DataFrame
df = (
point_values.to_dataframe()
.reset_index()
.melt(id_vars=["y", "x"], var_name="TIME", value_name=value_name)
.drop(columns=["y", "x"])
)
# 剔除time列为spatial_ref和值为-9999无效值的行
df = df[df["TIME"] != "spatial_ref"]
df = df[df[value_name] != -9999]
# 过滤掉Nan值
df.dropna(subset=[value_name], inplace=True)
df["TIME"] = pd.to_datetime(df["TIME"])
return df
def valid_raster(raster_path, threshold=0.4) -> tuple[str, float] | None:
"""
判断栅格数据是否有效
有效数据占比超过阈值则认为有效, 返回栅格数据路径, 否则返回 None
Parameters
----------
raster_path : str
栅格数据路径
threshold : float
有效数据占比阈值
Returns
-------
raster_path : str
栅格数据路径
valid_data_percentage : float
有效数据占比
"""
with open_rasterio(raster_path, masked=True) as raster:
raster = raster.where(raster != -9999)
all_data_count = raster.size
valid_data_count = raster.count().item()
valid_data_percentage = valid_data_count / all_data_count
if valid_data_percentage >= threshold:
return raster_path, valid_data_percentage
else:
return None
def valid_raster_list(
raster_path_list, threshold=0.4, only_path=False
) -> list[tuple[str, float]]:
"""
判断栅格数据列表是否有效
有效数据占比超过阈值则认为有效, 返回栅格数据路径列表, 否则返回空列表
Parameters
----------
raster_path_list : list[str]
栅格数据路径列表
threshold : float
有效数据占比阈值
only_path : bool, optional
仅保留有效数据路径列表, 默认为 False
Returns
-------
valid_raster_path_list : list[tuple[str, float]] | list[str]
有效栅格数据路径列表, 每个元素为 (栅格数据路径, 有效数据占比) 或栅格数据路径
[]
无效栅格数据
"""
raster_path_list = list(
map(
lambda x: valid_raster(x, threshold),
raster_path_list,
)
)
# 剔除列表中的 None 值
raster_path_list = list(filter(None, raster_path_list))
if only_path:
raster_path_list = list(map(lambda x: x[0], raster_path_list))
return raster_path_list
def match_utm_grid(roi: gpd.GeoDataFrame, mgrs_kml_file: str) -> gpd.GeoDataFrame:
"""
根据 ROI 匹配对应的 ESA Sentinel-2 UTM 格网信息.
NOTE: 需要在 fiona > 1.9 环境下才能运行.
"""
import fiona
# enable KML support which is disabled by default
fiona.drvsupport.supported_drivers["LIBKML"] = "rw"
# bbox = tuple(list(roi.total_bounds))
if not os.path.isfile(mgrs_kml_file):
kml_url = "https://hls.gsfc.nasa.gov/wp-content/uploads/2016/03/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.kml"
earthaccess.download([kml_url], mgrs_kml_file)
# 尽管 geopandas.read_file() 可以直接读取在线资源, 但是由于格网 kml 文件过大 (约 106M), 导致加载会非常慢, 所以先下载到本地再进行读取
mgrs_gdf = gpd.read_file(mgrs_kml_file, roi)
# 空间连接, 筛选网格中与ROI边界相交的部分
grid_in_roi = gpd.sjoin(mgrs_gdf, roi, predicate="intersects", how="left")
# 剔除连接后的产生的冗余属性
grid_in_roi = grid_in_roi[mgrs_gdf.columns].drop(
columns=[
"description",
"timestamp",
"begin",
"end",
"altitudeMode",
"drawOrder",
]
)
# 处理GeometryCollection类型
for i in range(len(grid_in_roi)):
grid = grid_in_roi.iloc[i]
# 将 GeometryCollection 转换为 Polygon
if grid.geometry.geom_type == "GeometryCollection":
grid_in_roi.at[i, "geometry"] = grid.geometry.geoms[0]
return grid_in_roi
def plot(data, title=None, cmap="gray"):
"""
绘制影像图像
args:
data (numpy.ndarray): 要绘制的数据
title (str): 标题
cmap (str): 颜色映射
"""
plt.imshow(data)
plt.title(title)
plt.axis("off") # 关闭坐标轴
plt.colorbar()
plt.show()