356 lines
13 KiB
Python
356 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
===============================================================================
|
||
HLS Processing and Exporting Reformatted Data (HLS_PER)
|
||
|
||
This module contains functions to conduct subsetting and quality filtering of
|
||
search results.
|
||
-------------------------------------------------------------------------------
|
||
Authors: Cole Krehbiel, Mahsa Jami, and Erik Bolch
|
||
Editor: Hong Xie
|
||
Last Updated: 2025-01-06
|
||
===============================================================================
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import logging
|
||
|
||
import numpy as np
|
||
from datetime import datetime as dt
|
||
import xarray as xr
|
||
import rioxarray as rxr
|
||
import dask.distributed
|
||
|
||
|
||
def create_output_name(url, band_dict):
|
||
"""
|
||
Uses HLS default naming scheme to generate an output name with common band names.
|
||
This allows for easier stacking of bands from both collections.
|
||
"""
|
||
# Get Necessary Strings
|
||
prod = url.split("/")[4].split(".")[0]
|
||
asset = url.split("/")[-1].split(".")[-2]
|
||
# Hard-coded one off for Fmask name incase it is not in the band_dict but is needed for masking
|
||
# 翻译:硬编码一个Fmask名称, 以防它不在band_dict中但需要用于掩膜处理
|
||
if asset == "Fmask":
|
||
output_name = f"{'.'.join(url.split('/')[-1].split('.')[:-2])}.FMASK.subset.tif"
|
||
else:
|
||
for key, value in band_dict[prod].items():
|
||
if value == asset:
|
||
output_name = (
|
||
f"{'.'.join(url.split('/')[-1].split('.')[:-2])}.{key}.subset.tif"
|
||
)
|
||
return output_name
|
||
|
||
|
||
def open_hls(
|
||
url, roi=None, clip=False, scale=True, chunk_size=dict(band=1, x=512, y=512)
|
||
):
|
||
"""
|
||
Generic Function to open an HLS COG and clip to ROI. For consistent scaling, this must be done manually.
|
||
Some HLS Landsat scenes have the metadata in the wrong location.
|
||
"""
|
||
# Open using rioxarray
|
||
da = rxr.open_rasterio(url, chunks=chunk_size, mask_and_scale=False).squeeze(
|
||
"band", drop=True
|
||
)
|
||
|
||
# (Add) 读取波段名称
|
||
split_asset = url.split("/")[-1].split(".")
|
||
asset_name = split_asset[1]
|
||
band_name = split_asset[-2]
|
||
|
||
# Reproject ROI and Clip if ROI is provided and clip is True
|
||
if roi is not None and clip:
|
||
roi = roi.to_crs(da.spatial_ref.crs_wkt)
|
||
da = da.rio.clip(roi.geometry.values, roi.crs, all_touched=True)
|
||
|
||
# (Add) 即使大部分影像已经被缩放且填补了缺失值, 但可能仍然有些影像需要进行手动在本地GIS软件中进行缩放和填补缺失值
|
||
# Apply Scale Factor if desired for non-quality layer
|
||
if band_name != "Fmask":
|
||
if scale:
|
||
# Mask Fill Values
|
||
da = xr.where(da == -9999, np.nan, da)
|
||
# Scale Data
|
||
# (Add) 除质量层, 以及 L30 的两个热红外波段外, 其他光谱波段缩放因子均为 0.0001
|
||
# (Add) L30 的两个热红外波段缩放因子为 0.01
|
||
# NOTE: 需要注意的是热红外此时未被改名
|
||
if (band_name == "B10" or band_name == "B11") and (asset_name == "L30"):
|
||
da = da * 0.01
|
||
else:
|
||
da = da * 0.0001
|
||
# Remove Scale Factor After Scaling - Prevents Double Scaling
|
||
da.attrs["scale_factor"] = 1.0
|
||
|
||
# Add Scale Factor to Attributes Manually - This will overwrite/add if the data is missing.
|
||
# (Add) 若要手动缩放, 则需要手动添加缩放因子
|
||
else:
|
||
if (band_name == "B10" or band_name == "B11") and (asset_name == "L30"):
|
||
da.attrs["scale_factor"] = 0.01
|
||
else:
|
||
da.attrs["scale_factor"] = 0.0001
|
||
|
||
return da
|
||
|
||
|
||
def create_quality_mask(quality_data, bit_nums: list = [0, 1, 2, 3, 4, 5]):
|
||
"""
|
||
Uses the Fmask layer and bit numbers to create a binary mask of good pixels.
|
||
By default, bits 0-5 are used.
|
||
"""
|
||
mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))
|
||
# Remove/Mask Fill Values and Convert to Integer
|
||
quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)
|
||
for bit in bit_nums:
|
||
# Create a Single Binary Mask Layer
|
||
mask_temp = np.array(quality_data) & 1 << bit > 0
|
||
mask_array = np.logical_or(mask_array, mask_temp)
|
||
return mask_array
|
||
|
||
|
||
def process_granule(
|
||
granule_urls,
|
||
roi,
|
||
clip,
|
||
quality_filter,
|
||
scale,
|
||
output_dir,
|
||
band_dict,
|
||
bit_nums=[0, 1, 2, 3, 4, 5],
|
||
chunk_size=dict(band=1, x=512, y=512),
|
||
):
|
||
"""
|
||
Processes a list of HLS asset urls for a single granule.
|
||
|
||
args:
|
||
granule_urls (list): List of HLS asset urls to process.
|
||
|
||
roi (geopandas.GeoDataFrame): ROI to filter data.
|
||
|
||
clip (bool): If True, ROI will be clipped to the image.
|
||
|
||
quality_filter (bool): If True, quality layer will be used to mask data.
|
||
|
||
scale (bool): If True, data will be scaled to reflectance.
|
||
|
||
output_dir (str): Directory to save output files.
|
||
|
||
band_dict (dict): Dictionary of band names and asset names.
|
||
|
||
bit_nums (list): List of bit numbers to use for quality mask.
|
||
- 0: Cirrus
|
||
- 1: Cloud
|
||
- 2: Adjacent to cloud/shadow
|
||
- 3: Cloud shadow
|
||
- 4: Snow/ice
|
||
- 5: Water
|
||
|
||
chunk_size (dict): Dictionary of chunk sizes for dask.
|
||
"""
|
||
|
||
# Setup Logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(levelname)s:%(asctime)s ||| %(message)s",
|
||
handlers=[logging.StreamHandler(sys.stdout)],
|
||
)
|
||
|
||
# Check if all Outputs Exist for a Granule
|
||
if not all(
|
||
os.path.isfile(f"{output_dir}/{create_output_name(url, band_dict)}")
|
||
for url in granule_urls
|
||
):
|
||
|
||
# First Handle Quality Layer
|
||
# (Add) 简化原有的冗余处理, 仅处理质量层, 并最后移除质量层下载url
|
||
if quality_filter:
|
||
# Generate Quality Layer URL
|
||
split_asset = granule_urls[0].split("/")[-1].split(".")
|
||
split_asset[-2] = "Fmask"
|
||
quality_url = (
|
||
f"{'/'.join(granule_urls[0].split('/')[:-1])}/{'.'.join(split_asset)}"
|
||
)
|
||
# Check if File exists in Output Directory
|
||
quality_output_name = create_output_name(quality_url, band_dict)
|
||
quality_output_file = f"{output_dir}/{quality_output_name}"
|
||
# Check if quality asset is already processed
|
||
if not os.path.isfile(quality_output_file):
|
||
# Write Output
|
||
# Open Quality Layer
|
||
qa_da = open_hls(quality_url, roi, clip, scale, chunk_size)
|
||
qa_da.rio.to_raster(raster_path=quality_output_file, driver="COG")
|
||
else:
|
||
logging.info(
|
||
f"Existing quality file {quality_output_name} found in {output_dir}."
|
||
)
|
||
# Create Quality Mask
|
||
# TODO: 掩膜数组的存在可能会造成Dask内存的溢出, 需要优化
|
||
qa_mask = create_quality_mask(qa_da, bit_nums=bit_nums)
|
||
# (Add) 若设置 quality_filter=True, 则在生成质量掩码后, 需要移除质量层, 避免后续重复处理
|
||
granule_urls = [url for url in granule_urls if "Fmask" not in url]
|
||
|
||
# Process Remaining Assets
|
||
for url in granule_urls:
|
||
# Check if File exists in Output Directory
|
||
output_name = create_output_name(url, band_dict)
|
||
output_file = f"{output_dir}/{output_name}"
|
||
|
||
# Check if scene is already processed
|
||
if not os.path.isfile(output_file):
|
||
# Open Asset
|
||
da = open_hls(url, roi, clip, scale, chunk_size)
|
||
|
||
# Apply Quality Mask if Desired
|
||
if quality_filter:
|
||
da = da.where(~qa_mask)
|
||
|
||
# Write Output
|
||
if "FMASK" in output_name:
|
||
da.rio.to_raster(raster_path=output_file, driver="COG")
|
||
else:
|
||
# (Add) 固定输出为 float32 类型, 否则会默认 float64 类型
|
||
da.rio.to_raster(
|
||
raster_path=output_file, driver="COG", dtype="float32"
|
||
)
|
||
else:
|
||
logging.info(
|
||
f"Existing file {output_name} found in {output_dir}. Skipping."
|
||
)
|
||
else:
|
||
logging.info(
|
||
f"All assets related to {granule_urls[0].split('/')[-1]} are already processed, skipping."
|
||
)
|
||
|
||
|
||
def build_hls_xarray_timeseries(
|
||
hls_cog_list, mask_and_scale=True, chunk_size=dict(band=1, x=512, y=512)
|
||
):
|
||
"""
|
||
Builds a single band timeseries using xarray for a list of HLS COGs. Dependent on file naming convention.
|
||
Works on SuPERScript named files. Files need common naming bands corresponding HLSS and HLSL bands,
|
||
e.g. HLSL30 Band 5 (NIR1) and HLSS30 Band 8A (NIR1)
|
||
"""
|
||
# Define Band(s)
|
||
bands = [filename.split(".")[6] for filename in hls_cog_list]
|
||
|
||
# Make sure all files in list are the same band
|
||
if not all(band == bands[0] for band in bands):
|
||
raise ValueError("All listed files must be of the same band.")
|
||
|
||
band_name = bands[0]
|
||
|
||
# Create Time Variable
|
||
try:
|
||
time_list = [
|
||
dt.strptime(filename.split(".")[3], "%Y%jT%H%M%S")
|
||
for filename in hls_cog_list
|
||
]
|
||
except ValueError:
|
||
print("A COG does not have a valid date string in the filename.")
|
||
|
||
time = xr.Variable("time", time_list)
|
||
|
||
timeseries_da = xr.concat(
|
||
[
|
||
rxr.open_rasterio(
|
||
filename, mask_and_scale=mask_and_scale, chunks=chunk_size
|
||
).squeeze("band", drop=True)
|
||
for filename in hls_cog_list
|
||
],
|
||
dim=time,
|
||
)
|
||
timeseries_da.name = band_name
|
||
|
||
return timeseries_da
|
||
|
||
|
||
def create_timeseries_dataset(hls_file_dir, output_type, output_dir=None):
|
||
"""
|
||
Creates an xarray dataset timeseries from a directory of HLS COGs.
|
||
Writes to a netcdf output. Currently only works for HLS SuPER outputs.
|
||
"""
|
||
|
||
# Setup Logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(levelname)s:%(asctime)s ||| %(message)s",
|
||
handlers=[logging.StreamHandler(sys.stdout)],
|
||
)
|
||
|
||
# List Files in Directory
|
||
all_files = [file for file in os.listdir(hls_file_dir) if file.endswith(".tif")]
|
||
|
||
# Create Dictionary of Files by Band
|
||
file_dict = {}
|
||
for file in all_files:
|
||
tile = file.split(".")[2]
|
||
band = file.split(".")[6]
|
||
full_path = os.path.join(hls_file_dir, file)
|
||
if tile not in file_dict:
|
||
file_dict[tile] = {}
|
||
if band not in file_dict[tile]:
|
||
file_dict[tile][band] = []
|
||
file_dict[tile][band].append(full_path)
|
||
|
||
# logging.info(f"{file_dict}")
|
||
|
||
# Check that all bands within each tile have the same number of observations
|
||
for tile, bands in file_dict.items():
|
||
q_obs = {band: len(files) for band, files in bands.items()}
|
||
if not all(q == list(q_obs.values())[0] for q in q_obs.values()):
|
||
logging.info(
|
||
f"Not all bands in {tile} have the same number of observations."
|
||
)
|
||
logging.info(f"{q_obs}")
|
||
|
||
# Loop through each tile and build timeseries output
|
||
|
||
for tile, bands in file_dict.items():
|
||
dataset = xr.Dataset()
|
||
|
||
timeseries_dict = {
|
||
band: dask.delayed(build_hls_xarray_timeseries)(files)
|
||
for band, files in bands.items()
|
||
}
|
||
timeseries_dict = dask.compute(timeseries_dict)[0]
|
||
dataset = xr.Dataset(timeseries_dict)
|
||
|
||
# Set up CF-Compliant Coordinate Attributes
|
||
dataset.attrs["Conventions"] = "CF-1.6"
|
||
dataset.attrs["title"] = "HLS SuPER Timeseries Dataset"
|
||
dataset.attrs["institution"] = "LP DAAC"
|
||
|
||
dataset.x.attrs["axis"] = "X"
|
||
dataset.x.attrs["standard_name"] = "projection_x_coordinate"
|
||
dataset.x.attrs["long_name"] = "x-coordinate in projected coordinate system"
|
||
dataset.x.attrs["units"] = "m"
|
||
|
||
dataset.y.attrs["axis"] = "Y"
|
||
dataset.y.attrs["standard_name"] = "projection_y_coordinate"
|
||
dataset.y.attrs["long_name"] = "y-coordinate in projected coordinate system"
|
||
dataset.y.attrs["units"] = "m"
|
||
|
||
dataset.time.attrs["axis"] = "Z"
|
||
dataset.time.attrs["standard_name"] = "time"
|
||
dataset.time.attrs["long_name"] = "time"
|
||
|
||
# Get first and last date
|
||
first_date = (
|
||
dataset.time.data[0].astype("M8[ms]").astype(dt).strftime("%Y-%m-%d")
|
||
)
|
||
final_date = (
|
||
dataset.time.data[-1].astype("M8[ms]").astype(dt).strftime("%Y-%m-%d")
|
||
)
|
||
|
||
# Write Outputs
|
||
# if output_type == "NC4":
|
||
output_path = os.path.join(
|
||
output_dir, f"HLS.{tile}.{first_date}.{final_date}.subset.nc"
|
||
)
|
||
dataset.to_netcdf(output_path)
|
||
# elif output_type == "ZARR":
|
||
# output_path = os.path.join(output_dir, "hls_timeseries_dataset.zarr")
|
||
# dataset.to_zarr(output_path)
|
||
logging.info(f"Output saved to {output_path}")
|