Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NetCDF / xarray efficient processing #1

Open
gmillz012 opened this issue Nov 21, 2024 · 0 comments
Open

NetCDF / xarray efficient processing #1

gmillz012 opened this issue Nov 21, 2024 · 0 comments
Labels
help wanted Extra attention is needed

Comments

@gmillz012
Copy link
Owner

gmillz012 commented Nov 21, 2024

My original file (greece_dataset.nc) was ~21.9 GB. I have limited computing power, so I decided to split the file into geographical regions (climate classifications). In split_data.py, I read the data in using xarray, drop static variables, drop values over the ocean, added a region variable, and exported subfiles by region using .to_netcdf. It takes a really long time to write and the file sizes are much bigger - up to 300GB. I then process each subfile (process_class.py), creating two new data variables (spei, heatwave).

When executing both python scripts, I run into issues overloading the memory. I'm submitting them via bash scripts to run on a compute node of a supercomputer with 48 CPUs and ~180GB of memory. I've implemented chunking and various things but am still dealing with inflating file sizes and OOM issues. I've tried logging the memory, deleting unnecessary objects from memory as I code, but I suspect it has to do with inefficient chunking or how I'm exporting the file.

Code for split_data is below.

Description of one of my subfiles:

Dimensions: (time: ~4100, y: 385, x: 541) 
Coordinates:
time (time) datetime64[ns] 488B 2020-06-01T10:00:0...
x (x) float64 4kB 21.34 21.35 ... 27.64 27.65
y (y) float64 3kB 42.3 42.29 42.28 ... 39.18 39.17 
band int64 8B ... spatial_ref int64 8B ... 
Data variables: (12/40) 
burned_areas (time, y, x) float32 51MB ... 
ignition_points (time, y, x) float32 51MB ... 
ndvi (time, y, x) float32 51MB ... 
number_of_fires (time, y, x) float32 51MB ... 
evi (time, y, x) float32 51MB ... 
et (time, y, x) float32 51MB ... ... ... 
max_wind_direction (time, y, x) float32 51MB ... 
max_rh (time, y, x) float32 51MB ... 
min_rh (time, y, x) float32 51MB ... 
avg_rh (time, y, x) float32 51MB ... 
classification_value (y, x) float64 2MB ... 
classification_description (y, x) <U23 19MB ... 
Attributes: 
temporal_extent: (2009-03-06, 2021-08-29) 
spatial_extent: (18.7, 28.9, 34.3, 42.3) crs: EPSG:4326
import xarray as xr
import numpy as np
import psutil
import rasterio
import os
import gc
import dask
from scipy.spatial import cKDTree

def log_memory(stage=""):
    process = psutil.Process()
    memory_used = process.memory_info().rss / 1024 ** 3  # Convert to GB
    print(f"[{stage}] Memory usage: {memory_used:.2f} GB", flush=True)

# Paths to files
legend_path = '/home/gridsan/gmiller/climate/legend.txt'
tif_path = '/home/gridsan/gmiller/climate/Beck_KG_V1_present_0p083.tif'
file_path = '/home/gridsan/gmiller/climate/dataset_greece.nc'

# Read legend
legend = {}
with open(legend_path, 'r') as file:
    for line in file:
        if ':' in line and line.strip()[0].isdigit():
            key, rest = line.strip().split(':', 1)
            key = int(key)
            classification = rest.split('[')[0].strip()
            legend[key] = classification

# Read raster data (Koppen-Geiger classifications)
log_memory("Before reading raster")
with rasterio.open(tif_path) as src:
    raster_data = src.read(1)  # Read classification band
    raster_transform = src.transform

    # Extract coordinates
    rows, cols = np.indices(raster_data.shape)
    lon, lat = rasterio.transform.xy(raster_transform, rows, cols, offset="center")
    lon = np.array(lon).flatten()
    lat = np.array(lat).flatten()
    values = raster_data.flatten()

    # Filter valid points
    lon_min, lat_min, lon_max, lat_max = 18, 34, 32, 43
    mask = (values != 0) & (lon_min <= lon) & (lon <= lon_max) & (lat_min <= lat) & (lat <= lat_max)
    lon, lat, values = lon[mask], lat[mask], values[mask]
    del raster_data, rows, cols, mask  # Free memory
    gc.collect()

    descriptions = [legend.get(value, "Unknown") for value in values]

