Source code for xdatasets.spatial

import logging
import warnings

import pandas as pd
import xarray as xr
from clisops.core.subset import shape_bbox_indexer, subset_gridpoint
from tqdm.auto import tqdm

from .utils import HiddenPrints


try:
    import xagg as xa
except ImportError:
    warnings.warn("xagg is not installed. Please install it with `pip install xagg`", stacklevel=2)
    xa = None


[docs] def bbox_ds(ds_copy, geom): indexer = shape_bbox_indexer(ds_copy, geom) da = ds_copy.isel(indexer) da = da.chunk({"latitude": -1, "longitude": -1}) return da
[docs] def clip_by_bbox(ds, space, dataset_name): msg = f"Spatial operations: processing bbox with {dataset_name}" logging.info(msg) indexer = shape_bbox_indexer(ds, space["geometry"]) ds_copy = ds.isel(indexer).copy() return ds_copy
[docs] def create_weights_mask(da, poly): if xa is None: raise ImportError( "xagg is not installed. Please install it with `pip install xagg`", ) weightmap = xa.pixel_overlaps(da, poly, subset_bbox=True) pixels = pd.DataFrame( index=weightmap.agg["pix_idxs"][0], data=list(map(list, weightmap.agg["coords"][0])), columns=["latitude", "longitude"], ) weights = pd.DataFrame( index=weightmap.agg["pix_idxs"][0], data=weightmap.agg["rel_area"][0][0].tolist(), columns=["weights"], ) df = pd.merge(pixels, weights, left_index=True, right_index=True) return df.set_index(["latitude", "longitude"]).to_xarray()
[docs] def aggregate(ds_in, ds_weights): return (ds_in * ds_weights.weights).sum(["latitude", "longitude"], min_count=1)
[docs] def clip_by_polygon(ds, space, dataset_name): # We are not using clisops for weighted averages because it is too unstable for now. # We use the xagg package instead. indexer = shape_bbox_indexer(ds, space["geometry"]) ds_copy = ds.isel(indexer).copy() arrays = [] pbar = tqdm(space["geometry"].iterrows(), position=0, leave=True) for idx, row in pbar: item = row[space["unique_id"]] if space["unique_id"] is not None and space["unique_id"] in row else idx pbar.set_description( f"Spatial operations: processing polygon {item} with {dataset_name}", ) geom = space["geometry"].iloc[[idx]] da = bbox_ds(ds_copy, geom) # Average data array over shape # da = average_shape(da, shape=geom) with HiddenPrints(): ds_weights = create_weights_mask(da.isel(time=0), geom) if space["averaging"] is True: da_tmp = aggregate(da, ds_weights) for var in da_tmp.variables: da_tmp[var].attrs = da[var].attrs da = da_tmp else: da = xr.merge([da, ds_weights]) da = da.where(da.weights.notnull(), drop=True) da = da.expand_dims({"geom": geom.index.values}) arrays.append(da) data = xr.concat(arrays, dim="geom") if space.get("unique_id") is not None: try: data = data.swap_dims({"geom": space["unique_id"]}) data = data.drop_vars("geom") if space["unique_id"] not in data.coords: data = data.assign_coords( { space["unique_id"]: ( space["unique_id"], space["geometry"][space["unique_id"]], ), }, ) data[space["unique_id"]].attrs["cf_role"] = "timeseries_id" except KeyError: # noqa: S110 pass return data
[docs] def clip_by_point(ds, space, dataset_name): # TODO : adapt logic for coordinate names msg = f"Spatial operations: processing points with {dataset_name}" logging.info(msg) lat, lon = zip(*space["geometry"].values(), strict=False) data = subset_gridpoint( ds.rename({"latitude": "lat", "longitude": "lon"}), lon=list(lon), lat=list(lat), ) data = data.rename({"lat": "latitude", "lon": "longitude"}) data = data.assign_coords({"site": ("site", list(space["geometry"].keys()))}) data["site"].attrs["cf_role"] = "timeseries_id" return data