diff --git a/hypercoast/pace.py b/hypercoast/pace.py index ae454e7..28b530b 100644 --- a/hypercoast/pace.py +++ b/hypercoast/pace.py @@ -976,6 +976,7 @@ def apply_sam( n_components: int = 3, n_clusters: int = 6, random_state: int = 0, + filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None, plot: bool = True, figsize: tuple[int, int] = (8, 6), extent: list[float] = None, @@ -991,6 +992,7 @@ def apply_sam( n_components (int, optional): Number of principal components to compute. Defaults to 3. n_clusters (int, optional): Number of clusters for K-means. Defaults to 6. random_state (int, optional): Random state for K-means. Defaults to 0. + filter_condition (Callable[[xr.DataArray], xr.DataArray], optional): A function to filter the data. Defaults to None. plot (bool, optional): Whether to plot the data. Defaults to True. figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6). extent (List[float], optional): The extent to zoom in to the specified region. Defaults to None. @@ -1058,6 +1060,9 @@ def spectral_angle_mapper(pixel, reference): best_match_full[~np.isnan(reshaped_data).any(axis=1)] = best_match best_match_full = best_match_full.reshape(original_shape) + if filter_condition is not None: + best_match_full = np.where(filter_condition, best_match_full, np.nan) + latitudes = da.coords["latitude"].values longitudes = da.coords["longitude"].values