Skip to content

Commit

Permalink
Improve apply_kmeans function
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 13, 2024
1 parent 7a50660 commit 5df1bd5
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from typing import List, Tuple, Union, Optional, Any
from typing import List, Tuple, Union, Optional, Any, Callable
from .common import extract_date_from_filename


Expand Down Expand Up @@ -775,8 +775,10 @@ def cyano_band_ratios(
def apply_kmeans(
dataset: Union[xr.Dataset, str],
n_clusters: int = 6,
filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
colors: list[str] = None,
extent: list[float] = None,
title: str = None,
**kwargs,
Expand All @@ -789,6 +791,7 @@ def apply_kmeans(
n_clusters (int, optional): Number of clusters for K-means. Defaults to 6.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
colors (list[str], optional): List of colors to use for the clusters. Defaults to None.
extent (list[float] | None, optional): The extent to zoom in to the specified region. Defaults to None.
title (str | None, optional): Title for the plot. Defaults to None.
**kwargs: Additional keyword arguments to pass to the `plt.subplots` function.
Expand All @@ -807,6 +810,8 @@ def apply_kmeans(

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset()
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

Expand All @@ -831,15 +836,18 @@ def apply_kmeans(
# Reshape the labels back to the original spatial dimensions
cluster_labels = labels.reshape(da.shape[:-1])

if filter_condition is not None:
cluster_labels = np.where(filter_condition, cluster_labels, np.nan)

latitudes = da.coords["latitude"].values
longitudes = da.coords["longitude"].values

if plot:

# Create a custom discrete color map for K-means clusters
cmap = mcolors.ListedColormap(
["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]
)
if colors is None:
colors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(-0.5, n_clusters, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

Expand Down

0 comments on commit 5df1bd5

Please sign in to comment.