389 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
===============================================================================
HLS Processing and Exporting Reformatted Data (HLS_PER)
This module contains functions to conduct subsetting and quality filtering of
search results.
-------------------------------------------------------------------------------
Authors: Cole Krehbiel, Mahsa Jami, and Erik Bolch
Editor: Hong Xie
Last Updated: 2025-01-06
===============================================================================
"""
import os
import sys
import logging
import numpy as np
from datetime import datetime as dt
import xarray as xr
import rioxarray as rxr
import dask.distributed
def create_output_name(url, band_dict):
"""
Uses HLS default naming scheme to generate an output name with common band names.
This allows for easier stacking of bands from both collections.
"""
# Get Necessary Strings
prod = url.split("/")[4].split(".")[0]
asset = url.split("/")[-1].split(".")[-2]
# Hard-coded one off for Fmask name incase it is not in the band_dict but is needed for masking
# 翻译硬编码一个Fmask名称, 以防它不在band_dict中但需要用于掩膜处理
if asset == "Fmask":
output_name = f"{'.'.join(url.split('/')[-1].split('.')[:-2])}.FMASK.subset.tif"
else:
for key, value in band_dict[prod].items():
if value == asset:
output_name = (
f"{'.'.join(url.split('/')[-1].split('.')[:-2])}.{key}.subset.tif"
)
return output_name
def open_hls(
url, roi=None, clip=False, scale=True, chunk_size=dict(band=1, x=512, y=512)
):
"""
Generic Function to open an HLS COG and clip to ROI. For consistent scaling, this must be done manually.
Some HLS Landsat scenes have the metadata in the wrong location.
"""
# Open using rioxarray
da_org = rxr.open_rasterio(url, chunks=chunk_size, mask_and_scale=False).squeeze(
"band", drop=True
)
# (Add) 若未获取到数据, 则返回 None
if da_org is None:
return None
# (Add) 复制源数据进行后续操作, 以便最后复制源数据属性信息
da = da_org.copy()
# (Add) 读取波段名称
split_asset = url.split("/")[-1].split(".")
asset_name = split_asset[1]
band_name = split_asset[-2] if split_asset[-2] != "subset" else split_asset[-3]
# 判断是否为 Fmask 波段
is_fmask = True if (band_name in ["Fmask", "FMASK"]) else False
# 判断是否为 L30 独有的热红外波段
is_tir = True if (band_name in ["B10", "B11", "TIR1", "TIR2"]) and (asset_name == "L30") else False
# Reproject ROI and Clip if ROI is provided and clip is True
if roi is not None and clip:
roi = roi.to_crs(da.spatial_ref.crs_wkt)
da = da.rio.clip(roi.geometry.values, roi.crs, all_touched=True)
# (Add) 即使大部分影像已经被缩放且填补了缺失值, 但可能仍然有些影像需要进行手动在本地GIS软件中进行缩放和填补缺失值
# Apply Scale Factor if desired for non-quality layer
if not is_fmask:
if scale:
# Mask Fill Values
da = xr.where(da == -9999, np.nan, da)
# Scale Data
# (Add) 除质量层, 以及 L30 的两个热红外波段外, 其他光谱波段缩放因子均为 0.0001
# (Add) L30 的两个热红外波段缩放因子为 0.01
if is_tir:
da = da * 0.01
else:
da = da * 0.0001
# Remove Scale Factor After Scaling - Prevents Double Scaling
# (Add) 缩放计算后会丢源属性和坐标系, 需要复制源数据坐标系与属性
# 先根据是否裁剪情况判断需要复制的坐标系
if clip:
da.rio.write_crs(roi.crs, inplace=True)
else:
da.rio.write_crs(da_org.rio.crs, inplace=True)
# 再复制源数据属性
da.attrs = da_org.attrs.copy()
da.attrs["scale_factor"] = 1.0
# Add Scale Factor to Attributes Manually - This will overwrite/add if the data is missing.
# (Add) 若要手动缩放, 则需要手动添加缩放因子
else:
if is_tir:
da.attrs["scale_factor"] = 0.01
else:
da.attrs["scale_factor"] = 0.0001
# 清除源数据
da_org = None
return da
def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]):
"""
Uses the Fmask layer and bit numbers to create a binary mask of good pixels.
By default, bits 0-5 are used.
"""
mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))
# Remove/Mask Fill Values and Convert to Integer
quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)
for bit in bit_nums:
# Create a Single Binary Mask Layer
mask_temp = np.array(quality_data) & 1 << bit > 0
mask_array = np.logical_or(mask_array, mask_temp)
return mask_array
def process_granule(
granule_urls,
roi,
clip,
quality_filter,
scale,
output_dir,
band_dict,
bit_nums=[0, 1, 2, 3, 4, 5],
chunk_size=dict(band=1, x=512, y=512),
):
"""
Processes a list of HLS asset urls for a single granule.
args:
granule_urls (list): List of HLS asset urls to process.
roi (geopandas.GeoDataFrame): ROI to filter data.
clip (bool): If True, ROI will be clipped to the image.
quality_filter (bool): If True, quality layer will be used to mask data.
scale (bool): If True, data will be scaled to reflectance.
output_dir (str): Directory to save output files.
band_dict (dict): Dictionary of band names and asset names.
bit_nums (list): List of bit numbers to use for quality mask.
- 0: Cirrus
- 1: Cloud
- 2: Adjacent to cloud/shadow
- 3: Cloud shadow
- 4: Snow/ice
- 5: Water
chunk_size (dict): Dictionary of chunk sizes for dask.
"""
# Setup Logging
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:%(asctime)s ||| %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Check if all Outputs Exist for a Granule
if not all(
os.path.isfile(f"{output_dir}/{create_output_name(url, band_dict)}")
for url in granule_urls
):
# First Handle Quality Layer
# (Add) 简化原有的冗余处理, 仅处理质量层, 并最后移除质量层下载url
if quality_filter:
# Generate Quality Layer URL
split_asset = granule_urls[0].split("/")[-1].split(".")
split_asset[-2] = "Fmask"
quality_url = (
f"{'/'.join(granule_urls[0].split('/')[:-1])}/{'.'.join(split_asset)}"
)
# Check if File exists in Output Directory
quality_output_name = create_output_name(quality_url, band_dict)
quality_output_file = f"{output_dir}/{quality_output_name}"
# Check if quality asset is already processed
if not os.path.isfile(quality_output_file):
# Open Quality Layer
qa_da = open_hls(quality_url, roi, clip, scale, chunk_size)
# Write Output
# (Add) 添加压缩选项参数 compress
# compress 参数是源自 rioxarray 继承的 rasterio 的选项, 可以参考 https://rasterio.readthedocs.io/en/latest/api/rasterio.enums.html#rasterio.enums.Compression
# 若未指定则默认为 LZW, 生成速度较快, 但文件较大
# 设置为 DEFLATE 是为了与官网直接下载文件一致且输出文件更小, 但生成速度略慢, 单张约慢 15s.
qa_da.rio.to_raster(raster_path=quality_output_file, driver="COG", compress="DEFLATE")
else:
qa_da = open_hls(quality_output_file, roi, clip, scale, chunk_size)
logging.info(
f"Existing quality file {quality_output_name} found in {output_dir}."
)
# Create Quality Mask
# TODO: 掩膜数组的存在可能会造成Dask内存的溢出, 需要优化
qa_mask = create_quality_mask(qa_da, bit_nums=bit_nums)
# (Add) 若设置 quality_filter=True, 则在生成质量掩码后, 需要移除质量层, 避免后续重复处理
granule_urls = [url for url in granule_urls if "Fmask" not in url]
# Process Remaining Assets
for url in granule_urls:
# Check if File exists in Output Directory
output_name = create_output_name(url, band_dict)
output_file = f"{output_dir}/{output_name}"
# Check if scene is already processed
if not os.path.isfile(output_file):
# Open Asset
da = open_hls(url, roi, clip, scale, chunk_size)
# (Add) 若返回的da为None, 则表示该url对应的文件不存在/无法访问, 再次尝试无果后将会跳过
if da is None:
logging.warning(
f"Asset {url} not found in {output_dir}. Try again."
)
# 再次尝试下载
da = open_hls(url, roi, clip, scale, chunk_size)
if da is None:
logging.warning(f"Asset {url} still not found. Skipping.")
continue
# Apply Quality Mask if Desired
if quality_filter:
da = da.where(~qa_mask)
# Write Output
if "FMASK" in output_name and not quality_filter:
# (Add) 若 quality_filter=False, 则需要将质量层文件另外保存
da.rio.to_raster(raster_path=output_file, driver="COG", compress="DEFLATE")
else:
# (Add) 固定输出为 float32 类型, 否则会默认 float64 类型
da.rio.to_raster(
raster_path=output_file, driver="COG", dtype="float32", compress="DEFLATE"
)
else:
logging.info(
f"Existing file {output_name} found in {output_dir}. Skipping."
)
else:
logging.info(
f"All assets related to {granule_urls[0].split('/')[-1]} are already processed, skipping."
)
def build_hls_xarray_timeseries(
hls_cog_list, mask_and_scale=True, chunk_size=dict(band=1, x=512, y=512)
):
"""
Builds a single band timeseries using xarray for a list of HLS COGs. Dependent on file naming convention.
Works on SuPERScript named files. Files need common naming bands corresponding HLSS and HLSL bands,
e.g. HLSL30 Band 5 (NIR1) and HLSS30 Band 8A (NIR1)
"""
# Define Band(s)
bands = [filename.split(".")[6] for filename in hls_cog_list]
# Make sure all files in list are the same band
if not all(band == bands[0] for band in bands):
raise ValueError("All listed files must be of the same band.")
band_name = bands[0]
# Create Time Variable
try:
time_list = [
dt.strptime(filename.split(".")[3], "%Y%jT%H%M%S")
for filename in hls_cog_list
]
except ValueError:
print("A COG does not have a valid date string in the filename.")
time = xr.Variable("time", time_list)
timeseries_da = xr.concat(
[
rxr.open_rasterio(
filename, mask_and_scale=mask_and_scale, chunks=chunk_size
).squeeze("band", drop=True)
for filename in hls_cog_list
],
dim=time,
)
timeseries_da.name = band_name
return timeseries_da
def create_timeseries_dataset(hls_file_dir, output_type, output_dir=None):
"""
Creates an xarray dataset timeseries from a directory of HLS COGs.
Writes to a netcdf output. Currently only works for HLS SuPER outputs.
"""
# Setup Logging
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:%(asctime)s ||| %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
# List Files in Directory
all_files = [file for file in os.listdir(hls_file_dir) if file.endswith(".tif")]
# Create Dictionary of Files by Band
file_dict = {}
for file in all_files:
tile = file.split(".")[2]
band = file.split(".")[6]
full_path = os.path.join(hls_file_dir, file)
if tile not in file_dict:
file_dict[tile] = {}
if band not in file_dict[tile]:
file_dict[tile][band] = []
file_dict[tile][band].append(full_path)
# logging.info(f"{file_dict}")
# Check that all bands within each tile have the same number of observations
for tile, bands in file_dict.items():
q_obs = {band: len(files) for band, files in bands.items()}
if not all(q == list(q_obs.values())[0] for q in q_obs.values()):
logging.info(
f"Not all bands in {tile} have the same number of observations."
)
logging.info(f"{q_obs}")
# Loop through each tile and build timeseries output
for tile, bands in file_dict.items():
dataset = xr.Dataset()
timeseries_dict = {
band: dask.delayed(build_hls_xarray_timeseries)(files)
for band, files in bands.items()
}
timeseries_dict = dask.compute(timeseries_dict)[0]
dataset = xr.Dataset(timeseries_dict)
# Set up CF-Compliant Coordinate Attributes
dataset.attrs["Conventions"] = "CF-1.6"
dataset.attrs["title"] = "HLS SuPER Timeseries Dataset"
dataset.attrs["institution"] = "LP DAAC"
dataset.x.attrs["axis"] = "X"
dataset.x.attrs["standard_name"] = "projection_x_coordinate"
dataset.x.attrs["long_name"] = "x-coordinate in projected coordinate system"
dataset.x.attrs["units"] = "m"
dataset.y.attrs["axis"] = "Y"
dataset.y.attrs["standard_name"] = "projection_y_coordinate"
dataset.y.attrs["long_name"] = "y-coordinate in projected coordinate system"
dataset.y.attrs["units"] = "m"
dataset.time.attrs["axis"] = "Z"
dataset.time.attrs["standard_name"] = "time"
dataset.time.attrs["long_name"] = "time"
# Get first and last date
first_date = (
dataset.time.data[0].astype("M8[ms]").astype(dt).strftime("%Y-%m-%d")
)
final_date = (
dataset.time.data[-1].astype("M8[ms]").astype(dt).strftime("%Y-%m-%d")
)
# Write Outputs
# if output_type == "NC4":
output_path = os.path.join(
output_dir, f"HLS.{tile}.{first_date}.{final_date}.subset.nc"
)
dataset.to_netcdf(output_path)
# elif output_type == "ZARR":
# output_path = os.path.join(output_dir, "hls_timeseries_dataset.zarr")
# dataset.to_zarr(output_path)
logging.info(f"Output saved to {output_path}")