Skip to content

Commit

Permalink
add __array__ for pairs,loops& fix small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanchengyan committed Nov 20, 2023
1 parent 4b6217f commit f055034
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 60 deletions.
114 changes: 56 additions & 58 deletions faninsar/_core/pair_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -45,6 +45,9 @@ def __eq__(self, other: "Pair") -> bool:
def __hash__(self) -> int:
return hash(self._name)

def __array__(self) -> np.ndarray:
return self._values

@property
def values(self):
"""return the values of the pair.
Expand Down Expand Up @@ -195,22 +198,11 @@ def __init__(
"""
if pairs is None or len(pairs) == 0:
raise ValueError("pairs cannot be None.")
pairs_ls = []
for pair in pairs:
if isinstance(pair, Pair):
_pair = pair
elif isinstance(pair, Iterable):
_pair = Pair(pair)
else:
raise TypeError(
f"pairs should be an Iterable containing Iterable or Pair object, but got {type(pair)}."
)
pairs_ls.append(_pair.values)

_values = np.array(pairs_ls)
_values = np.array(pairs, dtype="M8[D]")

self._values = _values
self._dates = np.unique(pairs_ls)
self._dates = np.unique(_values.flatten())
self._length = self._values.shape[0]
self._edge_index = np.searchsorted(self._dates, self._values)

Expand Down Expand Up @@ -328,6 +320,9 @@ def __contains__(self, item):

return np.any(np.all(item == self.values, axis=1))

def __array__(self) -> np.ndarray:
return self._values

@property
def values(self) -> np.ndarray:
"""return the pairs array in type of np.datetime64[D]"""
Expand Down Expand Up @@ -561,6 +556,9 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash(self.name)

def __array__(self) -> np.ndarray:
return self._values

@property
def values(self) -> np.ndarray:
"""return the values array of the loop.
Expand Down Expand Up @@ -683,24 +681,13 @@ def __init__(
"""
if loops is None or len(loops) == 0:
raise ValueError("loops cannot be None.")
loops_ls = []
for loop in loops:
if isinstance(loop, Loop):
_loop = loop
elif isinstance(loop, Iterable):
_loop = Loop(loop)
else:
raise TypeError(
f"loops should be an Iterable containing Iterable or Poop object, but got {type(loop)}."
)
loops_ls.append(_loop.values)

_values = np.array(loops_ls)
_values = np.array(loops, dtype="M8[D]")

self._values = _values
self._dates = np.unique(loops_ls)
self._dates = np.unique(_values)
self._length = self._values.shape[0]

if sort:
self.sort(inplace=True)

Expand Down Expand Up @@ -814,6 +801,9 @@ def __contains__(self, item):

return np.any(np.all(item == self.values, axis=1))

def __array__(self) -> np.ndarray:
return self._values

