feat: 优化项目路径处理和代码结构以及MODIS下载处理逻辑.

This commit is contained in:
谢泓 2025-09-12 10:41:56 +08:00
parent bda6f0a1ef
commit 64f50ffc0a
9 changed files with 141 additions and 90 deletions

2
.gitignore vendored
View File

@ -9,3 +9,5 @@ data/
*.tif
*.tiff
*.ipynb

View File

@ -39,9 +39,7 @@ import geopandas as gpd
import numpy as np
from rioxarray import open_rasterio
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_utils import setup_dask_environment, clip_image, mosaic_images
from HLS_SuPER.HLS_Su import earthdata_search
@ -139,7 +137,7 @@ def process_granule(
name: str,
roi: list,
clip=True,
tile_id: str = "",
tile_id: str = None,
) -> bool:
"""
读取解压并重命名处理后的指定类型 NASADEM 数据并进行预处理, 包括读取, 裁剪, 镶嵌, 并对坡度坡向进行缩放

View File

@ -28,9 +28,7 @@ import logging
import earthaccess
from xarray import open_dataset
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_utils import setup_dask_environment
from HLS_SuPER.HLS_Su import earthdata_search

View File

@ -6,7 +6,7 @@ For example, MCD43A3, MCD43A4, MOD11A1.
-------------------------------------------------------------------------------
Authors: Hong Xie
Last Updated: 2025-07-15
Last Updated: 2025-09-11
===============================================================================
"""
@ -20,11 +20,14 @@ import rioxarray as rxr
import dask.distributed
import geopandas as gpd
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_utils import clip_image, reproject_image, setup_dask_environment
from utils.common_utils import (
clip_image,
reproject_image,
setup_dask_environment,
setup_logging,
)
from HLS_SuPER.HLS_Su import earthdata_search
@ -119,7 +122,15 @@ def open_modis(file_path, prod_name):
raise ValueError(f"Unknown MODIS product: {prod_name}.")
def process_modis(download_file, prod_name, roi, clip=True, scale=True, target_crs=None, output_file=None):
def process_modis(
download_file,
prod_name,
roi,
clip=True,
scale=True,
target_crs=None,
output_file=None,
):
"""
MODIS 数据进行预处理, 包括裁剪, 重投影和缩放.
"""
@ -129,13 +140,14 @@ def process_modis(download_file, prod_name, roi, clip=True, scale=True, target_c
if roi is not None and clip:
modis = clip_image(modis, roi)
if target_crs is not None:
if target_crs is not None and modis is not None:
modis = reproject_image(modis, target_crs)
# 重投影后再裁剪一次
if roi is not None and clip:
modis = clip_image(modis, roi)
# 重投影后再裁剪一次
if roi is not None and clip:
modis = clip_image(modis, roi)
if modis.isnull().all():
logging.error(f"Processing {download_file}. Roi area all pixels are nodata.")
if scale:
# 缩放计算后会丢源属性和坐标系, 需要先备份源数据属性信息
org_attrs = modis.attrs
@ -167,14 +179,9 @@ def process_granule(
clip,
scale,
output_dir,
target_crs="EPSG:4326",
tile_id=None,
target_crs="EPSG:4326",
):
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:%(asctime)s ||| %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
download_hdf_name = os.path.basename(granule_urls[0])
# 获取名称与日期
@ -193,7 +200,7 @@ def process_granule(
out_tif_name = f"MODIS.{prod_name}.{tile_id}.{date}.NBRDF.tif"
else:
out_tif_name = download_hdf_name.replace(".hdf", ".tif")
# 除 MCD43A4 需用于光谱指数计算外, MOD11A1 日间温度与 MCD43A4 反照率无需再按日期归档
# 除 MCD43A4 需用于光谱指数计算外, MOD11A1 日间温度与 MCD43A3 反照率无需再按日期归档
if prod_name == "MOD11A1" or prod_name == "MCD43A3":
output_path = os.path.join(output_dir, "TIF")
else:
@ -201,25 +208,29 @@ def process_granule(
os.makedirs(output_path, exist_ok=True)
output_file = os.path.join(output_path, out_tif_name)
if not os.path.isfile(output_file):
# Step1: 下载 HDF 文件
if not os.path.isfile(download_file):
try:
earthaccess.download(granule_urls, download_path)
except Exception as e:
logging.error(f"Error downloading {download_file}: {e}")
return
else:
logging.warning(f"{download_file} already exists. Skipping.")
# Step2: 处理 HDF 文件
# Step1: 下载 HDF 文件
if not os.path.isfile(download_file):
try:
process_modis(download_file, prod_name, roi, clip, scale, target_crs, output_file)
earthaccess.download(granule_urls, download_path)
except Exception as e:
logging.error(f"Error downloading {download_file}: {e}")
return
else:
logging.warning(f"{download_file} already exists. Skipping.")
# Step2: 处理 HDF 文件
if not os.path.isfile(output_file):
try:
process_modis(
download_file, prod_name, roi, clip, scale, target_crs, output_file
)
logging.info(f"Processed {output_file} successfully.")
except Exception as e:
os.remove(download_file)
logging.info(f"Removed corrupted file {download_file}. Retrying download.")
process_granule(granule_urls, roi, clip, scale, output_dir, target_crs, tile_id)
logging.info(f"Processed {output_file} successfully.")
process_granule(
granule_urls, roi, clip, scale, output_dir, target_crs, tile_id
)
else:
logging.warning(f"{output_file} already exists. Skipping.")
@ -231,6 +242,7 @@ def main(
years: list,
dates: tuple[str, str],
tile_id: str,
target_crs: str,
output_root_dir: str,
):
bbox = tuple(list(region.total_bounds))
@ -257,20 +269,9 @@ def main(
with open(results_urls_file, "w") as f:
json.dump(results_urls, f)
# 配置日志, 首次配置生效, 后续嵌套配置无效
logging.basicConfig(
level=logging.INFO, # 级别为INFO及以上的日志会被记录
format="%(levelname)s:%(asctime)s ||| %(message)s",
handlers=[
logging.StreamHandler(sys.stdout), # 输出到控制台
logging.FileHandler(
f"{output_dir}\\{asset_name}_{tile_id}_SuPER.log"
), # 输出到日志文件
],
)
client = dask.distributed.Client(timeout=60, memory_limit="8GB")
client.run(setup_dask_environment)
client.run(setup_logging)
all_start_time = time.time()
for year in years:
year_results_dir = os.path.join(output_dir, year)
@ -278,6 +279,11 @@ def main(
year_results_dir, f"{asset_name}_{modis_tile_id}_{year}_results_urls.json"
)
year_results = json.load(open(year_results_file))
# 配置主进程日志
logs_file = os.path.join(
year_results_dir, f"{asset_name}_{tile_id}_{year}_SuPER.log"
)
setup_logging(logs_file)
client.scatter(year_results)
start_time = time.time()
@ -289,14 +295,21 @@ def main(
clip=True,
scale=True,
output_dir=year_results_dir,
target_crs="EPSG:32649",
tile_id=tile_id,
target_crs=target_crs,
)
for granule_url in year_results
]
dask.compute(*tasks)
total_time = time.time() - start_time
# Dask任务结束后读取dask_worker.txt日志文件内容, 并注入到logs_file中
with open(logs_file, "a") as f:
with open("dask_worker.log", "r") as f2:
f.write(f2.read())
# 随后清空dask_worker.txt文件
with open("dask_worker.log", "w") as f:
f.write("")
logging.info(
f"{year} MODIS {asset_name} Downloading complete and proccessed. Total time: {total_time} seconds"
)
@ -305,6 +318,9 @@ def main(
logging.info(
f"All MODIS {asset_name} Downloading complete and proccessed. Total time: {all_total_time} seconds"
)
# 最后删除dask_worker.log文件
os.remove("dask_worker.log")
return
if __name__ == "__main__":
@ -312,12 +328,22 @@ if __name__ == "__main__":
# region = gpd.read_file("./data/vectors/wuling_guanqu_polygon.geojson")
tile_id = "49REL"
region = gpd.read_file(f"./data/vectors/{tile_id}.geojson")
# asset_name = "MOD11A1"
target_crs = "EPSG:32649"
asset_name = "MOD11A1"
# asset_name = "MCD43A3"
asset_name = "MCD43A4"
# asset_name = "MCD43A4"
modis_tile_id = "h27v06"
# 示例文件名称: MCD43A4.A2024001.h27v05.061.2024010140610.hdf
years = ["2024", "2023", "2022"]
years = ["2024"]
dates = ("03-01", "10-31")
output_root_dir = ".\\data\\MODIS\\"
main(region, asset_name, modis_tile_id, years, dates, tile_id, output_root_dir)
main(
region,
asset_name,
modis_tile_id,
years,
dates,
tile_id,
target_crs,
output_root_dir,
)

View File

@ -48,9 +48,7 @@ import numpy as np
import xarray as xr
from rioxarray import open_rasterio
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_utils import setup_dask_environment, clip_image, mosaic_images
from HLS_SuPER.HLS_Su import earthdata_search

View File

@ -28,9 +28,7 @@ import h5py
from osgeo import gdal
import xarray as xr
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_params import EASE2_GRID_PARAMS, EPSG
from utils.common_utils import (

View File

@ -55,9 +55,7 @@ import logging
import time
from datetime import datetime, timedelta
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class getInsituData:

View File

@ -29,9 +29,7 @@ import geopandas as gpd
from datetime import datetime as dt
import dask.distributed
# 动态获取项目根目录路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.common_utils import setup_dask_environment
from HLS_Su import hls_search

View File

@ -5,7 +5,7 @@
-------------------------------------------------------------------------------
Authors: Hong Xie
Last Updated: 2025-08-13
Last Updated: 2025-09-11
===============================================================================
"""
@ -256,6 +256,7 @@ def setup_logging(log_file: str = "dask_worker.log"):
Parameters
----------
log_file : str, optional
日志文件路径, by default "dask_worker.log"
"""
@ -269,6 +270,7 @@ def setup_logging(log_file: str = "dask_worker.log"):
],
)
def load_band_as_arr(org_tif_path, band_num=1):
"""
读取波段数据
@ -410,15 +412,26 @@ def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]):
def clip_image(
image: xr.DataArray | xr.Dataset, roi: gpd.GeoDataFrame, clip_by_box=True
):
image: xr.DataArray | xr.Dataset, roi: gpd.GeoDataFrame = None, clip_by_box=True
) -> xr.DataArray | xr.Dataset | None:
"""
Clip Image data to ROI.
args:
image (xarray.DataArray | xarray.Dataset): 通过 rioxarray.open_rasterio 加载的图像数据.
roi (gpd.GeoDataFrame): 感兴趣区数据.
clip_by_box (bool): 是否使用 bbox 进行裁剪, 默认为 True.
Parameters
----------
image : xarray.DataArray | xarray.Dataset
通过 rioxarray.open_rasterio 加载的图像数据.
roi : gpd.GeoDataFrame, optional
感兴趣区数据.
clip_by_box : bool, optional
是否使用 bbox 进行裁剪, 默认为 True.
Returns
-------
xarray.DataArray | xarray.Dataset | None
裁剪后的图像数据. 若裁剪后数据全为无效值, 则返回 None.
"""
if roi is None:
@ -443,15 +456,25 @@ def clip_image(
return image_cliped
def clip_roi_image(file_path: str, grid: gpd.GeoDataFrame) -> xr.DataArray | None:
def clip_roi_image(
file_path: str, grid: gpd.GeoDataFrame = None
) -> xr.DataArray | None:
"""
按研究区范围裁剪影像
args:
file_path (str): 待裁剪影像路径
grid (gpd.GeoDataFrame): 格网范围
return:
raster_cliped (xr.DataArray): 裁剪后的影像
Parameters
----------
file_path : str
待裁剪影像路径
grid : gpd.GeoDataFrame, optional
格网范围, 默认为 None.
Returns
-------
raster_cliped : xr.DataArray
裁剪后的影像
"""
raster = open_rasterio(file_path)
try:
@ -487,15 +510,27 @@ def reproject_image(
target_crs: str = None,
target_shape: tuple = None,
target_image: xr.DataArray = None,
):
) -> xr.DataArray:
"""
Reproject Image data to target CRS or target data.
args:
image (xarray.DataArray): 通过 rioxarray.open_rasterio 加载的图像数据.
target_crs (str): Target CRS, eg. EPSG:4326.
target_shape (tuple): Target shape, eg. (1000, 1000).
target_image (xarray.DataArray): Target image, eg. rioxarray.open_rasterio 加载的图像数据.
Parameters
----------
image : xarray.DataArray
通过 rioxarray.open_rasterio 加载的图像数据.
target_crs : str, optional
Target CRS, eg. EPSG:4326.
target_shape : tuple, optional
Target shape, eg. (1000, 1000).
target_image : xarray.DataArray, optional
Target image, eg. rioxarray.open_rasterio 加载的图像数据.
Returns
-------
xarray.DataArray
重投影后的图像数据.
"""
if target_image is not None:
# 使用 target_image 进行重投影匹配
@ -506,7 +541,7 @@ def reproject_image(
target_image.shape[1] == image.shape[1]
and target_image.shape[2] == image.shape[2]
):
# 若判断为降尺度/等尺度, 则直接使用 cubic 重采样投影到目标影像
# 若判断为降尺度/等尺度, 则直接使用 cubic 双三次插值重采样投影到目标影像
image_reprojed = image.rio.reproject_match(
target_image, resampling=Resampling.cubic
)