From 3141a751e5a8561219cdb113737e49fda6c2d553 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Thu, 9 Jan 2025 12:43:59 +0100 Subject: [PATCH 1/7] complete refactoring of get_best_valid_date function --- src/worldcereal/utils/refdata.py | 83 ++++++++++++++++---------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index bd9067f..220aa77 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -10,7 +10,6 @@ from loguru import logger from openeo_gfmap import TemporalContext from shapely.geometry import Polygon - from worldcereal.data import croptype_mappings @@ -219,54 +218,54 @@ def get_best_valid_date(row: pd.Series): from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS - # check if shift forward will fit into existing extractions - # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period - temp_end_date = row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + def is_within_period(proposed_date, start_date, end_date): + return (proposed_date - pd.DateOffset(months=MIN_EDGE_BUFFER) >= start_date) & ( + proposed_date + pd.DateOffset(months=MIN_EDGE_BUFFER) <= end_date + ) + + def check_shift(proposed_date, valid_date, start_date, end_date): + proposed_start_date = proposed_date - pd.DateOffset( + months=(NUM_TIMESTEPS // 2 - 1) + ) + proposed_end_date = proposed_date + pd.DateOffset(months=(NUM_TIMESTEPS // 2)) + return ( + is_within_period(proposed_date, start_date, end_date) + & (valid_date >= proposed_start_date) + & (valid_date <= proposed_end_date) + ) + + valid_date = row["valid_date"] + start_date = row["start_date"] + end_date = row["end_date"] + + proposed_valid_date_fwd = valid_date + pd.DateOffset( + months=row["valid_month_shift_forward"] + ) + proposed_valid_date_bwd = valid_date - pd.DateOffset( + months=row["valid_month_shift_backward"] ) - temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS) - if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): - shift_forward_ok = True - else: - shift_forward_ok = False - # check if shift backward will fit into existing extractions - # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period - temp_start_date = row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + shift_forward_ok = check_shift( + proposed_valid_date_fwd, valid_date, start_date, end_date + ) + shift_backward_ok = check_shift( + proposed_valid_date_bwd, valid_date, start_date, end_date ) - temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS) - if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): - shift_backward_ok = True - else: - shift_backward_ok = False - if (not shift_forward_ok) & (not shift_backward_ok): + if not shift_forward_ok and not shift_backward_ok: return np.nan - - if shift_forward_ok & (not shift_backward_ok): - return row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] - ) - - if (not shift_forward_ok) & shift_backward_ok: - return row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] + if shift_forward_ok and not shift_backward_ok: + return proposed_valid_date_fwd + if not shift_forward_ok and shift_backward_ok: + return proposed_valid_date_bwd + if shift_forward_ok and shift_backward_ok: + return ( + proposed_valid_date_bwd + if (row["valid_month_shift_backward"] - row["valid_month_shift_forward"]) + <= MIN_EDGE_BUFFER + else proposed_valid_date_fwd ) - if shift_forward_ok & shift_backward_ok: - # if shift backward is not too much bigger than shift forward, choose backward - if ( - row["valid_month_shift_backward"] - row["valid_month_shift_forward"] - ) <= MIN_EDGE_BUFFER: - return row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] - ) - else: - return row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] - ) - def process_parquet( public_df_raw: pd.DataFrame, processing_period: TemporalContext = None From f24569ed90185264df97b365e26e00756c7098c2 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Thu, 9 Jan 2025 12:44:50 +0100 Subject: [PATCH 2/7] even more concise computation of month_diff --- src/worldcereal/utils/refdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 220aa77..09e9618 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -190,7 +190,7 @@ def month_diff(month1: int, month2: int) -> int: The difference between `month1` and `month2`. """ - return month2 - month1 if month2 >= month1 else 12 - month1 + month2 + return (month2 - month1) % 12 def get_best_valid_date(row: pd.Series): From c55a9c6ffacae8bac0d334217a79374fef93de0f Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Thu, 9 Jan 2025 12:57:16 +0100 Subject: [PATCH 3/7] adding test for get_best_valid_date function --- tests/worldcerealtests/test_refdata.py | 100 ++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/tests/worldcerealtests/test_refdata.py b/tests/worldcerealtests/test_refdata.py index c5767bc..c5ee158 100644 --- a/tests/worldcerealtests/test_refdata.py +++ b/tests/worldcerealtests/test_refdata.py @@ -1,6 +1,7 @@ +import pandas as pd from shapely.geometry import Polygon - -from worldcereal.utils.refdata import query_public_extractions +from worldcereal.utils.refdata import (get_best_valid_date, month_diff, + query_public_extractions) def test_query_public_extractions(): @@ -14,3 +15,98 @@ def test_query_public_extractions(): # Check if dataframe has samples assert not df.empty + + +def test_get_best_valid_date(): + def process_test_case(test_case: pd.Series) -> pd.DataFrame: + test_case_res = [] + for processing_period_middle_month in range(1, 13): + test_case["true_valid_date_month"] = test_case["valid_date"].month + test_case["proposed_valid_date_month"] = processing_period_middle_month + test_case["valid_month_shift_backward"] = month_diff( + test_case["proposed_valid_date_month"], test_case["true_valid_date_month"] + ) + test_case["valid_month_shift_forward"] = month_diff( + test_case["true_valid_date_month"], test_case["proposed_valid_date_month"] + ) + proposed_valid_date = get_best_valid_date(test_case) + test_case_res.append([processing_period_middle_month, proposed_valid_date]) + return pd.DataFrame( + test_case_res, columns=["proposed_valid_month", "resulting_valid_date"] + ) + + test_case1 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-06-01"), + } + ) + test_case2 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-10-01"), + } + ) + test_case3 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-03-01"), + } + ) + + # Process test cases + test_case1_res = process_test_case(test_case1) + test_case2_res = process_test_case(test_case2) + test_case3_res = process_test_case(test_case3) + + # Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values + # Assertions for test case 1 + assert ( + test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][ + "resulting_valid_date" + ] + .isna() + .all() + ) + assert ( + test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][ + "resulting_valid_date" + ] + .notna() + .all() + ) + + # Assertions for test case 2 + assert ( + test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][ + "resulting_valid_date" + ] + .isna() + .all() + ) + assert ( + test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][ + "resulting_valid_date" + ] + .notna() + .all() + ) + + # Assertions for test case 3 + assert ( + test_case3_res[test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12])][ + "resulting_valid_date" + ] + .isna() + .all() + ) + assert ( + test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][ + "resulting_valid_date" + ] + .notna() + .all() + ) \ No newline at end of file From cb9bddfbc0fcd846b1e9522ccc1ce1a48c19acf3 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Thu, 9 Jan 2025 13:20:30 +0100 Subject: [PATCH 4/7] formatting --- src/worldcereal/utils/refdata.py | 1 + tests/worldcerealtests/test_refdata.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 09e9618..7090865 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -10,6 +10,7 @@ from loguru import logger from openeo_gfmap import TemporalContext from shapely.geometry import Polygon + from worldcereal.data import croptype_mappings diff --git a/tests/worldcerealtests/test_refdata.py b/tests/worldcerealtests/test_refdata.py index c5ee158..e390e62 100644 --- a/tests/worldcerealtests/test_refdata.py +++ b/tests/worldcerealtests/test_refdata.py @@ -1,7 +1,11 @@ import pandas as pd from shapely.geometry import Polygon -from worldcereal.utils.refdata import (get_best_valid_date, month_diff, - query_public_extractions) + +from worldcereal.utils.refdata import ( + get_best_valid_date, + month_diff, + query_public_extractions, +) def test_query_public_extractions(): @@ -24,10 +28,12 @@ def process_test_case(test_case: pd.Series) -> pd.DataFrame: test_case["true_valid_date_month"] = test_case["valid_date"].month test_case["proposed_valid_date_month"] = processing_period_middle_month test_case["valid_month_shift_backward"] = month_diff( - test_case["proposed_valid_date_month"], test_case["true_valid_date_month"] + test_case["proposed_valid_date_month"], + test_case["true_valid_date_month"], ) test_case["valid_month_shift_forward"] = month_diff( - test_case["true_valid_date_month"], test_case["proposed_valid_date_month"] + test_case["true_valid_date_month"], + test_case["proposed_valid_date_month"], ) proposed_valid_date = get_best_valid_date(test_case) test_case_res.append([processing_period_middle_month, proposed_valid_date]) @@ -97,9 +103,9 @@ def process_test_case(test_case: pd.Series) -> pd.DataFrame: # Assertions for test case 3 assert ( - test_case3_res[test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12])][ - "resulting_valid_date" - ] + test_case3_res[ + test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12]) + ]["resulting_valid_date"] .isna() .all() ) @@ -109,4 +115,4 @@ def process_test_case(test_case: pd.Series) -> pd.DataFrame: ] .notna() .all() - ) \ No newline at end of file + ) From 68861574bd5e376b63a9dadecca8a60bec4b5e73 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Thu, 9 Jan 2025 13:36:52 +0100 Subject: [PATCH 5/7] renamed duplicated process_parquet function --- src/worldcereal/utils/refdata.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 7090865..21dda4d 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -168,7 +168,9 @@ def query_public_extractions( ) # Process the parquet into the format we need for training - processed_public_df = process_parquet(public_df_raw, processing_period) + processed_public_df = process_public_extractions_df( + public_df_raw, processing_period + ) return processed_public_df @@ -268,7 +270,7 @@ def check_shift(proposed_date, valid_date, start_date, end_date): ) -def process_parquet( +def process_public_extractions_df( public_df_raw: pd.DataFrame, processing_period: TemporalContext = None ) -> pd.DataFrame: """Method to transform the raw parquet data into a format that can be used for @@ -284,7 +286,7 @@ def process_parquet( pd.DataFrame processed dataframe with the necessary columns for training. """ - from presto.utils import process_parquet as process_parquet_for_presto + from presto.utils import process_parquet logger.info("Processing selected samples ...") @@ -342,7 +344,7 @@ def process_parquet( f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent." ) - public_df = process_parquet_for_presto(public_df_raw) + public_df = process_parquet(public_df_raw) if processing_period is not None: # put back the true valid_date From 925ff21f3b7ba4580fd1b9a84ec6c26e463c50c0 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 13 Jan 2025 10:35:16 +0100 Subject: [PATCH 6/7] more generic definition of processing_period_middle_ts --- src/worldcereal/utils/refdata.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 21dda4d..60499ee 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -271,7 +271,9 @@ def check_shift(proposed_date, valid_date, start_date, end_date): def process_public_extractions_df( - public_df_raw: pd.DataFrame, processing_period: TemporalContext = None + public_df_raw: pd.DataFrame, + processing_period: TemporalContext = None, + freq: str = "MS", ) -> pd.DataFrame: """Method to transform the raw parquet data into a format that can be used for training. Includes pivoting of the dataframe and mapping of the crop types. @@ -280,11 +282,17 @@ def process_public_extractions_df( ---------- public_df_raw : pd.DataFrame Input raw flattened dataframe from the global database. - - Returns - ------- - pd.DataFrame - processed dataframe with the necessary columns for training. + processing_period: TemporalContext, optional + User-defined temporal extent to align the samples with, by default None, + which means that 12-month processing window will be aligned around each sample's original valid_date. + If provided, the processing window will be aligned with the middle of the user-defined temporal extent, according to the + following principles: + - the original valid_date of the sample should remain within the processing window + - the center of the user-defined temporal extent should be not closer than MIN_EDGE_BUFFER (by default 2 months) + to the start or end of the extraction period + freq : str, optional + Frequency of the time series, by default "MS". Provided frequency alias should be compatible with pandas. + https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases """ from presto.utils import process_parquet @@ -295,7 +303,16 @@ def process_public_extractions_df( # get the middle of the user-defined temporal extent start_date, end_date = processing_period.to_datetime() - processing_period_middle_ts = start_date + pd.DateOffset(months=6) + + # sanity check to make sure freq is not something we still don't support in Presto + if freq not in ["MS", "10D"]: + raise ValueError( + f"Unsupported frequency alias: {freq}. Please use 'MS' or '10D'." + ) + + date_range = pd.date_range(start=start_date, end=end_date, freq=freq) + middle_index = len(date_range) // 2 - 1 + processing_period_middle_ts = date_range[middle_index] processing_period_middle_month = processing_period_middle_ts.month # get a lighter subset with only the necessary columns From de55b03b8b826de60bf7f9d5489f5514247b114b Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 13 Jan 2025 13:38:58 +0100 Subject: [PATCH 7/7] =?UTF-8?q?forward=20and=20backward=20mixup=20?= =?UTF-8?q?=F0=9F=A4=A6=E2=80=8D=E2=99=80=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/worldcereal/utils/refdata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 60499ee..a6a29d2 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -328,13 +328,13 @@ def process_public_extractions_df( # calculate the shifts and assign new valid date sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month sample_dates["proposed_valid_date_month"] = processing_period_middle_month - sample_dates["valid_month_shift_forward"] = sample_dates.apply( + sample_dates["valid_month_shift_backward"] = sample_dates.apply( lambda xx: month_diff( xx["proposed_valid_date_month"], xx["true_valid_date_month"] ), axis=1, ) - sample_dates["valid_month_shift_backward"] = sample_dates.apply( + sample_dates["valid_month_shift_forward"] = sample_dates.apply( lambda xx: month_diff( xx["true_valid_date_month"], xx["proposed_valid_date_month"] ),