@property
def values(self) -> np.ndarray:
"""return the values of the loops.
Expand Down Expand Up @@ -1070,7 +1060,7 @@ def sort(
Returns
-------
None or (Loops, np.ndarray). if inplace is True, return the sorted loops
None or (Loops, np.ndarray). if inplace is True, return the sorted loops
and the index of the sorted loops in the original loops. Otherwise,
return None.
"""
Expand All @@ -1085,10 +1075,10 @@ def sort(
"days23": self.days23,
"days13": self.days13,
"date": self._values,
"pairs": zip(
"pairs": np.hstack(
[self.pairs12.values, self.pairs23.values, self.pairs13.values]
),
"days": zip([self.days12, self.days23, self.days13]),
"days": np.hstack([self.days12, self.days23, self.days13]),
}
if isinstance(order, str):
order = [order]
Expand All @@ -1098,7 +1088,7 @@ def sort(
raise ValueError(
f"order should be one of {list(item_map.keys())}, but got {order}."
)
_values.append(item_map[i].reshape(self._length, -1))
_values.append(item_map[i])
_values = np.hstack(_values)
_, _index = np.unique(_values, axis=0, return_index=True)
if not ascending:
Expand Down Expand Up @@ -1303,11 +1293,11 @@ def sort(
Whether to sort ascending. Default is True.
inplace: bool, optional
Whether to sort the pairs and loops inplace. Default is True.
Returns
-------
None or (SBASNetwork, np.ndarray). if inplace is True, return the sorted
SBASNetwork and the index of the sorted pairs in the original pairs.
None or (SBASNetwork, np.ndarray). if inplace is True, return the sorted
SBASNetwork and the index of the sorted pairs in the original pairs.
Otherwise, return None.
"""
pairs, _index = self._pairs.sort(order, ascending, inplace=False)
Expand Down Expand Up @@ -1358,7 +1348,7 @@ def __init__(self, dates: Iterable, **kwargs) -> None:
date_args: dict, optional
Keyword arguments for pd.to_datetime().
"""
self.dates = pd.to_datetime(dates, **kwargs)
self.dates = pd.to_datetime(dates, **kwargs).unique().sort_values()

def from_interval(self, max_interval: int = 2, max_day: int = 180) -> Pairs:
"""generate interferometric pairs by SAR acquisition interval. SAR
Expand Down Expand Up @@ -1406,38 +1396,44 @@ def from_period(
period_start: str = "1201",
period_end: str = "0331",
n_per_period: str = 3,
n_primary_period: str = 2,
n_primary_period: Optional[str] = None,
primary_years: Optional[list[int]] = None,
) -> Pairs:
"""generate interferometric pairs between periods for all years. period is defined by month
and day for each year. For example, period_start='1201', period_end='0331' means the period
is from Dec 1 to Mar 31 for each year in the time series. This function will randomly select
n_per_period dates in each period and generate interferometric pairs between those dates. This
will be useful to mitigate the temporal cumulative bias.
"""generate interferometric pairs between periods for all years.
period is defined by month and day for each year. For example,
period_start='1201', period_end='0331' means the period is from Dec 1
to Mar 31 for each year in the time series. This function will randomly
select n_per_period dates in each period and generate interferometric
pairs between those dates. This will be useful to mitigate the temporal
cumulative bias.
Parameters
----------
period_start, period_end: str
start and end date for the period which expressed as month and day with format '%m%d'
start and end date for the period which expressed as month and day
with format '%m%d'
n_per_period: int
how many dates will be used for each period. Those dates will be selected randomly
in each period. Default is 3
n_primary_period: int
how many periods used as primary date of ifg. Default is 2. For example, if n_primary_period=2,
then the interferometric pairs will be generated between the first two periods and the rest
periods.
how many dates will be used for each period. Those dates will be
selected randomly in each period. Default is 3
n_primary_period: int, optional
how many periods/years used as primary date of ifg. For example, if
n_primary_period=2, then the interferometric pairs will be generated
between the first two periods and the rest periods. If None, all
periods will be used. Default is None.
primary_years: list, optional
years used as primary date of ifg. If None, all years in the time
series will be used. Default is None.
Returns
-------
pairs: Pairs object
"""
years = sorted(set(self.dates.year))
df_dates = pd.Series(self.dates.strftime("%Y%m%d"), index=self.dates)
df_dates = pd.Series(self.dates, index=self.dates)

# check if period_start and period_end are in the same year. If not, the period_end should be
# in the next year
same_year = False
if int(period_start) < int(period_end):
same_year = True
# check if period_start and period_end are in the same year. If not,
# the period_end should be in the next year
same_year = True if int(period_start) < int(period_end) else False

# randomly select n_per_period dates in each period/year
date_years = []
Expand All @@ -1457,8 +1453,10 @@ def from_period(
_pairs = []
for i, date_year in enumerate(date_years):
# only generate pairs for n_primary_period
if i + 1 > n_primary_period:
if n_primary_period is not None and i + 1 > n_primary_period:
break
if primary_years is not None and years[i] not in primary_years:
continue
for date_primary in date_year:
# all rest periods
for date_year1 in date_years[i + 1 :]:
Expand Down Expand Up @@ -1489,10 +1487,10 @@ def from_summer_winter(
Returns
-------
Pairs object
Pairs object
"""
years = sorted(set(self.dates.year))
df_dates = pd.Series(self.dates.strftime('%Y%m%d'), index=self.dates)
df_dates = pd.Series(self.dates.strftime("%Y%m%d"), index=self.dates)

_pairs = []
for year in years:
Expand Down
6 changes: 4 additions & 2 deletions tests/sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pathlib import Path

from faninsar import samplers
from faninsar.datasets import RasterDataset,BoundingBox,Points
from tqdm import tqdm

from faninsar import samplers
from faninsar.datasets import RasterDataset
from faninsar.query import BoundingBox, Points

home_dir = Path("/home/fancy/work/data/test")
files = list(home_dir.rglob("*unw_phase_clip.tif"))

Expand Down

0 comments on commit f055034

Please sign in to comment.