Skip to content

Commit

Permalink
Add apply_kmeans function
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 12, 2024
1 parent 36bcc7e commit 86f6ec6
Showing 1 changed file with 129 additions and 0 deletions.
129 changes: 129 additions & 0 deletions hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,132 @@ def cyano_band_ratios(
plt.show()

return data


def apply_kmeans(
dataset: Union[xr.Dataset, str],
n_clusters: int = 6,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
extent: list[float] | None = None,
title: str | None = None,
**kwargs,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Applies K-means clustering to the dataset and optionally plots the results.
Args:
dataset (xr.Dataset | str): The dataset containing the PACE data or the file path to the dataset.
n_clusters (int, optional): Number of clusters for K-means. Defaults to 6.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
extent (list[float] | None, optional): The extent to zoom in to the specified region. Defaults to None.
title (str | None, optional): Title for the plot. Defaults to None.
**kwargs: Additional keyword arguments to pass to the `plt.subplots` function.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: The cluster labels, latitudes, and longitudes.
"""

import numpy as np
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

if title is None:
title = f"K-means Clustering with {n_clusters} Clusters"

da = dataset["Rrs"]

reshaped_data = da.values.reshape(-1, da.shape[-1])
reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)]

# Apply K-means clustering to classify into 5-6 water types.
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
kmeans.fit(reshaped_data_no_nan)

# Initialize an array for cluster labels with NaN
labels = np.full(reshaped_data.shape[0], np.nan)

# Assign the computed cluster labels to the non-NaN positions
labels[~np.isnan(reshaped_data).any(axis=1)] = kmeans.labels_

# Reshape the labels back to the original spatial dimensions
cluster_labels = labels.reshape(da.shape[:-1])

latitudes = da.coords["latitude"].values
longitudes = da.coords["longitude"].values

if plot:

# Create a custom discrete color map for K-means clusters
cmap = mcolors.ListedColormap(
["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]
)
bounds = np.arange(-0.5, n_clusters, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create a figure and axis with the correct map projection

if "dpi" not in kwargs:
kwargs["dpi"] = 100

if "subplot_kw" not in kwargs:
kwargs["subplot_kw"] = {"projection": ccrs.PlateCarree()}

fig, ax = plt.subplots(
figsize=figsize,
**kwargs,
)

# Plot the K-means classification results on the map
im = ax.pcolormesh(
longitudes,
latitudes,
cluster_labels,
cmap=cmap,
norm=norm,
transform=ccrs.PlateCarree(),
)

# Add geographic features for context
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.STATES, linestyle="--")

# Add gridlines
ax.gridlines(draw_labels=True)

# Set the extent to zoom in to the specified region
if extent is not None:
ax.set_extent(extent, crs=ccrs.PlateCarree())

# Add color bar with labels
cbar = plt.colorbar(
im,
ax=ax,
orientation="vertical",
pad=0.02,
fraction=0.05,
ticks=np.arange(n_clusters),
)
cbar.ax.set_yticklabels([f"Class {i+1}" for i in range(n_clusters)])
cbar.set_label("Water Types", rotation=270, labelpad=10)

# Add title
ax.set_title(title, fontsize=14)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")

# Show the plot
plt.show()

return cluster_labels, latitudes, longitudes

0 comments on commit 86f6ec6

Please sign in to comment.