-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
245 fix querying public extraction function #246
Changes from all commits
3141a75
f24569e
c55a9c6
cb9bddf
6886157
44934e8
925ff21
de55b03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,7 +168,9 @@ def query_public_extractions( | |
) | ||
|
||
# Process the parquet into the format we need for training | ||
processed_public_df = process_parquet(public_df_raw, processing_period) | ||
processed_public_df = process_public_extractions_df( | ||
public_df_raw, processing_period | ||
) | ||
|
||
return processed_public_df | ||
|
||
|
@@ -191,7 +193,7 @@ def month_diff(month1: int, month2: int) -> int: | |
The difference between `month1` and `month2`. | ||
""" | ||
|
||
return month2 - month1 if month2 >= month1 else 12 - month1 + month2 | ||
return (month2 - month1) % 12 | ||
|
||
|
||
def get_best_valid_date(row: pd.Series): | ||
|
@@ -219,57 +221,59 @@ def get_best_valid_date(row: pd.Series): | |
|
||
from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS | ||
|
||
# check if shift forward will fit into existing extractions | ||
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period | ||
temp_end_date = row["valid_date"] + pd.DateOffset( | ||
months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER | ||
def is_within_period(proposed_date, start_date, end_date): | ||
return (proposed_date - pd.DateOffset(months=MIN_EDGE_BUFFER) >= start_date) & ( | ||
proposed_date + pd.DateOffset(months=MIN_EDGE_BUFFER) <= end_date | ||
) | ||
|
||
def check_shift(proposed_date, valid_date, start_date, end_date): | ||
proposed_start_date = proposed_date - pd.DateOffset( | ||
months=(NUM_TIMESTEPS // 2 - 1) | ||
) | ||
proposed_end_date = proposed_date + pd.DateOffset(months=(NUM_TIMESTEPS // 2)) | ||
return ( | ||
is_within_period(proposed_date, start_date, end_date) | ||
& (valid_date >= proposed_start_date) | ||
& (valid_date <= proposed_end_date) | ||
) | ||
|
||
valid_date = row["valid_date"] | ||
start_date = row["start_date"] | ||
end_date = row["end_date"] | ||
|
||
proposed_valid_date_fwd = valid_date + pd.DateOffset( | ||
months=row["valid_month_shift_forward"] | ||
) | ||
proposed_valid_date_bwd = valid_date - pd.DateOffset( | ||
months=row["valid_month_shift_backward"] | ||
) | ||
temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS) | ||
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): | ||
shift_forward_ok = True | ||
else: | ||
shift_forward_ok = False | ||
|
||
# check if shift backward will fit into existing extractions | ||
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period | ||
temp_start_date = row["valid_date"] - pd.DateOffset( | ||
months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER | ||
shift_forward_ok = check_shift( | ||
proposed_valid_date_fwd, valid_date, start_date, end_date | ||
) | ||
shift_backward_ok = check_shift( | ||
proposed_valid_date_bwd, valid_date, start_date, end_date | ||
) | ||
temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS) | ||
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): | ||
shift_backward_ok = True | ||
else: | ||
shift_backward_ok = False | ||
|
||
if (not shift_forward_ok) & (not shift_backward_ok): | ||
if not shift_forward_ok and not shift_backward_ok: | ||
return np.nan | ||
|
||
if shift_forward_ok & (not shift_backward_ok): | ||
return row["valid_date"] + pd.DateOffset( | ||
months=row["valid_month_shift_forward"] | ||
if shift_forward_ok and not shift_backward_ok: | ||
return proposed_valid_date_fwd | ||
if not shift_forward_ok and shift_backward_ok: | ||
return proposed_valid_date_bwd | ||
if shift_forward_ok and shift_backward_ok: | ||
return ( | ||
proposed_valid_date_bwd | ||
if (row["valid_month_shift_backward"] - row["valid_month_shift_forward"]) | ||
<= MIN_EDGE_BUFFER | ||
else proposed_valid_date_fwd | ||
) | ||
|
||
if (not shift_forward_ok) & shift_backward_ok: | ||
return row["valid_date"] - pd.DateOffset( | ||
months=row["valid_month_shift_backward"] | ||
) | ||
|
||
if shift_forward_ok & shift_backward_ok: | ||
# if shift backward is not too much bigger than shift forward, choose backward | ||
if ( | ||
row["valid_month_shift_backward"] - row["valid_month_shift_forward"] | ||
) <= MIN_EDGE_BUFFER: | ||
return row["valid_date"] - pd.DateOffset( | ||
months=row["valid_month_shift_backward"] | ||
) | ||
else: | ||
return row["valid_date"] + pd.DateOffset( | ||
months=row["valid_month_shift_forward"] | ||
) | ||
|
||
|
||
def process_parquet( | ||
public_df_raw: pd.DataFrame, processing_period: TemporalContext = None | ||
def process_public_extractions_df( | ||
public_df_raw: pd.DataFrame, | ||
processing_period: TemporalContext = None, | ||
freq: str = "MS", | ||
) -> pd.DataFrame: | ||
"""Method to transform the raw parquet data into a format that can be used for | ||
training. Includes pivoting of the dataframe and mapping of the crop types. | ||
|
@@ -278,13 +282,19 @@ def process_parquet( | |
---------- | ||
public_df_raw : pd.DataFrame | ||
Input raw flattened dataframe from the global database. | ||
|
||
Returns | ||
------- | ||
pd.DataFrame | ||
processed dataframe with the necessary columns for training. | ||
processing_period: TemporalContext, optional | ||
User-defined temporal extent to align the samples with, by default None, | ||
which means that 12-month processing window will be aligned around each sample's original valid_date. | ||
If provided, the processing window will be aligned with the middle of the user-defined temporal extent, according to the | ||
following principles: | ||
- the original valid_date of the sample should remain within the processing window | ||
- the center of the user-defined temporal extent should be not closer than MIN_EDGE_BUFFER (by default 2 months) | ||
to the start or end of the extraction period | ||
freq : str, optional | ||
Frequency of the time series, by default "MS". Provided frequency alias should be compatible with pandas. | ||
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases | ||
""" | ||
from presto.utils import process_parquet as process_parquet_for_presto | ||
from presto.utils import process_parquet | ||
|
||
logger.info("Processing selected samples ...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the way how processing_period_middle_ts is defined a few lines below doesn't seem to be future proof? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good point. I tried to make a small step towards more generic time series handling here (925ff21) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we anticipate that this will happen within the CCN timeframe? Because if not, we may want to keep this as nice to have (track an issue somewhere) but not get distracted by this future option too much. |
||
|
||
|
@@ -293,7 +303,16 @@ def process_parquet( | |
|
||
# get the middle of the user-defined temporal extent | ||
start_date, end_date = processing_period.to_datetime() | ||
processing_period_middle_ts = start_date + pd.DateOffset(months=6) | ||
|
||
# sanity check to make sure freq is not something we still don't support in Presto | ||
if freq not in ["MS", "10D"]: | ||
raise ValueError( | ||
f"Unsupported frequency alias: {freq}. Please use 'MS' or '10D'." | ||
) | ||
|
||
date_range = pd.date_range(start=start_date, end=end_date, freq=freq) | ||
middle_index = len(date_range) // 2 - 1 | ||
processing_period_middle_ts = date_range[middle_index] | ||
processing_period_middle_month = processing_period_middle_ts.month | ||
|
||
# get a lighter subset with only the necessary columns | ||
|
@@ -309,13 +328,13 @@ def process_parquet( | |
# calculate the shifts and assign new valid date | ||
sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month | ||
sample_dates["proposed_valid_date_month"] = processing_period_middle_month | ||
sample_dates["valid_month_shift_forward"] = sample_dates.apply( | ||
sample_dates["valid_month_shift_backward"] = sample_dates.apply( | ||
lambda xx: month_diff( | ||
xx["proposed_valid_date_month"], xx["true_valid_date_month"] | ||
), | ||
axis=1, | ||
) | ||
sample_dates["valid_month_shift_backward"] = sample_dates.apply( | ||
sample_dates["valid_month_shift_forward"] = sample_dates.apply( | ||
lambda xx: month_diff( | ||
xx["true_valid_date_month"], xx["proposed_valid_date_month"] | ||
), | ||
|
@@ -342,7 +361,7 @@ def process_parquet( | |
f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent." | ||
) | ||
|
||
public_df = process_parquet_for_presto(public_df_raw) | ||
public_df = process_parquet(public_df_raw) | ||
|
||
if processing_period is not None: | ||
# put back the true valid_date | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this correct? Shouldn't it be '-' in the first one and '+' in the second one?
Guess I don't understand what it is used for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, but this is exactly the case, no? like here in these lines
worldcereal-classification/src/worldcereal/utils/refdata.py
Lines 244 to 249 in 44934e8
or do you refer to something else?