diff --git a/faninsar/_core/pair_tools.py b/faninsar/_core/pair_tools.py index cc8d3ff..12bfd9b 100644 --- a/faninsar/_core/pair_tools.py +++ b/faninsar/_core/pair_tools.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from datetime import datetime from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -45,6 +45,9 @@ def __eq__(self, other: "Pair") -> bool: def __hash__(self) -> int: return hash(self._name) + def __array__(self) -> np.ndarray: + return self._values + @property def values(self): """return the values of the pair. @@ -195,22 +198,11 @@ def __init__( """ if pairs is None or len(pairs) == 0: raise ValueError("pairs cannot be None.") - pairs_ls = [] - for pair in pairs: - if isinstance(pair, Pair): - _pair = pair - elif isinstance(pair, Iterable): - _pair = Pair(pair) - else: - raise TypeError( - f"pairs should be an Iterable containing Iterable or Pair object, but got {type(pair)}." - ) - pairs_ls.append(_pair.values) - _values = np.array(pairs_ls) + _values = np.array(pairs, dtype="M8[D]") self._values = _values - self._dates = np.unique(pairs_ls) + self._dates = np.unique(_values.flatten()) self._length = self._values.shape[0] self._edge_index = np.searchsorted(self._dates, self._values) @@ -328,6 +320,9 @@ def __contains__(self, item): return np.any(np.all(item == self.values, axis=1)) + def __array__(self) -> np.ndarray: + return self._values + @property def values(self) -> np.ndarray: """return the pairs array in type of np.datetime64[D]""" @@ -561,6 +556,9 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash(self.name) + def __array__(self) -> np.ndarray: + return self._values + @property def values(self) -> np.ndarray: """return the values array of the loop. @@ -683,24 +681,13 @@ def __init__( """ if loops is None or len(loops) == 0: raise ValueError("loops cannot be None.") - loops_ls = [] - for loop in loops: - if isinstance(loop, Loop): - _loop = loop - elif isinstance(loop, Iterable): - _loop = Loop(loop) - else: - raise TypeError( - f"loops should be an Iterable containing Iterable or Poop object, but got {type(loop)}." - ) - loops_ls.append(_loop.values) - _values = np.array(loops_ls) + _values = np.array(loops, dtype="M8[D]") self._values = _values - self._dates = np.unique(loops_ls) + self._dates = np.unique(_values) self._length = self._values.shape[0] - + if sort: self.sort(inplace=True) @@ -814,6 +801,9 @@ def __contains__(self, item): return np.any(np.all(item == self.values, axis=1)) + def __array__(self) -> np.ndarray: + return self._values + @property def values(self) -> np.ndarray: """return the values of the loops. @@ -1070,7 +1060,7 @@ def sort( Returns ------- - None or (Loops, np.ndarray). if inplace is True, return the sorted loops + None or (Loops, np.ndarray). if inplace is True, return the sorted loops and the index of the sorted loops in the original loops. Otherwise, return None. """ @@ -1085,10 +1075,10 @@ def sort( "days23": self.days23, "days13": self.days13, "date": self._values, - "pairs": zip( + "pairs": np.hstack( [self.pairs12.values, self.pairs23.values, self.pairs13.values] ), - "days": zip([self.days12, self.days23, self.days13]), + "days": np.hstack([self.days12, self.days23, self.days13]), } if isinstance(order, str): order = [order] @@ -1098,7 +1088,7 @@ def sort( raise ValueError( f"order should be one of {list(item_map.keys())}, but got {order}." ) - _values.append(item_map[i].reshape(self._length, -1)) + _values.append(item_map[i]) _values = np.hstack(_values) _, _index = np.unique(_values, axis=0, return_index=True) if not ascending: @@ -1303,11 +1293,11 @@ def sort( Whether to sort ascending. Default is True. inplace: bool, optional Whether to sort the pairs and loops inplace. Default is True. - + Returns ------- - None or (SBASNetwork, np.ndarray). if inplace is True, return the sorted - SBASNetwork and the index of the sorted pairs in the original pairs. + None or (SBASNetwork, np.ndarray). if inplace is True, return the sorted + SBASNetwork and the index of the sorted pairs in the original pairs. Otherwise, return None. """ pairs, _index = self._pairs.sort(order, ascending, inplace=False) @@ -1358,7 +1348,7 @@ def __init__(self, dates: Iterable, **kwargs) -> None: date_args: dict, optional Keyword arguments for pd.to_datetime(). """ - self.dates = pd.to_datetime(dates, **kwargs) + self.dates = pd.to_datetime(dates, **kwargs).unique().sort_values() def from_interval(self, max_interval: int = 2, max_day: int = 180) -> Pairs: """generate interferometric pairs by SAR acquisition interval. SAR @@ -1406,38 +1396,44 @@ def from_period( period_start: str = "1201", period_end: str = "0331", n_per_period: str = 3, - n_primary_period: str = 2, + n_primary_period: Optional[str] = None, + primary_years: Optional[list[int]] = None, ) -> Pairs: - """generate interferometric pairs between periods for all years. period is defined by month - and day for each year. For example, period_start='1201', period_end='0331' means the period - is from Dec 1 to Mar 31 for each year in the time series. This function will randomly select - n_per_period dates in each period and generate interferometric pairs between those dates. This - will be useful to mitigate the temporal cumulative bias. + """generate interferometric pairs between periods for all years. + period is defined by month and day for each year. For example, + period_start='1201', period_end='0331' means the period is from Dec 1 + to Mar 31 for each year in the time series. This function will randomly + select n_per_period dates in each period and generate interferometric + pairs between those dates. This will be useful to mitigate the temporal + cumulative bias. Parameters ---------- period_start, period_end: str - start and end date for the period which expressed as month and day with format '%m%d' + start and end date for the period which expressed as month and day + with format '%m%d' n_per_period: int - how many dates will be used for each period. Those dates will be selected randomly - in each period. Default is 3 - n_primary_period: int - how many periods used as primary date of ifg. Default is 2. For example, if n_primary_period=2, - then the interferometric pairs will be generated between the first two periods and the rest - periods. + how many dates will be used for each period. Those dates will be + selected randomly in each period. Default is 3 + n_primary_period: int, optional + how many periods/years used as primary date of ifg. For example, if + n_primary_period=2, then the interferometric pairs will be generated + between the first two periods and the rest periods. If None, all + periods will be used. Default is None. + primary_years: list, optional + years used as primary date of ifg. If None, all years in the time + series will be used. Default is None. Returns ------- pairs: Pairs object """ years = sorted(set(self.dates.year)) - df_dates = pd.Series(self.dates.strftime("%Y%m%d"), index=self.dates) + df_dates = pd.Series(self.dates, index=self.dates) - # check if period_start and period_end are in the same year. If not, the period_end should be - # in the next year - same_year = False - if int(period_start) < int(period_end): - same_year = True + # check if period_start and period_end are in the same year. If not, + # the period_end should be in the next year + same_year = True if int(period_start) < int(period_end) else False # randomly select n_per_period dates in each period/year date_years = [] @@ -1457,8 +1453,10 @@ def from_period( _pairs = [] for i, date_year in enumerate(date_years): # only generate pairs for n_primary_period - if i + 1 > n_primary_period: + if n_primary_period is not None and i + 1 > n_primary_period: break + if primary_years is not None and years[i] not in primary_years: + continue for date_primary in date_year: # all rest periods for date_year1 in date_years[i + 1 :]: @@ -1489,10 +1487,10 @@ def from_summer_winter( Returns ------- - Pairs object + Pairs object """ years = sorted(set(self.dates.year)) - df_dates = pd.Series(self.dates.strftime('%Y%m%d'), index=self.dates) + df_dates = pd.Series(self.dates.strftime("%Y%m%d"), index=self.dates) _pairs = [] for year in years: diff --git a/tests/sampler.py b/tests/sampler.py index 7305f54..42ab051 100644 --- a/tests/sampler.py +++ b/tests/sampler.py @@ -1,9 +1,11 @@ from pathlib import Path -from faninsar import samplers -from faninsar.datasets import RasterDataset,BoundingBox,Points from tqdm import tqdm +from faninsar import samplers +from faninsar.datasets import RasterDataset +from faninsar.query import BoundingBox, Points + home_dir = Path("/home/fancy/work/data/test") files = list(home_dir.rglob("*unw_phase_clip.tif"))