Skip to content

Commit

Permalink
Merge pull request #235 from WorldCereal/234-refactor-point-extractions
Browse files Browse the repository at this point in the history
234-refactor-point-extractions
  • Loading branch information
kvantricht authored Jan 9, 2025
2 parents 27bcc41 + f2b22cc commit 2a27eb9
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 265 deletions.
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ dependencies:
- tqdm
- pip:
- duckdb==1.1.0
- h3==3.7.7
- openeo-gfmap==0.2.0
- h3==4.1.0
- openeo-gfmap==0.3.0
- git+https://github.com/worldcereal/worldcereal-classification
- git+https://github.com/WorldCereal/presto-worldcereal.git@croptype

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ dependencies = [
"cftime",
"geojson",
"geopandas",
"h3==3.7.7",
"h3==4.1.0",
"h5netcdf>=1.1.0",
"loguru>=0.7.2",
"netcdf4<=1.6.4",
"numpy<2.0.0",
"openeo==0.31.0",
"openeo-gfmap==0.2.0",
"openeo-gfmap==0.3.0",
"pyarrow",
"pydantic==2.8.0",
"rioxarray>=0.13.0",
Expand Down
199 changes: 128 additions & 71 deletions scripts/extractions/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare_job_dataframe(
pipeline_log.info("Preparing the job dataframe.")

# Filter the input dataframe to only keep the locations to extract
input_df = input_df[input_df["extract"] == extract_value].copy()
input_df = input_df[input_df["extract"] >= extract_value].copy()

# Split the locations into chunks of max_locations
split_dfs = []
Expand Down Expand Up @@ -144,9 +144,9 @@ def prepare_job_dataframe(
def setup_extraction_functions(
collection: ExtractionCollection,
extract_value: int,
memory: str,
python_memory: str,
max_executors: int,
memory: typing.Union[str, None],
python_memory: typing.Union[str, None],
max_executors: typing.Union[int, None],
) -> tuple[typing.Callable, typing.Callable, typing.Callable]:
"""Setup the datacube creation, path generation and post-job action
functions for the given collection. Returns a tuple of three functions:
Expand All @@ -158,33 +158,33 @@ def setup_extraction_functions(
datacube_creation = {
ExtractionCollection.PATCH_SENTINEL1: partial(
create_job_patch_s1,
executor_memory=memory,
python_memory=python_memory,
max_executors=max_executors,
executor_memory=memory if memory is not None else "1800m",
python_memory=python_memory if python_memory is not None else "1900m",
max_executors=max_executors if max_executors is not None else 22,
),
ExtractionCollection.PATCH_SENTINEL2: partial(
create_job_patch_s2,
executor_memory=memory,
python_memory=python_memory,
max_executors=max_executors,
executor_memory=memory if memory is not None else "1800m",
python_memory=python_memory if python_memory is not None else "1900m",
max_executors=max_executors if max_executors is not None else 22,
),
ExtractionCollection.PATCH_METEO: partial(
create_job_patch_meteo,
executor_memory=memory,
python_memory=python_memory,
max_executors=max_executors,
executor_memory=memory if memory is not None else "1800m",
python_memory=python_memory if python_memory is not None else "1000m",
max_executors=max_executors if max_executors is not None else 22,
),
ExtractionCollection.PATCH_WORLDCEREAL: partial(
create_job_patch_worldcereal,
executor_memory=memory,
python_memory=python_memory,
max_executors=max_executors,
executor_memory=memory if memory is not None else "1800m",
python_memory=python_memory if python_memory is not None else "3000m",
max_executors=max_executors if max_executors is not None else 22,
),
ExtractionCollection.POINT_WORLDCEREAL: partial(
create_job_point_worldcereal,
executor_memory=memory,
python_memory=python_memory,
max_executors=max_executors,
executor_memory=memory if memory is not None else "1800m",
python_memory=python_memory if python_memory is not None else "3000m",
max_executors=max_executors if max_executors is not None else 22,
),
}

Expand Down Expand Up @@ -334,6 +334,102 @@ def manager_main_loop(
raise e


def run_extractions(
collection: ExtractionCollection,
output_folder: Path,
input_df: Path,
max_locations_per_job: int = 500,
memory: str = "1800m",
python_memory: str = "1900m",
max_executors: int = 22,
parallel_jobs: int = 2,
restart_failed: bool = False,
extract_value: int = 1,
backend=Backend.CDSE,
) -> None:
"""Main function responsible for launching point and patch extractions.
Parameters
----------
collection : ExtractionCollection
The collection to extract. Most popular: PATCH_WORLDCEREAL, POINT_WORLDCEREAL
output_folder : Path
The folder where to store the extracted data
input_df : Path
Path to the input dataframe containing the geometries
for which extractions need to be done
max_locations_per_job : int, optional
The maximum number of locations to extract per job, by default 500
memory : str, optional
Memory to allocate for the executor, by default "1800m"
python_memory : str, optional
Memory to allocate for the python processes as well as OrfeoToolbox in the executors,
by default "1900m"
max_executors : int, optional
Number of executors to run, by default 22
parallel_jobs : int, optional
The maximum number of parallel jobs to run at the same time, by default 10
restart_failed : bool, optional
Restart the jobs that previously failed, by default False
extract_value : int, optional
All samples with an "extract" value equal or larger than this one, will be extracted, by default 1
backend : openeo_gfmap.Backend, optional
cloud backend where to run the extractions, by default Backend.CDSE
Raises
------
ValueError
_description_
"""

if not output_folder.is_dir():
output_folder.mkdir(parents=True)

tracking_df_path = output_folder / "job_tracking.csv"

# Load the input dataframe and build the job dataframe
input_df = load_dataframe(input_df)

job_df = None
if not tracking_df_path.exists():
job_df = prepare_job_dataframe(
input_df, collection, max_locations_per_job, extract_value, backend
)

# Setup the extraction functions
pipeline_log.info("Setting up the extraction functions.")
datacube_fn, path_fn, post_job_fn = setup_extraction_functions(
collection, extract_value, memory, python_memory, max_executors
)

# Initialize and setups the job manager
pipeline_log.info("Initializing the job manager.")

job_manager = GFMAPJobManager(
output_dir=output_folder,
output_path_generator=path_fn,
post_job_action=post_job_fn,
poll_sleep=60,
n_threads=4,
restart_failed=restart_failed,
stac_enabled=False,
)

job_manager.add_backend(
backend.value,
cdse_connection,
parallel_jobs=parallel_jobs,
)

manager_main_loop(job_manager, collection, job_df, datacube_fn, tracking_df_path)

pipeline_log.info("Extraction completed successfully.")
send_notification(
title=f"WorldCereal Extraction {collection.value} - Completed",
message="Extractions have been completed successfully.",
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract data from a collection")
parser.add_argument(
Expand Down Expand Up @@ -372,8 +468,8 @@ def manager_main_loop(
parser.add_argument(
"--parallel_jobs",
type=int,
default=10,
help="The maximum number of parrallel jobs to run at the same time.",
default=2,
help="The maximum number of parallel jobs to run at the same time.",
)
parser.add_argument(
"--restart_failed",
Expand All @@ -389,55 +485,16 @@ def manager_main_loop(

args = parser.parse_args()

# Fetches values and setups hardocded values
collection = args.collection
extract_value = args.extract_value
max_locations_per_job = args.max_locations
backend = Backend.CDSE

if not args.output_folder.is_dir():
raise ValueError(f"Output folder {args.output_folder} does not exist.")

tracking_df_path = Path(args.output_folder) / "job_tracking.csv"

# Load the input dataframe and build the job dataframe
input_df = load_dataframe(args.input_df)

job_df = None
if not tracking_df_path.exists():
job_df = prepare_job_dataframe(
input_df, collection, max_locations_per_job, extract_value, backend
)

# Setup the extraction functions
pipeline_log.info("Setting up the extraction functions.")
datacube_fn, path_fn, post_job_fn = setup_extraction_functions(
collection, extract_value, args.memory, args.python_memory, args.max_executors
)

# Initialize and setups the job manager
pipeline_log.info("Initializing the job manager.")

job_manager = GFMAPJobManager(
output_dir=args.output_folder,
output_path_generator=path_fn,
post_job_action=post_job_fn,
poll_sleep=60,
n_threads=4,
restart_failed=args.restart_failed,
stac_enabled=False,
)

job_manager.add_backend(
Backend.CDSE.value,
cdse_connection,
run_extractions(
collection=args.collection,
output_folder=args.output_folder,
input_df=args.input_df,
max_locations_per_job=args.max_locations,
memory=args.memory,
python_memory=args.python_memory,
max_executors=args.max_executors,
parallel_jobs=args.parallel_jobs,
)

manager_main_loop(job_manager, collection, job_df, datacube_fn, tracking_df_path)

pipeline_log.info("Extraction completed successfully.")
send_notification(
title=f"WorldCereal Extraction {collection.value} - Completed",
message="Extractions have been completed successfully.",
restart_failed=args.restart_failed,
extract_value=args.extract_value,
backend=Backend.CDSE,
)
10 changes: 5 additions & 5 deletions scripts/extractions/patch_extractions/extract_patch_meteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def create_job_dataframe_patch_meteo(
def create_job_patch_meteo(
row: pd.Series,
connection: openeo.DataCube,
provider=None,
connection_provider=None,
executor_memory: str = "2G",
python_memory: str = "1G",
max_executors: int = 22,
provider,
connection_provider,
executor_memory: str,
python_memory: str,
max_executors: int,
) -> gpd.GeoDataFrame:
start_date = row.start_date
end_date = row.end_date
Expand Down
6 changes: 3 additions & 3 deletions scripts/extractions/patch_extractions/extract_patch_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def create_job_patch_s1(
connection: openeo.DataCube,
provider,
connection_provider,
executor_memory: str = "5G",
python_memory: str = "2G",
max_executors: int = 22,
executor_memory: str,
python_memory: str,
max_executors: int,
) -> openeo.BatchJob:
"""Creates an OpenEO BatchJob from the given row information. This job is a
S1 patch of 32x32 pixels at 20m spatial resolution."""
Expand Down
10 changes: 5 additions & 5 deletions scripts/extractions/patch_extractions/extract_patch_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def create_job_dataframe_patch_s2(
def create_job_patch_s2(
row: pd.Series,
connection: openeo.DataCube,
provider=None,
connection_provider=None,
executor_memory: str = "5G",
python_memory: str = "2G",
max_executors: int = 22,
provider,
connection_provider,
executor_memory: str,
python_memory: str,
max_executors: int,
) -> gpd.GeoDataFrame:
start_date = row.start_date
end_date = row.end_date
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def create_job_patch_worldcereal(
connection: openeo.DataCube,
provider,
connection_provider,
executor_memory: str = "5G",
python_memory: str = "2G",
max_executors: int = 22,
executor_memory: str,
python_memory: str,
max_executors: int,
) -> openeo.BatchJob:
"""Creates an OpenEO BatchJob from the given row information."""

Expand Down Expand Up @@ -398,13 +398,16 @@ def post_job_action_patch_worldcereal(


def generate_output_path_patch_worldcereal(
root_folder: Path, geometry_index: int, row: pd.Series, s2_grid: gpd.GeoDataFrame
root_folder: Path,
job_index: int,
row: pd.Series,
asset_id: str,
s2_grid: gpd.GeoDataFrame,
):
"""Generate the output path for the extracted data, from a base path and
the row information.
"""
features = geojson.loads(row.geometry)
sample_id = features[geometry_index].properties.get("sample_id", None)
sample_id = asset_id.replace(".nc", "").replace("openEO_", "")

s2_tile_id = row.s2_tile
epsg = s2_grid[s2_grid.tile == s2_tile_id].iloc[0].epsg
Expand Down
Loading

0 comments on commit 2a27eb9

Please sign in to comment.