823 lines
25 KiB
Python
823 lines
25 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
===============================================================================
|
|
本模块为公共工具函数模块
|
|
|
|
-------------------------------------------------------------------------------
|
|
Authors: Hong Xie
|
|
Last Updated: 2025-09-11
|
|
===============================================================================
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import glob
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
import earthaccess
|
|
import numpy as np
|
|
import pandas as pd
|
|
from affine import Affine
|
|
from osgeo import gdal, gdal_array
|
|
from shapely import box
|
|
import xarray as xr
|
|
from rasterio.enums import Resampling
|
|
from rasterio.merge import merge
|
|
from rioxarray.merge import merge_arrays
|
|
from rioxarray import open_rasterio
|
|
import geopandas as gpd
|
|
import matplotlib.pyplot as plt
|
|
|
|
gdal.UseExceptions()
|
|
|
|
|
|
def get_month_from_filenames(file_path):
|
|
"""
|
|
从格式化后的文件名中提取月份
|
|
|
|
args:
|
|
file_path (str): 文件路径
|
|
returns:
|
|
int: 月份
|
|
"""
|
|
# 获取文件名中的年份与DOY
|
|
date = os.path.basename(file_path).split(".")[3]
|
|
# 结合年份和DOY, 判断当前文件的月份
|
|
month = datetime.strptime(date, "%Y%j").month
|
|
return month
|
|
|
|
|
|
def group_by_month(file_list, out_path):
|
|
"""
|
|
根据文件名中的日期, 将文件按月分组
|
|
"""
|
|
grouped_files = {}
|
|
# 遍历文件列表, 将文件按月分组
|
|
for file in file_list:
|
|
month = get_month_from_filenames(file)
|
|
# 将文件添加到对应月份的列表中
|
|
if month not in grouped_files:
|
|
grouped_files[month] = []
|
|
grouped_files[month].append(file)
|
|
# 将字典转换为按月份排序的列表
|
|
grouped_files = [grouped_files[month] for month in sorted(grouped_files.keys())]
|
|
# 将结果存入json文件中
|
|
with open(out_path, "w") as f:
|
|
json.dump(grouped_files, f)
|
|
return grouped_files
|
|
|
|
|
|
def time_index_from_filenames(file_list):
|
|
"""
|
|
根据文件名创建时间索引, 用于时间序列分析
|
|
"""
|
|
return [datetime.strptime(file.split(".")[-4], "%Y%jT%H%M%S") for file in file_list]
|
|
|
|
|
|
def square_kernel(radius):
|
|
"""
|
|
方形核函数
|
|
|
|
生成一个边长为 2 * radius + 1 且值全为 1 的方形核
|
|
|
|
args:
|
|
radius: 方形核半径, 使用半径可以确保核的大小为奇数
|
|
"""
|
|
# 方形核的边长 = 2 * 核半径 + 1
|
|
kernel_size = 2 * radius + 1
|
|
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
|
|
return kernel
|
|
|
|
|
|
def array_calc_to_xr(
|
|
arr1: xr.DataArray | xr.Dataset,
|
|
arr2: xr.DataArray | xr.Dataset,
|
|
func: str,
|
|
) -> xr.DataArray:
|
|
"""
|
|
数组计算 [返回带有属性的xarray.DataArray]
|
|
|
|
仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算
|
|
|
|
Parameters
|
|
----------
|
|
arr1 : xr.DataArray | xr.Dataset
|
|
数组1
|
|
arr2 : xr.DataArray | xr.Dataset
|
|
数组2
|
|
func : str
|
|
计算函数, 支持 "subtract", "add", "multiply", "divide"
|
|
Returns
|
|
-------
|
|
xr.DataArray
|
|
计算结果
|
|
"""
|
|
if isinstance(arr1, xr.Dataset):
|
|
arr1 = arr1.to_array()
|
|
if isinstance(arr2, xr.Dataset):
|
|
arr2 = arr2.to_array()
|
|
# 备份源数据坐标与属性
|
|
org_attrs = arr1.attrs
|
|
org_dim = arr1.dims
|
|
org_coords = arr1.coords
|
|
org_crs = arr1.rio.crs if arr1.rio.crs else None
|
|
# 将输入数组转换为纯numpy数组, 计算时仅关注数组本身
|
|
arr1 = arr1.values.astype(np.float32, copy=False)
|
|
arr2 = arr2.values.astype(np.float32, copy=False)
|
|
# 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜
|
|
arr1[arr1 == -9999] = np.nan
|
|
arr2[arr2 == -9999] = np.nan
|
|
# 再过滤数组1与数组2均为空的元素
|
|
if arr1.shape == arr2.shape:
|
|
mask = np.isnan(arr1) & np.isnan(arr2)
|
|
arr1[mask] = np.nan
|
|
arr2[mask] = np.nan
|
|
else:
|
|
mask = None
|
|
# 根据计算函数进行计算
|
|
func = func.lower()
|
|
if func == "subtract":
|
|
result = arr1 - arr2
|
|
elif func == "add":
|
|
result = arr1 + arr2
|
|
elif func == "multiply":
|
|
result = arr1 * arr2
|
|
elif func == "divide":
|
|
# 若分母为0,则结果为nan
|
|
arr2[arr2 == 0] = np.nan
|
|
result = arr1 / arr2
|
|
else:
|
|
raise ValueError("Unsupported operation")
|
|
# 释放不再使用的内存
|
|
del arr1, arr2, mask
|
|
# 恢复因计算而丢失的属性与空间坐标系
|
|
result = xr.DataArray(
|
|
data=result,
|
|
coords=org_coords,
|
|
dims=org_dim,
|
|
attrs=org_attrs,
|
|
)
|
|
if org_crs:
|
|
result.rio.write_crs(org_crs, inplace=True)
|
|
return result
|
|
|
|
|
|
def array_calc(
|
|
arr1: xr.DataArray | xr.Dataset | np.ndarray,
|
|
arr2: xr.DataArray | xr.Dataset | np.ndarray,
|
|
func: str,
|
|
) -> np.ndarray:
|
|
"""
|
|
数组计算 [仅返回数组本身即np.ndarray]
|
|
|
|
仅支持形状坐标一致的两数组, 先进行非空掩膜, 再进行计算
|
|
|
|
Parameters
|
|
----------
|
|
arr1 : xr.DataArray | xr.Dataset | np.ndarray
|
|
数组1 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray]
|
|
arr2 : xr.DataArray | xr.Dataset | np.ndarray
|
|
数组2 [支持xarray.DataArray, xarray.Dataset, numpy.ndarray]
|
|
func : str
|
|
计算函数, 支持 "subtract", "add", "multiply", "divide"
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
计算结果
|
|
"""
|
|
# 将输入数组转换为纯numpy数组, 计算时仅关注数组本身
|
|
if isinstance(arr1, xr.Dataset):
|
|
arr1 = arr1.to_array().values.astype(np.float32, copy=False)
|
|
if isinstance(arr2, xr.Dataset):
|
|
arr2 = arr2.to_array().values.astype(np.float32, copy=False)
|
|
if isinstance(arr1, xr.DataArray):
|
|
arr1 = arr1.values.astype(np.float32, copy=False)
|
|
if isinstance(arr2, xr.DataArray):
|
|
arr2 = arr2.values.astype(np.float32, copy=False)
|
|
# 先过滤掉数组1中值为-9999的元素, 然后使用数组1对数组2进行非空掩膜
|
|
arr1[arr1 == -9999] = np.nan
|
|
arr2[arr2 == -9999] = np.nan
|
|
# 再过滤数组1与数组2均为空的元素
|
|
if arr1.shape == arr2.shape:
|
|
if func == "divide":
|
|
# 若分母为0,则结果为nan
|
|
arr2[arr2 == 0] = np.nan
|
|
mask = np.isnan(arr1) & np.isnan(arr2)
|
|
arr1[mask] = np.nan
|
|
arr2[mask] = np.nan
|
|
else:
|
|
mask = None
|
|
# 根据计算函数进行计算
|
|
func = func.lower()
|
|
if func == "subtract":
|
|
result = arr1 - arr2
|
|
elif func == "add":
|
|
result = arr1 + arr2
|
|
elif func == "multiply":
|
|
result = arr1 * arr2
|
|
elif func == "divide":
|
|
result = arr1 / arr2
|
|
else:
|
|
raise ValueError("Unsupported operation")
|
|
# 释放不再使用的内存
|
|
del arr1, arr2, mask
|
|
return result
|
|
|
|
|
|
def setup_dask_environment():
|
|
"""
|
|
Passes RIO environment variables to dask workers for authentication.
|
|
"""
|
|
import os
|
|
import rasterio
|
|
|
|
cookie_file_path = os.path.expanduser("~/cookies.txt")
|
|
|
|
global env
|
|
gdal_config = {
|
|
"GDAL_HTTP_UNSAFESSL": "YES",
|
|
"GDAL_HTTP_COOKIEFILE": cookie_file_path,
|
|
"GDAL_HTTP_COOKIEJAR": cookie_file_path,
|
|
"GDAL_DISABLE_READDIR_ON_OPEN": "YES",
|
|
"CPL_VSIL_CURL_ALLOWED_EXTENSIONS": "TIF",
|
|
"GDAL_HTTP_MAX_RETRY": "10",
|
|
"GDAL_HTTP_RETRY_DELAY": "0.5",
|
|
"GDAL_HTTP_TIMEOUT": "300",
|
|
}
|
|
|
|
env = rasterio.Env(**gdal_config)
|
|
env.__enter__()
|
|
|
|
|
|
def setup_logging(log_file: str = "dask_worker.log"):
|
|
"""
|
|
在Dask工作进程中设置logging
|
|
|
|
Parameters
|
|
----------
|
|
|
|
log_file : str, optional
|
|
日志文件路径, by default "dask_worker.log"
|
|
"""
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(levelname)s:%(asctime)s ||| %(message)s",
|
|
handlers=[
|
|
logging.StreamHandler(sys.stdout),
|
|
logging.FileHandler(log_file),
|
|
],
|
|
)
|
|
|
|
|
|
def load_band_as_arr(org_tif_path, band_num=1):
|
|
"""
|
|
读取波段数据
|
|
|
|
args:
|
|
org_tif_path (str): 原始tif文件路径
|
|
band_num (int): 波段号, 默认为 1
|
|
returns:
|
|
numpy.ndarray: 数组化波段数据
|
|
"""
|
|
org_tif = gdal.Open(org_tif_path)
|
|
if not org_tif:
|
|
raise ValueError(f"GDAL could not open {org_tif_path}")
|
|
band = org_tif.GetRasterBand(band_num)
|
|
data = band.ReadAsArray()
|
|
# 获取 NoData 值
|
|
nodata = band.GetNoDataValue()
|
|
if nodata is not None:
|
|
# 将 NoData 值替换为 NaN
|
|
data[data == nodata] = np.nan
|
|
return data
|
|
|
|
|
|
def get_proj_info(org_tif_path):
|
|
"""
|
|
获取原始影像的投影和变换
|
|
|
|
args:
|
|
org_tif_path (str): 原始tif文件路径
|
|
returns:
|
|
str: 投影与变换
|
|
"""
|
|
org_tif = gdal.Open(org_tif_path)
|
|
projection = org_tif.GetProjection()
|
|
transform = org_tif.GetGeoTransform()
|
|
org_tif = None
|
|
return projection, transform
|
|
|
|
|
|
def calc_time_series(file_list, calc_method):
|
|
"""
|
|
时间序列合成
|
|
|
|
args:
|
|
file_list (list): 文件列表
|
|
calc_method (str): 计算方法, 包括 "mean", "median", "max", "min"
|
|
returns:
|
|
numpy.ndarray: 时间序列合成结果
|
|
"""
|
|
if not file_list:
|
|
raise ValueError("file_list is empty.")
|
|
calc_method = calc_method.lower()
|
|
if calc_method == "mean":
|
|
data = np.nanmean([load_band_as_arr(file) for file in file_list], axis=0)
|
|
elif calc_method == "median":
|
|
data = np.nanmedian([load_band_as_arr(file) for file in file_list], axis=0)
|
|
elif calc_method == "max":
|
|
data = np.nanmax([load_band_as_arr(file) for file in file_list], axis=0)
|
|
elif calc_method == "min":
|
|
data = np.nanmin([load_band_as_arr(file) for file in file_list], axis=0)
|
|
else:
|
|
raise ValueError("Invalid calc_method.")
|
|
return data.astype(np.float32)
|
|
|
|
|
|
def save_as_tif(data, projection, transform, file_path):
|
|
"""
|
|
保存为tif
|
|
|
|
args:
|
|
data (numpy.ndarray): 要保存的数据
|
|
projection (str): 投影
|
|
transform (str): 变换
|
|
file_path (str): 文件输出完整路径
|
|
"""
|
|
if data is None:
|
|
return
|
|
y, x = data.shape
|
|
gtiff_driver = gdal.GetDriverByName("GTiff")
|
|
out_ds = gtiff_driver.Create(
|
|
file_path, x, y, 1, gdal.GDT_Float32, options=["COMPRESS=DEFLATE"]
|
|
)
|
|
out_ds.SetGeoTransform(transform)
|
|
out_ds.SetProjection(projection)
|
|
out_band = out_ds.GetRasterBand(1)
|
|
out_band.WriteArray(data)
|
|
out_band.FlushCache()
|
|
out_ds = None # 确保文件正确关闭
|
|
return
|
|
|
|
|
|
def array_to_raster(
|
|
data: np.ndarray, transform, wkt, dtype: str = None, nodata=-9999
|
|
) -> gdal.Dataset:
|
|
"""
|
|
将 numpy 数组转换为 gdal.Dataset 对象
|
|
|
|
reference: https://github.com/arthur-e/pyl4c/blob/master/pyl4c/spatial.py
|
|
|
|
args:
|
|
data (numpy.ndarray): 待转换的 numpy 数组
|
|
transform: (仿射)投影变换矩阵
|
|
wkt (str): 投影坐标系信息
|
|
dtype (str): 数据类型, 默认为 None
|
|
nodata (float): NoData 值, 默认为 -9999
|
|
returns:
|
|
gdal.Dataset: 转换后的 gdal.Dataset 对象
|
|
"""
|
|
if dtype is not None:
|
|
data = data.astype(dtype)
|
|
try:
|
|
rast = gdal_array.OpenNumPyArray(data)
|
|
except AttributeError:
|
|
# For backwards compatibility with older version of GDAL
|
|
rast = gdal.Open(gdal_array.GetArrayFilename(data))
|
|
except:
|
|
rast = gdal_array.OpenArray(data)
|
|
rast.SetGeoTransform(transform)
|
|
rast.SetProjection(wkt)
|
|
if nodata is not None:
|
|
for band in range(1, rast.RasterCount + 1):
|
|
rast.GetRasterBand(band).SetNoDataValue(nodata)
|
|
return rast
|
|
|
|
|
|
def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]):
|
|
"""
|
|
Uses the Fmask layer and bit numbers to create a binary mask of good pixels.
|
|
By default, bits 0-5 are used.
|
|
"""
|
|
mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))
|
|
# Remove/Mask Fill Values and Convert to Integer
|
|
quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)
|
|
for bit in bit_nums:
|
|
# Create a Single Binary Mask Layer
|
|
mask_temp = np.array(quality_data) & 1 << bit > 0
|
|
mask_array = np.logical_or(mask_array, mask_temp)
|
|
return mask_array
|
|
|
|
|
|
def clip_image(
|
|
image: xr.DataArray | xr.Dataset, roi: gpd.GeoDataFrame = None, clip_by_box=True
|
|
) -> xr.DataArray | xr.Dataset | None:
|
|
"""
|
|
Clip Image data to ROI.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
image : xarray.DataArray | xarray.Dataset
|
|
通过 rioxarray.open_rasterio 加载的图像数据.
|
|
roi : gpd.GeoDataFrame, optional
|
|
感兴趣区数据.
|
|
clip_by_box : bool, optional
|
|
是否使用 bbox 进行裁剪, 默认为 True.
|
|
|
|
Returns
|
|
-------
|
|
|
|
xarray.DataArray | xarray.Dataset | None
|
|
裁剪后的图像数据. 若裁剪后数据全为无效值, 则返回 None.
|
|
"""
|
|
|
|
if roi is None:
|
|
return image
|
|
org_crs = image.rio.crs
|
|
if not roi.crs == org_crs:
|
|
# 若 crs 不一致, 则重新投影感兴趣区数据
|
|
roi_bound = roi.to_crs(org_crs)
|
|
else:
|
|
roi_bound = roi
|
|
if clip_by_box:
|
|
clip_area = [box(*roi_bound.total_bounds)]
|
|
else:
|
|
clip_area = roi_bound.geometry.values
|
|
# 设置nodata值防止裁剪时未相交部分出现nan值 [仅对DataArray有效]
|
|
if isinstance(image, xr.DataArray):
|
|
nodata_value = -9999
|
|
image.rio.write_nodata(nodata_value, inplace=True)
|
|
image_cliped = image.rio.clip(
|
|
clip_area, roi_bound.crs, all_touched=True, from_disk=True
|
|
)
|
|
return image_cliped
|
|
|
|
|
|
def clip_roi_image(
|
|
file_path: str, grid: gpd.GeoDataFrame = None
|
|
) -> xr.DataArray | None:
|
|
"""
|
|
按研究区范围裁剪影像
|
|
|
|
Parameters
|
|
----------
|
|
|
|
file_path : str
|
|
待裁剪影像路径
|
|
grid : gpd.GeoDataFrame, optional
|
|
格网范围, 默认为 None.
|
|
|
|
Returns
|
|
-------
|
|
|
|
raster_cliped : xr.DataArray
|
|
裁剪后的影像
|
|
"""
|
|
raster = open_rasterio(file_path)
|
|
try:
|
|
doy = os.path.basename(file_path).split(".")[3]
|
|
except Exception as e:
|
|
doy = None
|
|
if doy:
|
|
raster.attrs["DOY"] = doy
|
|
# 先对数据进行降维, 若存在band波段, 则降维为二维数组; 若不存在band波段, 则继续计算
|
|
if "band" in raster.dims and raster.sizes["band"] == 1:
|
|
raster = raster.squeeze("band")
|
|
# 由于当前实施均在同一格网下进行, 且MODIS数据原始坐标为正弦投影, 在原始影像爬取与预处理阶段使用格网裁剪时并未裁剪出目标效果
|
|
# 所以需要先使用格网裁剪, 再后使用感兴趣区域裁剪
|
|
# TODO: 跨格网实施算法时, 修改为使用感兴趣区域裁剪
|
|
if grid is not None:
|
|
raster_cliped = clip_image(raster, grid)
|
|
else:
|
|
raster_cliped = raster
|
|
# TODO: 待完善, 若裁剪后的数据全为空或空值数量大于裁剪后数据总数量, 则跳过
|
|
# raster_clip_grid = raster_clip_grid.where(raster_clip_grid != -9999)
|
|
# if (
|
|
# raster_clip_grid.count().item() == 0
|
|
# or raster_clip_grid.isnull().sum() > raster_clip_grid.count().item()
|
|
# ):
|
|
# return
|
|
raster_cliped.attrs = raster.attrs.copy()
|
|
raster_cliped.attrs["file_name"] = os.path.basename(file_path)
|
|
return raster_cliped
|
|
|
|
|
|
def reproject_image(
|
|
image: xr.DataArray,
|
|
target_crs: str = None,
|
|
target_shape: tuple = None,
|
|
target_image: xr.DataArray = None,
|
|
) -> xr.DataArray:
|
|
"""
|
|
Reproject Image data to target CRS or target data.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
image : xarray.DataArray
|
|
通过 rioxarray.open_rasterio 加载的图像数据.
|
|
target_crs : str, optional
|
|
Target CRS, eg. EPSG:4326.
|
|
target_shape : tuple, optional
|
|
Target shape, eg. (1000, 1000).
|
|
target_image : xarray.DataArray, optional
|
|
Target image, eg. rioxarray.open_rasterio 加载的图像数据.
|
|
|
|
Returns
|
|
-------
|
|
|
|
xarray.DataArray
|
|
重投影后的图像数据.
|
|
"""
|
|
if target_image is not None:
|
|
# 使用 target_image 进行重投影匹配
|
|
if (
|
|
target_image.shape[1] > image.shape[1]
|
|
or target_image.shape[2] > image.shape[2]
|
|
) or (
|
|
target_image.shape[1] == image.shape[1]
|
|
and target_image.shape[2] == image.shape[2]
|
|
):
|
|
# 若判断为降尺度/等尺度, 则直接使用 cubic 双三次插值重采样投影到目标影像
|
|
image_reprojed = image.rio.reproject_match(
|
|
target_image, resampling=Resampling.cubic
|
|
)
|
|
else:
|
|
# print("target_image shape is not match with image shape", image.shape, "to", target_image.shape)
|
|
# 若判断为升尺度, 为减少图像像元丢失, 需要使用聚合类方法, 考虑使用 average/mode/med 重采样投影到目标影像
|
|
image_reprojed = image.rio.reproject_match(
|
|
target_image, resampling=Resampling.med
|
|
)
|
|
elif target_crs is not None:
|
|
# 使用 target_crs 进行重投影
|
|
reproject_kwargs = {
|
|
"dst_crs": target_crs,
|
|
"resampling": Resampling.cubic,
|
|
}
|
|
if target_shape is not None:
|
|
reproject_kwargs["shape"] = target_shape
|
|
image_reprojed = image.rio.reproject(**reproject_kwargs)
|
|
else:
|
|
# 没有任何重投影参数, 返回原始图像
|
|
image_reprojed = image
|
|
return image_reprojed
|
|
|
|
|
|
def mosaic_images(
|
|
tif_list: list[str | xr.DataArray],
|
|
nodata=np.nan,
|
|
tif_crs: any = None,
|
|
method: str = "first",
|
|
) -> xr.DataArray:
|
|
"""
|
|
影像镶嵌合成
|
|
|
|
将同一天/同一区域的数据集进行镶嵌, 生成形状为 (1, y, x) 的 numpy 数组, 以及 transform
|
|
|
|
Parameters
|
|
----------
|
|
tif_list : list[str | xr.DataArray]
|
|
待镶嵌的影像文件路径列表 或已加载的 xarray.DataArray 影像数据列表
|
|
nodata : float | int
|
|
空值填充值
|
|
tif_crs : Any, optional
|
|
tif影像的crs
|
|
method : str, optional
|
|
合成方法, 默认为 "first", 可选 "last", "min", "max"
|
|
|
|
Returns
|
|
-------
|
|
ds : xr.DataArray
|
|
镶嵌后的影像数据
|
|
"""
|
|
if isinstance(tif_list[0], str):
|
|
ds_np, transform = merge(
|
|
tif_list,
|
|
nodata=nodata,
|
|
method=method,
|
|
resampling=Resampling.cubic,
|
|
)
|
|
# 将结果重新构建为 xarray 数据集
|
|
# 单张SAR影像直接读取 transform: 233400.0 30.0 0.0 3463020.0 0.0 -30.0
|
|
# 而镶嵌后输出 transform 为: (30.0, 0.0, 221250.0, 0.0, -30.0, 3536970.0)
|
|
x_scale, _, x_origin, _, y_scale, y_origin = transform[:6]
|
|
affine_transform = Affine.from_gdal(
|
|
*(y_origin, x_scale, 0.0, x_origin, 0.0, y_scale)
|
|
)
|
|
y = np.arange(ds_np.shape[1]) * y_scale + y_origin
|
|
x = np.arange(ds_np.shape[2]) * x_scale + x_origin
|
|
ds = xr.DataArray(
|
|
ds_np,
|
|
coords={
|
|
"band": np.arange(ds_np.shape[0]),
|
|
"y": y,
|
|
"x": x,
|
|
},
|
|
dims=["band", "y", "x"],
|
|
)
|
|
ds.rio.write_transform(affine_transform, inplace=True)
|
|
else:
|
|
ds = merge_arrays(tif_list, nodata=nodata, method=method)
|
|
if tif_crs is not None:
|
|
ds.rio.write_crs(tif_crs, inplace=True)
|
|
return ds
|
|
|
|
|
|
def merge_time_series(raster_dir: str) -> xr.Dataset:
|
|
"""
|
|
合成时间序列
|
|
|
|
读取指定目录内的所有tif文件, 将它们按照时间方向合并为一个xarray.Dataset
|
|
|
|
Parameters
|
|
----------
|
|
raster_dir : str
|
|
包含tif文件的目录路径
|
|
|
|
Returns
|
|
-------
|
|
raster_dataset : xr.Dataset
|
|
包含时间维度的xarray.Dataset
|
|
"""
|
|
raster_list = []
|
|
# 遍历每个文件, 读取时间属性, 并将其作为新波段合并到dataset中
|
|
for file in glob.glob(os.path.join(raster_dir, "*.tif")):
|
|
date = os.path.basename(file).split(".")[3]
|
|
if len(str(date)) < 7:
|
|
continue
|
|
# 读取影像
|
|
raster_data = open_rasterio(file, masked=True).squeeze(dim="band", drop=True)
|
|
# 获取时间属性
|
|
time_attr = datetime.strptime(date, "%Y%j")
|
|
# 将时间属性作为新的维度
|
|
raster_data = raster_data.assign_coords(time=time_attr)
|
|
# 将新的波段添加到dataset中
|
|
raster_list.append(raster_data)
|
|
|
|
raster_dataset = xr.concat(raster_list, dim="time").sortby("time")
|
|
return raster_dataset
|
|
|
|
|
|
def get_value_at_point(
|
|
dataset: xr.Dataset, value_name: str, lon: float, lat: float
|
|
) -> pd.DataFrame:
|
|
"""
|
|
从 Dataset 中提取点位置的值, 提取矢量区域所在的值的时间序列
|
|
|
|
Parameters
|
|
----------
|
|
dataset : xr.Dataset
|
|
包含 x, y, time 维度的数据集
|
|
value_name : str
|
|
数据集的变量名
|
|
lon : float
|
|
经度
|
|
lat : float
|
|
纬度
|
|
|
|
Returns
|
|
-------
|
|
df : pd.DataFrame
|
|
包含时间序列的 DataFrame
|
|
"""
|
|
point_values = dataset.sel(x=lon, y=lat, method="nearest").to_dataset(dim="time")
|
|
# 将结果转换为 DataFrame
|
|
df = (
|
|
point_values.to_dataframe()
|
|
.reset_index()
|
|
.melt(id_vars=["y", "x"], var_name="TIME", value_name=value_name)
|
|
.drop(columns=["y", "x"])
|
|
)
|
|
# 剔除time列为spatial_ref和值为-9999无效值的行
|
|
df = df[df["TIME"] != "spatial_ref"]
|
|
df = df[df[value_name] != -9999]
|
|
# 过滤掉Nan值
|
|
df.dropna(subset=[value_name], inplace=True)
|
|
df["TIME"] = pd.to_datetime(df["TIME"])
|
|
return df
|
|
|
|
|
|
def valid_raster(raster_path, threshold=0.4) -> tuple[str, float] | None:
|
|
"""
|
|
判断栅格数据是否有效
|
|
|
|
有效数据占比超过阈值则认为有效, 返回栅格数据路径, 否则返回 None
|
|
|
|
Parameters
|
|
----------
|
|
raster_path : str
|
|
栅格数据路径
|
|
threshold : float
|
|
有效数据占比阈值
|
|
|
|
Returns
|
|
-------
|
|
raster_path : str
|
|
栅格数据路径
|
|
valid_data_percentage : float
|
|
有效数据占比
|
|
"""
|
|
with open_rasterio(raster_path, masked=True) as raster:
|
|
raster = raster.where(raster != -9999)
|
|
all_data_count = raster.size
|
|
valid_data_count = raster.count().item()
|
|
valid_data_percentage = valid_data_count / all_data_count
|
|
if valid_data_percentage >= threshold:
|
|
return raster_path, valid_data_percentage
|
|
else:
|
|
return None
|
|
|
|
|
|
def valid_raster_list(
|
|
raster_path_list, threshold=0.4, only_path=False
|
|
) -> list[tuple[str, float]]:
|
|
"""
|
|
判断栅格数据列表是否有效
|
|
|
|
有效数据占比超过阈值则认为有效, 返回栅格数据路径列表, 否则返回空列表
|
|
|
|
Parameters
|
|
----------
|
|
raster_path_list : list[str]
|
|
栅格数据路径列表
|
|
threshold : float
|
|
有效数据占比阈值
|
|
only_path : bool, optional
|
|
仅保留有效数据路径列表, 默认为 False
|
|
|
|
Returns
|
|
-------
|
|
valid_raster_path_list : list[tuple[str, float]] | list[str]
|
|
有效栅格数据路径列表, 每个元素为 (栅格数据路径, 有效数据占比) 或栅格数据路径
|
|
[]
|
|
无效栅格数据
|
|
"""
|
|
raster_path_list = list(
|
|
map(
|
|
lambda x: valid_raster(x, threshold),
|
|
raster_path_list,
|
|
)
|
|
)
|
|
# 剔除列表中的 None 值
|
|
raster_path_list = list(filter(None, raster_path_list))
|
|
if only_path:
|
|
raster_path_list = list(map(lambda x: x[0], raster_path_list))
|
|
return raster_path_list
|
|
|
|
|
|
def match_utm_grid(roi: gpd.GeoDataFrame, mgrs_kml_file: str) -> gpd.GeoDataFrame:
|
|
"""
|
|
根据 ROI 匹配对应的 ESA Sentinel-2 UTM 格网信息.
|
|
|
|
NOTE: 需要在 fiona > 1.9 环境下才能运行.
|
|
"""
|
|
import fiona
|
|
|
|
# enable KML support which is disabled by default
|
|
fiona.drvsupport.supported_drivers["LIBKML"] = "rw"
|
|
# bbox = tuple(list(roi.total_bounds))
|
|
if not os.path.isfile(mgrs_kml_file):
|
|
kml_url = "https://hls.gsfc.nasa.gov/wp-content/uploads/2016/03/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.kml"
|
|
earthaccess.download([kml_url], mgrs_kml_file)
|
|
# 尽管 geopandas.read_file() 可以直接读取在线资源, 但是由于格网 kml 文件过大 (约 106M), 导致加载会非常慢, 所以先下载到本地再进行读取
|
|
mgrs_gdf = gpd.read_file(mgrs_kml_file, roi)
|
|
# 空间连接, 筛选网格中与ROI边界相交的部分
|
|
grid_in_roi = gpd.sjoin(mgrs_gdf, roi, predicate="intersects", how="left")
|
|
# 剔除连接后的产生的冗余属性
|
|
grid_in_roi = grid_in_roi[mgrs_gdf.columns].drop(
|
|
columns=[
|
|
"description",
|
|
"timestamp",
|
|
"begin",
|
|
"end",
|
|
"altitudeMode",
|
|
"drawOrder",
|
|
]
|
|
)
|
|
# 处理GeometryCollection类型
|
|
for i in range(len(grid_in_roi)):
|
|
grid = grid_in_roi.iloc[i]
|
|
# 将 GeometryCollection 转换为 Polygon
|
|
if grid.geometry.geom_type == "GeometryCollection":
|
|
grid_in_roi.at[i, "geometry"] = grid.geometry.geoms[0]
|
|
return grid_in_roi
|
|
|
|
|
|
def plot(data, title=None, cmap="gray"):
|
|
"""
|
|
绘制影像图像
|
|
|
|
args:
|
|
data (numpy.ndarray): 要绘制的数据
|
|
title (str): 标题
|
|
cmap (str): 颜色映射
|
|
"""
|
|
plt.imshow(data)
|
|
plt.title(title)
|
|
plt.axis("off") # 关闭坐标轴
|
|
plt.colorbar()
|
|
plt.show()
|