Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

245 fix querying public extraction function #246

Merged
merged 8 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 74 additions & 55 deletions src/worldcereal/utils/refdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def query_public_extractions(
)

# Process the parquet into the format we need for training
processed_public_df = process_parquet(public_df_raw, processing_period)
processed_public_df = process_public_extractions_df(
public_df_raw, processing_period
)

return processed_public_df

Expand All @@ -191,7 +193,7 @@ def month_diff(month1: int, month2: int) -> int:
The difference between `month1` and `month2`.
"""

return month2 - month1 if month2 >= month1 else 12 - month1 + month2
return (month2 - month1) % 12


def get_best_valid_date(row: pd.Series):
Expand Down Expand Up @@ -219,57 +221,59 @@ def get_best_valid_date(row: pd.Series):

from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS

# check if shift forward will fit into existing extractions
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period
temp_end_date = row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER
def is_within_period(proposed_date, start_date, end_date):
return (proposed_date - pd.DateOffset(months=MIN_EDGE_BUFFER) >= start_date) & (
proposed_date + pd.DateOffset(months=MIN_EDGE_BUFFER) <= end_date
)

def check_shift(proposed_date, valid_date, start_date, end_date):
proposed_start_date = proposed_date - pd.DateOffset(
months=(NUM_TIMESTEPS // 2 - 1)
)
proposed_end_date = proposed_date + pd.DateOffset(months=(NUM_TIMESTEPS // 2))
return (
is_within_period(proposed_date, start_date, end_date)
& (valid_date >= proposed_start_date)
& (valid_date <= proposed_end_date)
)

valid_date = row["valid_date"]
start_date = row["start_date"]
end_date = row["end_date"]

proposed_valid_date_fwd = valid_date + pd.DateOffset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct? Shouldn't it be '-' in the first one and '+' in the second one?
Guess I don't understand what it is used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, but this is exactly the case, no? like here in these lines

proposed_valid_date_fwd = valid_date + pd.DateOffset(
months=row["valid_month_shift_forward"]
)
proposed_valid_date_bwd = valid_date - pd.DateOffset(
months=row["valid_month_shift_backward"]
)

or do you refer to something else?

months=row["valid_month_shift_forward"]
)
proposed_valid_date_bwd = valid_date - pd.DateOffset(
months=row["valid_month_shift_backward"]
)
temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS)
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]):
shift_forward_ok = True
else:
shift_forward_ok = False

# check if shift backward will fit into existing extractions
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period
temp_start_date = row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER
shift_forward_ok = check_shift(
proposed_valid_date_fwd, valid_date, start_date, end_date
)
shift_backward_ok = check_shift(
proposed_valid_date_bwd, valid_date, start_date, end_date
)
temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS)
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]):
shift_backward_ok = True
else:
shift_backward_ok = False

if (not shift_forward_ok) & (not shift_backward_ok):
if not shift_forward_ok and not shift_backward_ok:
return np.nan

if shift_forward_ok & (not shift_backward_ok):
return row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"]
if shift_forward_ok and not shift_backward_ok:
return proposed_valid_date_fwd
if not shift_forward_ok and shift_backward_ok:
return proposed_valid_date_bwd
if shift_forward_ok and shift_backward_ok:
return (
proposed_valid_date_bwd
if (row["valid_month_shift_backward"] - row["valid_month_shift_forward"])
<= MIN_EDGE_BUFFER
else proposed_valid_date_fwd
)

if (not shift_forward_ok) & shift_backward_ok:
return row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"]
)

if shift_forward_ok & shift_backward_ok:
# if shift backward is not too much bigger than shift forward, choose backward
if (
row["valid_month_shift_backward"] - row["valid_month_shift_forward"]
) <= MIN_EDGE_BUFFER:
return row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"]
)
else:
return row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"]
)


def process_parquet(
public_df_raw: pd.DataFrame, processing_period: TemporalContext = None
def process_public_extractions_df(
public_df_raw: pd.DataFrame,
processing_period: TemporalContext = None,
freq: str = "MS",
) -> pd.DataFrame:
"""Method to transform the raw parquet data into a format that can be used for
training. Includes pivoting of the dataframe and mapping of the crop types.
Expand All @@ -278,13 +282,19 @@ def process_parquet(
----------
public_df_raw : pd.DataFrame
Input raw flattened dataframe from the global database.

Returns
-------
pd.DataFrame
processed dataframe with the necessary columns for training.
processing_period: TemporalContext, optional
User-defined temporal extent to align the samples with, by default None,
which means that 12-month processing window will be aligned around each sample's original valid_date.
If provided, the processing window will be aligned with the middle of the user-defined temporal extent, according to the
following principles:
- the original valid_date of the sample should remain within the processing window
- the center of the user-defined temporal extent should be not closer than MIN_EDGE_BUFFER (by default 2 months)
to the start or end of the extraction period
freq : str, optional
Frequency of the time series, by default "MS". Provided frequency alias should be compatible with pandas.
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases
"""
from presto.utils import process_parquet as process_parquet_for_presto
from presto.utils import process_parquet

logger.info("Processing selected samples ...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way how processing_period_middle_ts is defined a few lines below doesn't seem to be future proof?
What if we move away from 12 month processing periods?

Copy link
Contributor Author

@cbutsko cbutsko Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good point. I tried to make a small step towards more generic time series handling here (925ff21)
For the default case that we have (12 monthly timesteps), it replicates the previous implementation. It can also handle other frequencies/lengths.
But we might need to add similar logic to other relevant places as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we anticipate that this will happen within the CCN timeframe? Because if not, we may want to keep this as nice to have (track an issue somewhere) but not get distracted by this future option too much.


Expand All @@ -293,7 +303,16 @@ def process_parquet(

# get the middle of the user-defined temporal extent
start_date, end_date = processing_period.to_datetime()
processing_period_middle_ts = start_date + pd.DateOffset(months=6)

# sanity check to make sure freq is not something we still don't support in Presto
if freq not in ["MS", "10D"]:
raise ValueError(
f"Unsupported frequency alias: {freq}. Please use 'MS' or '10D'."
)

date_range = pd.date_range(start=start_date, end=end_date, freq=freq)
middle_index = len(date_range) // 2 - 1
processing_period_middle_ts = date_range[middle_index]
processing_period_middle_month = processing_period_middle_ts.month

# get a lighter subset with only the necessary columns
Expand All @@ -309,13 +328,13 @@ def process_parquet(
# calculate the shifts and assign new valid date
sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month
sample_dates["proposed_valid_date_month"] = processing_period_middle_month
sample_dates["valid_month_shift_forward"] = sample_dates.apply(
sample_dates["valid_month_shift_backward"] = sample_dates.apply(
lambda xx: month_diff(
xx["proposed_valid_date_month"], xx["true_valid_date_month"]
),
axis=1,
)
sample_dates["valid_month_shift_backward"] = sample_dates.apply(
sample_dates["valid_month_shift_forward"] = sample_dates.apply(
lambda xx: month_diff(
xx["true_valid_date_month"], xx["proposed_valid_date_month"]
),
Expand All @@ -342,7 +361,7 @@ def process_parquet(
f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent."
)

public_df = process_parquet_for_presto(public_df_raw)
public_df = process_parquet(public_df_raw)

if processing_period is not None:
# put back the true valid_date
Expand Down
104 changes: 103 additions & 1 deletion tests/worldcerealtests/test_refdata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pandas as pd
from shapely.geometry import Polygon

from worldcereal.utils.refdata import query_public_extractions
from worldcereal.utils.refdata import (
get_best_valid_date,
month_diff,
query_public_extractions,
)


def test_query_public_extractions():
Expand All @@ -14,3 +19,100 @@ def test_query_public_extractions():

# Check if dataframe has samples
assert not df.empty


def test_get_best_valid_date():
def process_test_case(test_case: pd.Series) -> pd.DataFrame:
test_case_res = []
for processing_period_middle_month in range(1, 13):
test_case["true_valid_date_month"] = test_case["valid_date"].month
test_case["proposed_valid_date_month"] = processing_period_middle_month
test_case["valid_month_shift_backward"] = month_diff(
test_case["proposed_valid_date_month"],
test_case["true_valid_date_month"],
)
test_case["valid_month_shift_forward"] = month_diff(
test_case["true_valid_date_month"],
test_case["proposed_valid_date_month"],
)
proposed_valid_date = get_best_valid_date(test_case)
test_case_res.append([processing_period_middle_month, proposed_valid_date])
return pd.DataFrame(
test_case_res, columns=["proposed_valid_month", "resulting_valid_date"]
)

test_case1 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-06-01"),
}
)
test_case2 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-10-01"),
}
)
test_case3 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-03-01"),
}
)

# Process test cases
test_case1_res = process_test_case(test_case1)
test_case2_res = process_test_case(test_case2)
test_case3_res = process_test_case(test_case3)

# Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values
# Assertions for test case 1
assert (
test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][
"resulting_valid_date"
]
.isna()
.all()
)
assert (
test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][
"resulting_valid_date"
]
.notna()
.all()
)

# Assertions for test case 2
assert (
test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][
"resulting_valid_date"
]
.isna()
.all()
)
assert (
test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][
"resulting_valid_date"
]
.notna()
.all()
)

# Assertions for test case 3
assert (
test_case3_res[
test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12])
]["resulting_valid_date"]
.isna()
.all()
)
assert (
test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][
"resulting_valid_date"
]
.notna()
.all()
)
Loading