From 5df1bd543a5b9565f38beaba6dd2a51659585154 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sat, 12 Oct 2024 21:17:45 -0400 Subject: [PATCH] Improve apply_kmeans function --- hypercoast/pace.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/hypercoast/pace.py b/hypercoast/pace.py index 636cbaf..21ea508 100644 --- a/hypercoast/pace.py +++ b/hypercoast/pace.py @@ -8,7 +8,7 @@ import numpy as np import xarray as xr import matplotlib.pyplot as plt -from typing import List, Tuple, Union, Optional, Any +from typing import List, Tuple, Union, Optional, Any, Callable from .common import extract_date_from_filename @@ -775,8 +775,10 @@ def cyano_band_ratios( def apply_kmeans( dataset: Union[xr.Dataset, str], n_clusters: int = 6, + filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None, plot: bool = True, figsize: tuple[int, int] = (8, 6), + colors: list[str] = None, extent: list[float] = None, title: str = None, **kwargs, @@ -789,6 +791,7 @@ def apply_kmeans( n_clusters (int, optional): Number of clusters for K-means. Defaults to 6. plot (bool, optional): Whether to plot the data. Defaults to True. figsize (tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6). + colors (list[str], optional): List of colors to use for the clusters. Defaults to None. extent (list[float] | None, optional): The extent to zoom in to the specified region. Defaults to None. title (str | None, optional): Title for the plot. Defaults to None. **kwargs: Additional keyword arguments to pass to the `plt.subplots` function. @@ -807,6 +810,8 @@ def apply_kmeans( if isinstance(dataset, str): dataset = read_pace(dataset) + elif isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset() elif not isinstance(dataset, xr.Dataset): raise ValueError("dataset must be an xarray Dataset") @@ -831,15 +836,18 @@ def apply_kmeans( # Reshape the labels back to the original spatial dimensions cluster_labels = labels.reshape(da.shape[:-1]) + if filter_condition is not None: + cluster_labels = np.where(filter_condition, cluster_labels, np.nan) + latitudes = da.coords["latitude"].values longitudes = da.coords["longitude"].values if plot: # Create a custom discrete color map for K-means clusters - cmap = mcolors.ListedColormap( - ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"] - ) + if colors is None: + colors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"] + cmap = mcolors.ListedColormap(colors) bounds = np.arange(-0.5, n_clusters, 1) norm = mcolors.BoundaryNorm(bounds, cmap.N)