log_memory("After reading raster")

# Create KDTree
coords_tree = cKDTree(np.column_stack((lon, lat)))
del lon, lat
log_memory("After creating KDTree")

# Load dataset with chunking to avoid OOM issues
log_memory("Before opening dataset")
ds = xr.open_dataset(file_path, chunks="auto")
ds = ds.unify_chunks()
print(ds.chunks, flush=True)
log_memory("After opening dataset")

# Filter variables with a time dimension
log_memory("Before filtering variables")
time_vars = [var for var in ds.data_vars if 'time' in ds[var].dims]
ds = ds[time_vars]
log_memory("After filtering variables")

# Create land mask using 'ndvi'
log_memory("Before creating land mask")
reference_var = "ndvi"
date_to_use = '2020-06-01T10:00:00.000000000'  # Specify the desired date explicitly

# Select the data for the specified date
land_mask = ds[reference_var].sel(time=date_to_use).notnull()
log_memory("After creating land mask")

# Apply land mask lazily
ds = ds.where(land_mask)
log_memory("After applying land mask")

# Generate valid coordinates
x_coords, y_coords = np.meshgrid(ds["x"].values, ds["y"].values)

# Flatten the grids and apply the land mask
land_mask_flat = land_mask.values.flatten()
valid_coords = np.column_stack((
    x_coords.flatten()[land_mask_flat],
    y_coords.flatten()[land_mask_flat]
))
del x_coords, y_coords
log_memory("After generating valid coordinates")

# Query KDTree
distances, indices = coords_tree.query(valid_coords)
del coords_tree, valid_coords
log_memory("After querying KDTree")

classification_values = values[indices]
del indices, values
classification_descriptions = [legend.get(int(val), "Unknown") for val in classification_values]
log_memory("After classification mapping")

# Assign classifications to dataset
classification_value_data = np.full(land_mask.shape, np.nan)
classification_description_data = np.full(land_mask.shape, np.nan, dtype=object)
classification_value_data[land_mask.values] = classification_values
classification_description_data[land_mask.values] = classification_descriptions

# Add to dataset
ds = ds.assign(
    classification_value=(("y", "x"), classification_value_data),
    classification_description=(("y", "x"), classification_description_data)
)
log_memory("After assigning classifications")

del classification_value_data, classification_description_data, classification_values, classification_descriptions
gc.collect()


output_dir = "classification_datasets"
os.makedirs(output_dir, exist_ok=True)

excluded_classifications = {}
unique_classifications = np.unique(ds["classification_value"].values[~np.isnan(ds["classification_value"].values)])
remaining_classifications = [c for c in unique_classifications if c not in excluded_classifications]

# Generate dynamic encoding for all variables
encoding = {}
for var in ds.data_vars:
    var_dims = ds[var].dims  # Get dimensions of the variable
    var_shape = ds[var].shape  # Get the shape of the variable
    var_chunks = tuple(min(size, 50) for size in var_shape)  # Adjust chunk sizes
    encoding[var] = {
        "zlib": True,  # Enable compression
        "complevel": 4,  # Compression level (1-9, 4 is a good balance)
        "chunksizes": var_chunks  # Chunk sizes
    }

# Export each classification as a separate file without splitting by time
for classification in remaining_classifications:
    print(f"Processing classification {classification}...", flush=True)

    # Lazy mask application
    land_mask = ds["classification_value"] == classification
    classification_ds = ds.where(land_mask, drop=True)
    classification_ds = classification_ds.chunk({"time": 10})  # Ensure chunking

    # Output file path for this classification
    output_file = os.path.join(output_dir, f"classification_{int(classification)}.nc")

    # Save the dataset for the current classification
    with dask.config.set({"array.slicing.split_large_chunks": True}):
        classification_ds.to_netcdf(
            output_file,
            compute=True,
            engine="netcdf4",
            encoding=encoding
        )

    del classification_ds, land_mask
    gc.collect()
    log_memory(f"After saving classification {classification}")

print("Processing complete.", flush=True)
@gmillz012 gmillz012 added the help wanted Extra attention is needed label Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant