diff --git a/.gitignore b/.gitignore index ebc2569..addfb30 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ dmypy.json .pyre/ .idea -localconfig.yaml \ No newline at end of file +localconfig.yaml +tmp +outputs +multirun diff --git a/docs/usage.rst b/docs/usage.rst index 6338042..0102e24 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -8,39 +8,90 @@ To use sarwaveifrproc in a project:: .. code-block:: +PYTHONPATH=. L2-wave-processor 'hydra/help=[default,doc]' --help - L2-wave-processor -h - usage: L2-wave-processor [-h] --input_path INPUT_PATH --save_directory SAVE_DIRECTORY --product_id PRODUCT_ID [--model_intraburst MODEL_INTRABURST] - [--model_interburst MODEL_INTERBURST] [--scaler_intraburst SCALER_INTRABURST] [--scaler_interburst SCALER_INTERBURST] - [--bins_intraburst BINS_INTRABURST] [--bins_interburst BINS_INTERBURST] [--predicted_variables PREDICTED_VARIABLES] [--overwrite] [--verbose] - - Generate a L2 WAVE product from a L1B or L1C SAFE. - - options: - -h, --help show this help message and exit - --input_path INPUT_PATH - l1b or l1c safe path or listing path (.txt file). - --save_directory SAVE_DIRECTORY - where to save output data. - --product_id PRODUCT_ID - 3 digits ID representing the processing options. Ex: E00. - --overwrite overwrite the existing outputs - --verbose - - model: - Arguments related to the neural models - - --model_intraburst MODEL_INTRABURST - neural model path to predict sea states on intraburst data. - --model_interburst MODEL_INTERBURST - neural model path to predict sea states on interburst data. - --scaler_intraburst SCALER_INTRABURST - scaler path to standardize intraburst data before feeding it to the neural model. - --scaler_interburst SCALER_INTERBURST - scaler path to standardize interburst data before feeding it to the neural model. - --bins_intraburst BINS_INTRABURST - bins path that depicts the range of predictions on intraburst data. - --bins_interburst BINS_INTERBURST - bins path that depicts the range of predictions on interburst data. - --predicted_variables PREDICTED_VARIABLES - list of sea states variables to predict. + Generate a L2 WAVE product from a L1B or L1C SAFE. + + input_path: l1b or l1c safe path or listing path (.txt file). + save_directory: where to save output data + product_id: 3 digits ID representing the processing options. Ex: E00. + models: onnx models and output + predicted_variables: model outputs and associated variables name to add to the L2 product + overwrite: overwrite the existing outputs + verbose: debug log level if True + + == Configuration groups == + Compose your configuration from those groups (group=option) + + parallel: chunk + + + == Config == + Override anything in the config (foo.bar=value) + + _target_: sarwaveifrproc.main_new.main + input_path: ??? + save_directory: ??? + product_id: E09 + models: + hs_mod: + path: models/hs.onnx + outputs: + - pred + - conf + t0m1_mod: + path: models/t0m1.onnx + outputs: + - pred + - conf + phs0_mod: + path: models/phs0.onnx + outputs: + - pred + - conf + predicted_variables: + intraburst: + hs_most_likely: + model: hs_mod + output: pred + attrs: + long_name: Most likely significant wave height + units: m + hs_conf: + model: hs_mod + output: conf + attrs: + long_name: Significant wave height confidence + units: '' + phs0_most_likely: + model: phs0_mod + output: pred + attrs: + long_name: Most likely wind sea significant wave height + units: m + phs0_conf: + model: phs0_mod + output: conf + attrs: + long_name: Wind sea significant wave height confidence + units: '' + t0m1_most_likely: + model: t0m1_mod + output: pred + attrs: + long_name: Most likely mean wave period + units: s + t0m1_conf: + model: t0m1_mod + output: conf + attrs: + long_name: Mean wave period confidence + units: '' + interburst: ${.intraburst} + overwrite: true + verbose: false + + + Powered by Hydra (https://hydra.cc) + Use --hydra-help to view Hydra specific help + list of sea states variables to predict. diff --git a/models/hs.onnx b/models/hs.onnx new file mode 100644 index 0000000..8d0b4ae Binary files /dev/null and b/models/hs.onnx differ diff --git a/models/phs0.onnx b/models/phs0.onnx new file mode 100644 index 0000000..5814b6d Binary files /dev/null and b/models/phs0.onnx differ diff --git a/models/t0m1.onnx b/models/t0m1.onnx new file mode 100644 index 0000000..510b136 Binary files /dev/null and b/models/t0m1.onnx differ diff --git a/pyproject.toml b/pyproject.toml index fbc5e56..221fa6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,28 +8,34 @@ license = {text = "MIT"} keywords = ["SAR", "wave", "reseach","sea-state"] authors = [ {name = "Robin Marquart"}, - {name = "Antoine Grouazel"} + {name = "Antoine Grouazel"}, + {name = "Quentin Febvre"} ] classifiers = [ "Programming Language :: Python :: 3", ] dependencies = [ "xarray", + "hydra-zen", "netCDF4", "pyyaml", "scipy", - "tensorflow", - "xarray-datatree" - ] + "onnxruntime" +] dynamic = ["version"] +[project.optional-dependencies] +gpu = [ "onnxruntime-gpu" ] + +joblib = [ "hydra-joblib-launcher" ] + [build-system] -requires = ["setuptools>=64.0", "setuptools-scm"] +requires = ["setuptools>=64.0", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [tool.setuptools] -packages = ["sarwaveifrproc"] +packages = ["sarwaveifrproc", "sarwave_config"] [tool.setuptools_scm] fallback_version = "999" @@ -42,4 +48,4 @@ default_section = "THIRDPARTY" known_first_party = "sarwaveifrproc" [project.scripts] -L2-wave-processor = "sarwaveifrproc.main:main" +L2-wave-processor = "sarwaveifrproc.main:hydra_main" diff --git a/sarwave_config/__init__.py b/sarwave_config/__init__.py new file mode 100644 index 0000000..63ae73d --- /dev/null +++ b/sarwave_config/__init__.py @@ -0,0 +1,33 @@ +import sarwaveifrproc.main +import hydra_zen +from pathlib import Path + +def chunk_listing(path: str, i: int, n: int=10) -> str: + l = Path(path).read_text().split() + chunk_size = len(l)//(n - 1) + Path('tmp').mkdir(exist_ok=True) + Path(f'tmp/chunk_{i}_{Path(path).name}').write_text( + '\n'.join(l[i*chunk_size:(i+1)*chunk_size]) + ) + return f'tmp/chunk_{i}_{Path(path).name}' + +hydra_zen.store(sarwaveifrproc.main.main, name='base') +hydra_zen.store( + dict(header=sarwaveifrproc.main.main.__doc__), + name='doc', + group='hydra/help', +) + +hydra_zen.store( + dict(input_path=hydra_zen.builds( + chunk_listing, + path='???', + i='${hydra:job.num}', + n='${hydra:launcher.n_jobs}', + )), + name='chunk', + group='parallel', + package='_global_', +) + +hydra_zen.store.add_to_hydra_store(overwrite_ok=True) diff --git a/sarwave_config/e08.yaml b/sarwave_config/e08.yaml new file mode 100644 index 0000000..d9095d3 --- /dev/null +++ b/sarwave_config/e08.yaml @@ -0,0 +1,46 @@ + +models: + hs_mod: + path: models/hs.onnx + outputs: ['pred', 'conf'] + t0m1_mod: + path: models/t0m1.onnx + outputs: ['pred', 'conf'] + phs0_mod: + path: models/phs0.onnx + outputs: ['pred', 'conf'] + +predicted_variables: + intraburst: + hs_most_likely: + model: hs_mod + output: pred + attrs: {'long_name': 'Most likely significant wave height', 'units': 'm'} + hs_conf: + model: hs_mod + output: conf + attrs: {'long_name': 'Significant wave height confidence', 'units': ''} + phs0_most_likely: + model: phs0_mod + output: pred + attrs: {'long_name': 'Most likely wind sea significant wave height', 'units': 'm'} + phs0_conf: + model: phs0_mod + output: conf + attrs: {'long_name': 'Wind sea significant wave height confidence', 'units': ''} + t0m1_most_likely: + model: t0m1_mod + output: pred + attrs: {'long_name': 'Most likely mean wave period', 'units': 's'} + t0m1_conf: + model: t0m1_mod + output: conf + attrs: {'long_name': 'Mean wave period confidence', 'units': ''} + interburst: ${.intraburst} + +product_id: E08 +overwrite: True + +defaults: + - base + - _self_ diff --git a/sarwave_config/e09.yaml b/sarwave_config/e09.yaml new file mode 100644 index 0000000..7a57a38 --- /dev/null +++ b/sarwave_config/e09.yaml @@ -0,0 +1,46 @@ + +models: + hs_mod: + path: models/hs.onnx + outputs: ['pred', 'conf'] + t0m1_mod: + path: models/t0m1.onnx + outputs: ['pred', 'conf'] + phs0_mod: + path: models/phs0.onnx + outputs: ['pred', 'conf'] + +predicted_variables: + intraburst: + hs_most_likely: + model: hs_mod + output: pred + attrs: {'long_name': 'Most likely significant wave height', 'units': 'm'} + hs_conf: + model: hs_mod + output: conf + attrs: {'long_name': 'Significant wave height confidence', 'units': ''} + phs0_most_likely: + model: phs0_mod + output: pred + attrs: {'long_name': 'Most likely wind sea significant wave height', 'units': 'm'} + phs0_conf: + model: phs0_mod + output: conf + attrs: {'long_name': 'Wind sea significant wave height confidence', 'units': ''} + t0m1_most_likely: + model: t0m1_mod + output: pred + attrs: {'long_name': 'Most likely mean wave period', 'units': 's'} + t0m1_conf: + model: t0m1_mod + output: conf + attrs: {'long_name': 'Mean wave period confidence', 'units': ''} + interburst: ${.intraburst} + +product_id: E09 +overwrite: True + +defaults: + - base + - _self_ diff --git a/sarwave_config/e10.yaml b/sarwave_config/e10.yaml new file mode 100644 index 0000000..417a19e --- /dev/null +++ b/sarwave_config/e10.yaml @@ -0,0 +1,46 @@ + +models: + hs_mod: + path: models/hs.onnx + outputs: ['pred', 'conf'] + t0m1_mod: + path: models/t0m1.onnx + outputs: ['pred', 'conf'] + phs0_mod: + path: models/phs0.onnx + outputs: ['pred', 'conf'] + +predicted_variables: + intraburst: + hs_most_likely: + model: hs_mod + output: pred + attrs: {'long_name': 'Most likely significant wave height', 'units': 'm'} + hs_conf: + model: hs_mod + output: conf + attrs: {'long_name': 'Significant wave height confidence', 'units': ''} + phs0_most_likely: + model: phs0_mod + output: pred + attrs: {'long_name': 'Most likely wind sea significant wave height', 'units': 'm'} + phs0_conf: + model: phs0_mod + output: conf + attrs: {'long_name': 'Wind sea significant wave height confidence', 'units': ''} + t0m1_most_likely: + model: t0m1_mod + output: pred + attrs: {'long_name': 'Most likely mean wave period', 'units': 's'} + t0m1_conf: + model: t0m1_mod + output: conf + attrs: {'long_name': 'Mean wave period confidence', 'units': ''} + interburst: ${.intraburst} + +product_id: E10 +overwrite: True + +defaults: + - base + - _self_ diff --git a/sarwaveifrproc/cli.py b/sarwaveifrproc/cli.py deleted file mode 100644 index 1e93219..0000000 --- a/sarwaveifrproc/cli.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Console script for sarwaveifrproc.""" -import argparse -import sys - - -def main(): - """Console script for sarwaveifrproc.""" - parser = argparse.ArgumentParser() - parser.add_argument('_', nargs='*') - args = parser.parse_args() - - print("Arguments: " + str(args._)) - print("Replace this message by putting your code into " - "sarwaveifrproc.cli.main") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) # pragma: no cover diff --git a/sarwaveifrproc/config.yaml b/sarwaveifrproc/config.yaml deleted file mode 100644 index d932ae5..0000000 --- a/sarwaveifrproc/config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -model_intraburst: 'model_intraburst.keras' -model_interburst: 'model_interburst.keras' -scaler_intraburst: 'scaler_intraburst.npy' -scaler_interburst: 'scaler_interburst.npy' -bins_intraburst: 'bins_intraburst' -bins_interburst: 'bins_interburst' -predicted_variables: - - 'hs' - - 'phs0' - - 't0m1' \ No newline at end of file diff --git a/sarwaveifrproc/l2_wave.py b/sarwaveifrproc/l2_wave.py index e6bcf81..d000038 100644 --- a/sarwaveifrproc/l2_wave.py +++ b/sarwaveifrproc/l2_wave.py @@ -1,157 +1,217 @@ import numpy as np import os import xarray as xr -import datatree as dtt + +# import datatree as dtt from scipy import special import sarwaveifrproc -attributes_missing_variables = \ -{ - 'sigma0_filt': {'long_name': 'calibrated sigma0 with BT correction', 'units': 'linear'}, - 'normalized_variance_filt': {'long_name': 'normalized variance with BT correction', 'units': ''}, - 'azimuth_cutoff': {'long_name': 'Azimuthal cut-off (2tau)', 'units': 'm'}, - 'cwave_params': {'long_name': 'CWAVE parameters'}, - 'k_gp': {'long_name': 'Gegenbauer polynoms dimension'}, - 'phi_hf': {'long_name': 'Harmonic functions dimension (odd number)'} +attributes_missing_variables = { + "sigma0_filt": { + "long_name": "calibrated sigma0 with BT correction", + "units": "linear", + }, + "normalized_variance_filt": { + "long_name": "normalized variance with BT correction", + "units": "", + }, + "azimuth_cutoff": {"long_name": "Azimuthal cut-off (2tau)", "units": "m"}, + "cwave_params": {"long_name": "CWAVE parameters"}, + "k_gp": {"long_name": "Gegenbauer polynoms dimension"}, + "phi_hf": {"long_name": "Harmonic functions dimension (odd number)"}, } -def generate_l2_wave_product(xdt, intraburst_model, interburst_model, intraburst_scaler, interburst_scaler, intraburst_bins, interburst_bins, predicted_variables): + +def generate_l2_wave_product(xdt, models, models_outputs, predicted_variables): """ Generate a level-2 wave (L2 WAV) product. Parameters: - xdt (dict): DataTree containing intraburst and interburst datasets. - - intraburst_model (tf.keras.Model): Machine learning model for intraburst predictions. - - interburst_model (tf.keras.Model): Machine learning model for interburst predictions. - - intraburst_scaler (Union[StandardScaler, MinMaxScaler, RobustScaler]): Scaler used for intraburst normalization. - - interburst_scaler (Union[StandardScaler, MinMaxScaler, RobustScaler]): Scaler used for interburst normalization. - - intraburst_bins (dict): Dictionary containing bins for intraburst variables. - - interburst_bins (dict): Dictionary containing bins for interburst variables. - - predicted_variables (list): List of predicted variable names. + - models (dict[str, onnxruntime.InferenceSession): different ml models used + - models_outputs (dict[str, list]): list of variables predicted by each model + - predicted_variables (dict[dict]): variables to add to the product and corresponding model and output name Returns: - l2_wave_product (dtt.DataTree): Level-2 wave product. - + Notes: - The scaler objects should be one of StandardScaler, MinMaxScaler, or RobustScaler from sklearn.preprocessing. """ - kept_variables = ['corner_longitude', 'corner_latitude', 'land_flag', 'sigma0_filt', 'normalized_variance_filt', 'incidence', 'azimuth_cutoff', 'cwave_params'] - ds_intraburst = generate_intermediate_product(xdt['intraburst'].ds, intraburst_model, intraburst_scaler, intraburst_bins, predicted_variables, kept_variables) - ds_interburst = generate_intermediate_product(xdt['interburst'].ds, interburst_model, interburst_scaler, interburst_bins, predicted_variables, kept_variables) - - l2_wave_product = dtt.DataTree.from_dict({"intraburst": ds_intraburst, "interburst": ds_interburst}) - l2_wave_product.attrs['input_sar_product'] = os.path.basename(xdt.encoding['source']) + kept_variables = [ + "corner_longitude", + "corner_latitude", + "land_flag", + "sigma0_filt", + "normalized_variance_filt", + "incidence", + "azimuth_cutoff", + "cwave_params", + ] + intraburst_models = { + k: m + for k, m in models.items() + if k in [v.model for v in predicted_variables.intraburst.values()] + } + interburst_models = { + k: m + for k, m in models.items() + if k in [v.model for v in predicted_variables.interburst.values()] + } + ds_intraburst = generate_intermediate_product( + xdt["intraburst"].ds, + intraburst_models, + models_outputs, + predicted_variables.intraburst, + kept_variables, + ) + ds_interburst = generate_intermediate_product( + xdt["interburst"].ds, + interburst_models, + models_outputs, + predicted_variables.interburst, + kept_variables, + ) + + l2_wave_product = xr.DataTree.from_dict( + {"intraburst": ds_intraburst, "interburst": ds_interburst} + ) + l2_wave_product.attrs["input_sar_product"] = os.path.basename( + xdt.encoding["source"] + ) return l2_wave_product -def generate_intermediate_product(ds, model, scaler, bins, predicted_variables, kept_variables): +def generate_intermediate_product( + ds, models, models_outputs, predicted_variables, kept_variables, pol="VV" +): """ Generate an intermediate l2 product, depending of the input dataset (intraburst or interburst). Parameters: - ds (xarray.Dataset): Input dataset. - - model: Machine learning model for predicting sea parameters. - - scaler: Scaler used for normalization. - - bins (dict): Dictionary containing bins for each variable. - - predicted_variables (list): List of predicted variable names. + - models (dict[str, onnxruntime.InferenceSession): different ml models used + - models_outputs (dict[str, list]): list of variables predicted by each model + - predicted_variables (dict[dict]): variables to add to the product and corresponding model and output name - kept_variables (list): List of variables from the input dataset that are kept in the final product. + - pol (str): polarisation to select Returns: - ds_pred (xarray.Dataset): Intermediate predictions dataset. """ - - if ds['land_flag'].all(): - return generate_product_on_land(ds, predicted_variables, kept_variables, bins) - - if '2tau' in ds.dims: - ds = ds.squeeze(dim='2tau') - ds.attrs['squeezed_dimensions']='2tau' - - tiles = ds[['sigma0_filt', 'normalized_variance_filt', 'incidence', 'azimuth_cutoff', 'cwave_params']] - if 'burst' in ds.coords: - tiles_stacked = tiles.stack(all_tiles = ['burst', 'tile_line','tile_sample'], k_phi = ['phi_hf', 'k_gp']) + + if ds["land_flag"].all(): + return generate_product_on_land(ds, predicted_variables, kept_variables) + + if "2tau" in ds.dims: + ds = ds.squeeze(dim="2tau") + ds.attrs["squeezed_dimensions"] = "2tau" + + tiles = ds[ + [ + "sigma0_filt", + "normalized_variance_filt", + "incidence", + "azimuth_cutoff", + "cwave_params", + ] + ].sel(pol=pol) + if "burst" in ds.coords: + tiles_stacked = tiles.stack( + all_tiles=["burst", "tile_line", "tile_sample"], k_phi=["k_gp", "phi_hf"] + ) else: - tiles_stacked = tiles.stack(all_tiles = ['tile_line','tile_sample'], k_phi = ['phi_hf', 'k_gp']) - - output_dims = [['all_tiles', f'{v}_mid'] for v in predicted_variables] - inds = np.cumsum([0] + [len(bins[v]) -1 for v in predicted_variables]) - - predictions = xr.apply_ufunc(predict_variables, - tiles_stacked['sigma0_filt'], tiles_stacked['normalized_variance_filt'], tiles_stacked['incidence'], - tiles_stacked['azimuth_cutoff'], tiles_stacked['cwave_params'], - model, scaler, inds, - input_core_dims=[['all_tiles'], ['all_tiles'], ['all_tiles'], ['all_tiles'], ['all_tiles','k_phi'], [], [], []], - output_core_dims=output_dims, - vectorize=False) - - ds_pred = format_dataset(ds, predictions, predicted_variables, kept_variables, bins) - + tiles_stacked = tiles.stack( + all_tiles=["tile_line", "tile_sample"], k_phi=["k_gp", "phi_hf"] + ) + + output_dims = [["preds", "all_tiles"]] + + predictions = xr.concat( + [ + xr.apply_ufunc( + predict_variables, + tiles_stacked["sigma0_filt"], + tiles_stacked["normalized_variance_filt"], + tiles_stacked["incidence"], + tiles_stacked["azimuth_cutoff"], + tiles_stacked["cwave_params"], + model, + input_core_dims=[ + ["all_tiles"], + ["all_tiles"], + ["all_tiles"], + ["all_tiles"], + ["all_tiles", "k_phi"], + [], + ], + output_core_dims=output_dims, + vectorize=False, + ).assign_coords(preds=[f"{k}_{o}" for o in models_outputs[k]]) + for (k, model) in models.items() + ], + dim="preds", + ) + + ds_pred = format_dataset(ds, predictions, predicted_variables, kept_variables) + return ds_pred -def generate_product_on_land(ds, predicted_variables, kept_variables, bins): - """ + +def generate_product_on_land(ds, predicted_variables, kept_variables): + """ Patch function when input dataset does not contain all necessary variables. - + Parameters: - ds (xarray.Dataset): Input dataset. - predicted_variables (list): List of predicted variable names. - kept_variables (list): List of variables from the input dataset that are kept in the final product. - bins (dict): Dictionary containing bins for each variable. """ - ds = ds[['corner_longitude', 'corner_latitude', 'incidence', 'land_flag']] - - shape = ds['land_flag'].shape - dims = ds['land_flag'].dims - + ds = ds[["corner_longitude", "corner_latitude", "incidence", "land_flag"]] + + shape = ds["land_flag"].shape + dims = ds["land_flag"].dims + data_to_merge = [] if set(predicted_variables).issubset(ds.keys()): kept_variables = kept_variables + predicted_variables - + for v in kept_variables: if v not in ds.variables: - v_array = xr.DataArray(data = np.full(shape, np.nan), dims = dims).rename(v) + v_array = xr.DataArray(data=np.full(shape, np.nan), dims=dims).rename(v) v_array.attrs = attributes_missing_variables[v] data_to_merge.append(v_array) - - # Add CWAVES independently because the required dimensions are not in the input dataset when there is only land. - k_gp = xr.DataArray(data=range(1, 5), dims='k_gp') - k_gp.attrs = attributes_missing_variables['k_gp'] - phi_hf = xr.DataArray(data=range(1, 6), dims='phi_hf') - phi_hf.attrs = attributes_missing_variables['phi_hf'] - - cwaves = xr.DataArray(data=np.full(shape + (4, 5), np.nan), dims=dims + ('k_gp', 'phi_hf')).rename('cwave_params') - cwaves.attrs = attributes_missing_variables['cwave_params'] - data_to_merge.append(cwaves.assign_coords({'k_gp': k_gp, 'phi_hf': phi_hf})) - - for v in predicted_variables: - attributes = get_attributes(v) - v_mid = xr.DataArray(data=(bins[v][:-1] + bins[v][1:])/2, dims=f'{v}_mid') - v_mid.attrs = attributes['mid'] - - v_pdf = xr.DataArray(data=np.full(v_mid.shape + shape, np.nan), dims=(f'{v}_mid', ) + dims).rename(f'{v}_pdf') - v_pdf.attrs = attributes['pdf'] - data_to_merge.append(v_pdf.assign_coords({f'{v}_mid': v_mid})) - - v_mean = xr.DataArray(data=np.full(shape, np.nan), dims=dims).rename(f'{v}_mean') - v_mean.attrs = attributes['mean'] - data_to_merge.append(v_mean) - - v_most_likely = xr.DataArray(data=np.full(shape, np.nan), dims=dims).rename(f'{v}_most_likely') - v_most_likely.attrs = attributes['most_likely'] - data_to_merge.append(v_most_likely) - - v_std = xr.DataArray(data=np.full(shape, np.nan), dims=dims).rename(f'{v}_std') - v_std.attrs = attributes['std'] - data_to_merge.append(v_std) - - ds = xr.merge([ds] + data_to_merge) - - ds.attrs.pop('name', None) - ds.attrs.pop('multidataset', None) + + # Add CWAVES independently because the required dimensions are not in the input dataset when there is only land. + k_gp = xr.DataArray(data=range(1, 5), dims="k_gp") + k_gp.attrs = attributes_missing_variables["k_gp"] + phi_hf = xr.DataArray(data=range(1, 6), dims="phi_hf") + phi_hf.attrs = attributes_missing_variables["phi_hf"] + + cwaves = xr.DataArray( + data=np.full(shape + (4, 5), np.nan), dims=dims + ("k_gp", "phi_hf") + ).rename("cwave_params") + cwaves.attrs = attributes_missing_variables["cwave_params"] + data_to_merge.append(cwaves.assign_coords({"k_gp": k_gp, "phi_hf": phi_hf})) + + for v, vd in predicted_variables.items(): + attributes = vd.attrs + preds = xr.DataArray(data=np.full(shape, np.nan), dims=dims).rename(v) + preds.attrs = attributes + data_to_merge.append(preds) + + ds = xr.merge([ds] + data_to_merge) + + ds.attrs.pop("name", None) + ds.attrs.pop("multidataset", None) return ds -def predict_variables(sigma0, normalized_variance, incidence, azimuth_cutoff, cwave_params, model, scaler, indices): + +def predict_variables( + sigma0, normalized_variance, incidence, azimuth_cutoff, cwave_params, model +): """ Launch predictions using a neural model. @@ -162,23 +222,23 @@ def predict_variables(sigma0, normalized_variance, incidence, azimuth_cutoff, cw - azimuth_cutoff (xarray.DataArray): Array containing azimuth cutoff values. - cwave_params (xarray.DataArray): Array containing the cwave parameters. - model (tf.keras.Model): Machine learning model for predicting sea parameters. - - scaler (Union[StandardScaler, MinMaxScaler, RobustScaler]): Scaler used for normalization. - - indices (list): List of indices to segment predictions. Returns: - res (tuple): Tuple containing predictions for each variable. - """ + """ X_stacked = np.vstack([sigma0, normalized_variance, incidence, azimuth_cutoff]).T - X_stacked = np.hstack([X_stacked, cwave_params]) - X_normalized = scaler.transform(X_stacked) - - predictions = model.predict(X_normalized) - - res = tuple(predictions[:, indices[i]:indices[i+1]] for i in range(len(indices)-1)) - + X_stacked = np.hstack([cwave_params, X_stacked]) + # X_normalized = scaler.transform(X_stacked) + + input_name = model.get_inputs()[0].name + inputs = {input_name: X_stacked.astype(np.float32)} + res = model.run(None, inputs) + res = [r[..., 0] for r in res] + return res -def format_dataset(ds, predictions, predicted_variables, kept_variables, bins): + +def format_dataset(ds, predictions, predicted_variables, kept_variables): """ Format a dataset based on predictions, variables, and bins. @@ -196,137 +256,26 @@ def format_dataset(ds, predictions, predicted_variables, kept_variables, bins): from the predictions and adds them to the new dataset. It also includes original data from the input dataset and some additional attributes. """ - + data_to_merge = [] - - n = len(predicted_variables) - for i in range(n): - v = predicted_variables[i] - attributes = get_attributes(v) - v_mid = xr.DataArray(data = (bins[v][:-1] + bins[v][1:])/2, dims = f'{v}_mid') - v_mid.attrs = attributes['mid'] - - v_pdf = xr.apply_ufunc(lambda x: special.softmax(x, axis=1), - predictions[i]).rename(f'{v}_pdf').assign_coords({f'{v}_mid': v_mid}).to_dataset() - v_pdf[f'{v}_pdf'].attrs = attributes['pdf'] - data_to_merge.append(v_pdf) - - v_mean = compute_values(v_pdf, v, compute_mean).rename(f'{v}_mean') - v_mean.attrs = attributes['mean'] - data_to_merge.append(v_mean) - - v_most_likely = compute_values(v_pdf, v, get_most_likely).rename(f'{v}_most_likely') - v_most_likely.attrs = attributes['most_likely'] - data_to_merge.append(v_most_likely) - - v_std = compute_values(v_pdf, v, compute_std, True).rename(f'{v}_std') - v_std.attrs = attributes['std'] - data_to_merge.append(v_std) + + for v, vd in predicted_variables.items(): + attributes = vd.attrs + preds = predictions.sel(preds=f"{vd.model}_{vd.output}", drop=True).rename(v) + preds.attrs = attributes + data_to_merge.append(preds) if set(predicted_variables).issubset(ds.keys()): kept_variables = kept_variables + predicted_variables - + data_to_merge = [ds.unstack() for ds in data_to_merge] ds = ds[kept_variables] - - ds = xr.merge([ds] + data_to_merge) - ds = ds.drop_vars(['tile_line', 'tile_sample']) - ds.attrs['l2_processor_name'] = 'sarwaveifrproc' - ds.attrs['l2_processor_version'] = sarwaveifrproc.__version__ - ds.attrs.pop('name', None) - ds.attrs.pop('multidataset', None) - - return ds -def compute_values(ds, var, function, vectorize=False): - """ - Compute values for the given variable using the given function. - - Args: - ds (xr.Dataset): Dataset containing the data. - var (str): Variable name. - function (callable): Function to compute values. - vectorize (bool): Whether to vectorize the computation. - - Returns: - xr.DataArray: Computed values. - """ - values = xr.apply_ufunc(function, - ds[f'{var}_mid'], ds[f'{var}_pdf'], - input_core_dims=[[f'{var}_mid'],[f'{var}_mid']], - vectorize=vectorize) - return values - -def get_attributes(var): - """ - Generate a dictionary of attributes for a given variable. - - Parameters: - - var (str): The variable name for which attributes are to be generated. This should correspond to keys in the `spec_dict`. - - Returns: - - attributes (dict): A dictionary containing various attributes for the given variable. - - The function uses a predefined `spec_dict` to look up information about the variable. - If the variable is not found in `spec_dict`, default attributes are generated with the variable name itself. - """ - spec_dict = { - 'hs': {'long_name': 'significant wave height', 'units': 'm'}, - 'phs0': {'long_name': 'wind sea significant wave height', 'units': 'm'}, - 't0m1': {'long_name': 'mean wave period', 'units': 's'}, - } - - var_spec = spec_dict.get(var, {"long_name": f'{var}', 'units': ''}) - - attributes = { - 'mid': {'long_name': f'central values of the bins used for discretizing the range of {var_spec["long_name"]} encountered during neural model training', 'units': var_spec['units']}, - 'pdf': {'long name': f'{var_spec["long_name"]} discrete probability density function', 'units': 'probability'}, - 'mean': {'long name': f'first-order moment of the {var_spec["long_name"]} discrete probability density function', 'units': var_spec['units']}, - 'most_likely': {'long name': f'most likely value of {var_spec["long_name"]} given its discrete probability density function', 'units': var_spec['units']}, - 'std': {'long name': f'square root of the second-order moment of the {var_spec["long_name"]} discrete probability density function', 'units': var_spec['units']} - } - - return attributes - -def get_most_likely(x, y): - """ - Get the maximum of probability for each prediction. - - Args: - x (np.ndarray): Input values. - y (np.ndarray): Probabilities. - - Returns: - np.ndarray: Most likely values. - """ - i_max = np.argmax(y, axis=1) - most_likely = x[i_max] - return most_likely - -def compute_mean(x, y): - """ - Compute the expected value. - - Args: - x (np.ndarray): Input values. - y (np.ndarray): Probabilities. - - Returns: - np.ndarray: Expected value. - """ - return np.sum(x * y, axis=1) - -def compute_std(x, y): - """ - Compute the standard deviation. - - Args: - x (np.ndarray): Input values. - y (np.ndarray): Probabilities. + ds = xr.merge([ds] + data_to_merge) + ds = ds.drop_vars(["tile_line", "tile_sample"]) + ds.attrs["l2_processor_name"] = "sarwaveifrproc" + ds.attrs["l2_processor_version"] = sarwaveifrproc.__version__ + ds.attrs.pop("name", None) + ds.attrs.pop("multidataset", None) - Returns: - np.ndarray: Standard deviation. - """ - mean = np.sum(x * y) - variance = np.sum(y * (x - mean) ** 2) - return np.sqrt(variance) + return ds diff --git a/sarwaveifrproc/main.py b/sarwaveifrproc/main.py index e40a129..9792603 100644 --- a/sarwaveifrproc/main.py +++ b/sarwaveifrproc/main.py @@ -1,124 +1,144 @@ -import argparse import logging +import hydra_zen +import hydra import os -import sys import glob import numpy as np -from sarwaveifrproc.utils import get_output_safe, load_config, load_models, process_files - -def parse_args(): - - parser = argparse.ArgumentParser(description="Generate a L2 WAVE product from a L1B or L1C SAFE.") - - # Define arguments - parser.add_argument("--input_path", required=True, help="l1b or l1c safe path or listing path (.txt file).") - parser.add_argument("--save_directory", required=True, help="where to save output data.") - parser.add_argument("--product_id", required=True, help="3 digits ID representing the processing options. Ex: E00.") - - # Group related arguments under 'model' and 'bins' - model_group = parser.add_argument_group('model', 'Arguments related to the neural models') - model_group.add_argument("--model_intraburst", required=False, help="neural model path to predict sea states on intraburst data.") - model_group.add_argument("--model_interburst", required=False, help="neural model path to predict sea states on interburst data.") - - model_group.add_argument("--scaler_intraburst", required=False, help="scaler path to standardize intraburst data before feeding it to the neural model.") - model_group.add_argument("--scaler_interburst", required=False, help="scaler path to standardize interburst data before feeding it to the neural model.") - - model_group.add_argument("--bins_intraburst", required=False, help="bins path that depicts the range of predictions on intraburst data.") - model_group.add_argument("--bins_interburst", required=False, help="bins path that depicts the range of predictions on interburst data.") - - model_group.add_argument("--predicted_variables", required=False, help="list of sea states variables to predict.") - - # Other arguments - parser.add_argument("--overwrite", action="store_true", default=False, help="overwrite the existing outputs") - parser.add_argument("--verbose", action="store_true", default=False) - - args = parser.parse_args() - return args - - -def setup_logging(verbose=False): - fmt = '%(asctime)s %(levelname)s %(filename)s(%(lineno)d) %(message)s' - level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig(level=level, format=fmt, datefmt='%d/%m/%Y %H:%M:%S', force=True) - -def get_files(dir_path, listing): - - fn = [] - for s in listing: - search_path = os.path.join(dir_path, s.replace('WAVE', 'XSP_'), '*-?v-*.nc') - fn+=glob.glob(search_path) - - print('Number of files :', len(fn)) - return fn - - -def main(): - args = parse_args() - setup_logging(args.verbose) - - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow INFO and WARNING messages - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Hide CUDA devices - +import sarwaveifrproc.utils as utils +from dataclasses import dataclass +import onnxruntime + + +@dataclass +class Model: + """ + path: path to the onnx file + outputs: names of the model output variables + """ + + path: str + outputs: list[str] + + +@dataclass +class Prediction: + """ + model: name of the model that predict this variable + output_name: name of the model output + attrs: attribute dictionnary of the variable + """ + + model: str + output: str + attrs: dict[str, str] + + +@dataclass +class PredictedVariables: + """ + intraburst: variables to add in the intraburst + interburst: variables to add in the interburst + """ + + intraburst: dict[str, Prediction] + interburst: dict[str, Prediction] + + +def main( + input_path, + save_directory: str, + product_id: str, + models: dict[str, Model], + predicted_variables: PredictedVariables, + overwrite: bool = False, + verbose: bool = False, +): + """ + Generate a L2 WAVE product from a L1B or L1C SAFE. + + input_path: l1b or l1c safe path or listing path (.txt file). + save_directory: where to save output data + product_id: 3 digits ID representing the processing options. Ex: E00. + models: onnx models and output + predicted_variables: model outputs and associated variables name to add to the L2 product + overwrite: overwrite the existing outputs + verbose: debug log level if True + """ + + setup_logging(verbose) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = ( + "3" # Suppress TensorFlow INFO and WARNING messages + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide CUDA devices logging.info("Loading configuration file...") - conf = load_config() - input_path = args.input_path - save_directory = args.save_directory - product_id = args.product_id - predicted_variables = args.predicted_variables or conf['predicted_variables'] + logging.info("Loading models...") + ort_mods = {k: onnxruntime.InferenceSession(d.path) for k, d in models.items()} + mod_outs = {k: d.outputs for k, d in models.items()} + logging.info("Models loaded.") - if input_path.endswith('.txt'): + if input_path.endswith(".txt"): files = np.loadtxt(input_path, dtype=str) - output_safes = np.array([get_output_safe(f, save_directory, product_id) for f in files]) + output_safes = np.array( + [utils.get_output_safe(f, save_directory, product_id) for f in files] + ) - if not args.overwrite: + if not overwrite: mask = np.array([os.path.exists(f) for f in output_safes]) files, output_safes = files[~mask], output_safes[~mask] - logging.info(f"{np.sum(mask)} file(s) already exist(s) and overwriting is not allowed. Use --overwrite to overwrite existing files.") + logging.info( + f"{np.sum(mask)} file(s) already exist(s) and overwriting is not allowed. Use --overwrite to overwrite existing files." + ) if not files.size: return None - logging.info('Loading models...') - # Define the paths. When args paths are None, the conf paths are set by default (the "or" operator returns first element if true, second element if first is None).  - paths = { - 'model_intraburst': args.model_intraburst or conf['model_intraburst'], - 'model_interburst': args.model_interburst or conf['model_interburst'], - 'scaler_intraburst': args.scaler_intraburst or conf['scaler_intraburst'], - 'scaler_interburst': args.scaler_interburst or conf['scaler_interburst'], - 'bins_intraburst': args.bins_intraburst or conf['bins_intraburst'], - 'bins_interburst': args.bins_interburst or conf['bins_interburst'], - } - model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst = load_models(paths, predicted_variables) - logging.info('Models loaded.') - - logging.info('Processing files...') + logging.info("Processing files...") for f, output_safe in zip(files, output_safes): - process_files(f, output_safe, model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst, predicted_variables, product_id) + utils.process_files( + f, output_safe, ort_mods, mod_outs, predicted_variables, product_id + ) - else: logging.info("Checking if output safe already exists...") - output_safe = get_output_safe(input_path, save_directory, product_id) + output_safe = utils.get_output_safe(input_path, save_directory, product_id) - if os.path.exists(output_safe) and not args.overwrite: - logging.info(f"{output_safe} already exists and overwriting is not allowed. Use --overwrite to overwrite existing files.") + if os.path.exists(output_safe) and not overwrite: + logging.info( + f"{output_safe} already exists and overwriting is not allowed. Use --overwrite to overwrite existing files." + ) return None - logging.info('Loading models...') - paths = { - 'model_intraburst': args.model_intraburst or conf['model_intraburst'], - 'model_interburst': args.model_interburst or conf['model_interburst'], - 'scaler_intraburst': args.scaler_intraburst or conf['scaler_intraburst'], - 'scaler_interburst': args.scaler_interburst or conf['scaler_interburst'], - 'bins_intraburst': args.bins_intraburst or conf['bins_intraburst'], - 'bins_interburst': args.bins_interburst or conf['bins_interburst'], - } - model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst = load_models(paths, predicted_variables) - logging.info('Models loaded.') - - logging.info('Processing files...') - process_files(input_path, output_safe, model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst, predicted_variables, product_id) - - logging.info(f'Processing terminated. Output directory: \n{save_directory}') + logging.info("Processing files...") + utils.process_files( + input_path, output_safe, ort_mods, mod_outs, predicted_variables, product_id + ) + + logging.info(f"Processing terminated. Output directory: \n{save_directory}") + + +def setup_logging(verbose=False): + fmt = "%(asctime)s %(levelname)s %(filename)s(%(lineno)d) %(message)s" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, format=fmt, datefmt="%d/%m/%Y %H:%M:%S", force=True + ) + + +def get_files(dir_path, listing): + + fn = [] + for s in listing: + search_path = os.path.join(dir_path, s.replace("WAVE", "XSP_"), "*-?v-*.nc") + fn += glob.glob(search_path) + + print("Number of files :", len(fn)) + return fn + + +hydra_main = hydra.main( + config_name="e10", + config_path="pkg://sarwave_config", + version_base="1.3", +)(hydra_zen.zen(main)) diff --git a/sarwaveifrproc/sarwaveifrproc.py b/sarwaveifrproc/sarwaveifrproc.py deleted file mode 100644 index dd0b80e..0000000 --- a/sarwaveifrproc/sarwaveifrproc.py +++ /dev/null @@ -1 +0,0 @@ -"""Main module.""" diff --git a/sarwaveifrproc/utils.py b/sarwaveifrproc/utils.py index c5ff978..d12cd83 100644 --- a/sarwaveifrproc/utils.py +++ b/sarwaveifrproc/utils.py @@ -1,6 +1,7 @@ import sarwaveifrproc -import tensorflow as tf -import datatree as dtt +# import tensorflow as tf +# import datatree as dtt +import xarray as xr import numpy as np import glob import logging @@ -53,7 +54,6 @@ def get_output_safe(l1x_safe, root_savepath, tail='E00'): regexA = re.compile("A[0-9]{2}.SAFE") regexB = re.compile("B[0-9]{2}.SAFE") if re.search(regexA, final_safe.split('_')[-1]) or re.search(regexB, final_safe.split('_')[-1]): - print('match regexp') final_safe = final_safe.replace(final_safe.split('_')[-1], f'{tail.upper()}.SAFE') else: print('no slug existing-> just add the product ID') @@ -157,31 +157,27 @@ def load_models(paths, predicted_variables): return model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst -def process_files(input_safe, output_safe, model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst, predicted_variables, product_id): +def process_files(input_safe, output_safe, models, models_outputs, predicted_variables, product_id): """ Processes files in the input directory, generates predictions, and saves results in the output directory. Parameters: input_safe (str): Input safe path. output_safe (str): Path to the directory where output data will be saved. - model_intraburst (tf.keras.Model): Intraburst model for prediction. - model_interburst (tf.keras.Model): Interburst model for prediction. - scaler_intraburst (RobustScaler): Scaler for intraburst data. - scaler_interburst (RobustScaler): Scaler for interburst data. - bins_intraburst (dict): Dictionary containing intraburst bins for each predicted variable. - bins_interburst (dict): Dictionary containing interburst bins for each predicted variable. + models (dict): dict of onnx runtime inference sessions + models_outputs (dict): dict of List of model outputs names predicted_variables (list): List of variable names to be predicted. product_id (str): Identifier for the output product. - +ort_mods, models, predicted_variables, product_id) Returns: None """ - subswath_filenames = glob.glob(os.path.join(input_safe, '*-?v-*.nc')) + subswath_filenames = glob.glob(os.path.join(input_safe, '*?v*.nc')) logging.info(f'{len(subswath_filenames)} subswaths found in given safe.') for path in subswath_filenames: - xdt = dtt.open_datatree(path) - l2_product = generate_l2_wave_product(xdt, model_intraburst, model_interburst, scaler_intraburst, scaler_interburst, bins_intraburst, bins_interburst, predicted_variables) + xdt = xr.DataTree.from_dict(xr.open_groups(path)) + l2_product = generate_l2_wave_product(xdt, models, models_outputs, predicted_variables) os.makedirs(output_safe, exist_ok=True) savepath = get_output_filename(path, output_safe, product_id)