diff --git a/hypercoast/pace.py b/hypercoast/pace.py index 7990eb8..e4ed09d 100644 --- a/hypercoast/pace.py +++ b/hypercoast/pace.py @@ -770,3 +770,132 @@ def cyano_band_ratios( plt.show() return data + + +def apply_kmeans( + dataset: Union[xr.Dataset, str], + n_clusters: int = 6, + plot: bool = True, + figsize: tuple[int, int] = (8, 6), + extent: list[float] | None = None, + title: str | None = None, + **kwargs, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Applies K-means clustering to the dataset and optionally plots the results. + + Args: + dataset (xr.Dataset | str): The dataset containing the PACE data or the file path to the dataset. + n_clusters (int, optional): Number of clusters for K-means. Defaults to 6. + plot (bool, optional): Whether to plot the data. Defaults to True. + figsize (tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6). + extent (list[float] | None, optional): The extent to zoom in to the specified region. Defaults to None. + title (str | None, optional): Title for the plot. Defaults to None. + **kwargs: Additional keyword arguments to pass to the `plt.subplots` function. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: The cluster labels, latitudes, and longitudes. + """ + + import numpy as np + from sklearn.cluster import KMeans + + import matplotlib.pyplot as plt + import matplotlib.colors as mcolors + import cartopy.crs as ccrs + import cartopy.feature as cfeature + + if isinstance(dataset, str): + dataset = read_pace(dataset) + elif not isinstance(dataset, xr.Dataset): + raise ValueError("dataset must be an xarray Dataset") + + if title is None: + title = f"K-means Clustering with {n_clusters} Clusters" + + da = dataset["Rrs"] + + reshaped_data = da.values.reshape(-1, da.shape[-1]) + reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)] + + # Apply K-means clustering to classify into 5-6 water types. + kmeans = KMeans(n_clusters=n_clusters, random_state=0) + kmeans.fit(reshaped_data_no_nan) + + # Initialize an array for cluster labels with NaN + labels = np.full(reshaped_data.shape[0], np.nan) + + # Assign the computed cluster labels to the non-NaN positions + labels[~np.isnan(reshaped_data).any(axis=1)] = kmeans.labels_ + + # Reshape the labels back to the original spatial dimensions + cluster_labels = labels.reshape(da.shape[:-1]) + + latitudes = da.coords["latitude"].values + longitudes = da.coords["longitude"].values + + if plot: + + # Create a custom discrete color map for K-means clusters + cmap = mcolors.ListedColormap( + ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"] + ) + bounds = np.arange(-0.5, n_clusters, 1) + norm = mcolors.BoundaryNorm(bounds, cmap.N) + + # Create a figure and axis with the correct map projection + + if "dpi" not in kwargs: + kwargs["dpi"] = 100 + + if "subplot_kw" not in kwargs: + kwargs["subplot_kw"] = {"projection": ccrs.PlateCarree()} + + fig, ax = plt.subplots( + figsize=figsize, + **kwargs, + ) + + # Plot the K-means classification results on the map + im = ax.pcolormesh( + longitudes, + latitudes, + cluster_labels, + cmap=cmap, + norm=norm, + transform=ccrs.PlateCarree(), + ) + + # Add geographic features for context + ax.add_feature(cfeature.COASTLINE) + ax.add_feature(cfeature.BORDERS, linestyle=":") + ax.add_feature(cfeature.STATES, linestyle="--") + + # Add gridlines + ax.gridlines(draw_labels=True) + + # Set the extent to zoom in to the specified region + if extent is not None: + ax.set_extent(extent, crs=ccrs.PlateCarree()) + + # Add color bar with labels + cbar = plt.colorbar( + im, + ax=ax, + orientation="vertical", + pad=0.02, + fraction=0.05, + ticks=np.arange(n_clusters), + ) + cbar.ax.set_yticklabels([f"Class {i+1}" for i in range(n_clusters)]) + cbar.set_label("Water Types", rotation=270, labelpad=10) + + # Add title + ax.set_title(title, fontsize=14) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + + # Show the plot + plt.show() + + return cluster_labels, latitudes, longitudes