From 7bb7909ab88bad18844ad43b8c41043b27085f64 Mon Sep 17 00:00:00 2001 From: Lilly Thomas Date: Fri, 26 Apr 2024 09:36:23 -0700 Subject: [PATCH] restructuring --- scripts/worldcover/run.py | 42 +++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/scripts/worldcover/run.py b/scripts/worldcover/run.py index 575ddfd5..95ad71b5 100755 --- a/scripts/worldcover/run.py +++ b/scripts/worldcover/run.py @@ -46,8 +46,7 @@ "https://huggingface.co/made-with-clay/Clay/resolve/main/" "Clay_v0.1_epoch-24_val-loss-0.46.ckpt" ) -# CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt" -VERSION = "003" +VERSION = "005" BUCKET = "clay-worldcover-embeddings" URL = "https://esa-worldcover-s2.s3.amazonaws.com/rgbnir/{year}/N{yidx}/ESA_WorldCover_10m_{year}_v{version}_N{yidx}W{xidx}_S2RGBNIR.tif" WC_VERSION_LOOKUP = { @@ -142,9 +141,8 @@ def tiles_and_windows(input: Window): return result - def download_image(url): - # Download the image from the URL + # Download an image from a URL response = requests.get(url) # Check if the request was successful if response.status_code == 200: @@ -152,8 +150,8 @@ def download_image(url): else: raise Exception("Failed to download the image") -def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)): - # Download the image from the URL +def patch_bounds_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)): + # Download an image from a URL image_data = download_image(url) # Open the image using rasterio from memory @@ -198,7 +196,6 @@ def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)): return chunk_bounds, img_crs - def make_batch(result): pixels = [] for url, win in result: @@ -282,7 +279,6 @@ def get_pixels(result): # Set the model to evaluation mode rgb_model.eval() - outdir_embeddings = Path("data/embeddings") outdir_embeddings.mkdir(exist_ok=True, parents=True) @@ -327,15 +323,16 @@ def get_pixels(result): print(len(embeddings), len(results)) - #embeddings = numpy.vstack(embeddings) - embeddings_ = embeddings[0] + embeddings_ = numpy.vstack(embeddings) + #embeddings_ = embeddings[0] print("Embeddings shape: ", embeddings_.shape) + + # remove date and lat/lon + embeddings_ = embeddings_[:, :-2, :].mean(axis=0) - embeddings_ = embeddings_[:, :-2, :] - - print(f"Embeddings have shape {embeddings_.shape}") #.mean(axis=1) + print(f"Embeddings have shape {embeddings_.shape}") - # remove date and lat/lon and reshape to disaggregated patches + # reshape to disaggregated patches embeddings_patch = embeddings_.reshape([2, 16, 16, 768]) # average over the band groups @@ -347,7 +344,7 @@ def get_pixels(result): if result is not None: print("result: ", result[0][0]) pix = get_pixels(result) - chunk_bounds, epsg = patches_and_windows_from_url(result[0][0]) + chunk_bounds, epsg = patch_bounds_from_url(result[0][0]) #print("chunk_bounds: ", chunk_bounds) print("chunk bounds length:", len(chunk_bounds)) @@ -365,16 +362,14 @@ def get_pixels(result): item_[0][1]["lon_end"], item_[0][1]["lat_end"], ] - #source_url = batch["source_url"] - date = batch["date"] - date_as_timestamp = pd.to_datetime(date, format="%Y-%m-%d") - - # Convert the Pandas Timestamp to the desired data type - #date_as_date32 = date_as_timestamp.astype('datetime64[D]') - #print(batch["date"]) data = { - "date": date_as_timestamp, + #"source_url": batch["source_url"][0], + #"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype( + # dtype="date32[day][pyarrow]" + #), + #"date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"), + "date": pd.to_datetime(batch["date"], format="%Y-%m-%d"), "embeddings": [numpy.ascontiguousarray(embeddings_output_patch)], } @@ -390,7 +385,6 @@ def get_pixels(result): # Reproject to WGS84 (lon/lat coordinates) gdf = gdf.to_crs(epsg=4326) - with tempfile.TemporaryDirectory() as tmp: # tmp = "/home/tam/Desktop/wcctmp"