Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

param_edits w concurrent futures #235

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 182 additions & 5 deletions examples/param_edits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"As a consequence, one has to make the Parameter class editable. Below this is accomplished by doing\n",
"`the_parameters.to_dd()` which returns a DatasetDict which is editable. One has to know something about\n",
"how DatasetDicts are constructed to edit effectively, information can be found in the [documentation](https://pywatershed.readthedocs.io/en/main/api/generated/pywatershed.base.DatasetDict.html#pywatershed.base.DatasetDict). \n",
"The edited DatasetDict can be made a Parameters object again by `Parameters(**param_dict.data)`, as shown below. "
"The edited DatasetDict can be made a Parameters object again by `Parameters(**param_dict.data)`, as shown below. \n",
"\n",
"Note this notebook needs notebooks 0-1 to have been run in advance."
]
},
{
Expand Down Expand Up @@ -46,6 +48,7 @@
"source": [
"import pathlib as pl\n",
"from pprint import pprint\n",
"import shutil\n",
"\n",
"import numpy as np\n",
"import pywatershed as pws\n",
Expand All @@ -61,7 +64,8 @@
"source": [
"domain_dir = pws.constants.__pywatershed_root__ / \"data/drb_2yr\"\n",
"nb_output_dir = pl.Path(\"./param_edits\")\n",
"nb_output_dir.mkdir(exist_ok=True)"
"nb_output_dir.mkdir(exist_ok=True)\n",
"(nb_output_dir / \"params\").mkdir(exist_ok=True)"
]
},
{
Expand Down Expand Up @@ -89,7 +93,9 @@
" multiplier = ii * 0.05 + 0.75\n",
" print(\"multiplier = \", multiplier)\n",
" param_dict.data_vars[\"K_coef\"] *= multiplier\n",
" param_file_name = nb_output_dir / f\"perturbed_params_{str(ii).zfill(3)}.nc\"\n",
" param_file_name = (\n",
" nb_output_dir / f\"params/perturbed_params_{str(ii).zfill(3)}.nc\"\n",
" )\n",
" param_files += [param_file_name]\n",
" # These could avoid export to netcdf4 if just using in memory\n",
" # could store in a list like: param_list.append(pws.Parameters(**param_dict.data))\n",
Expand All @@ -111,20 +117,191 @@
" # but we can just open the netcdf file as Parameters\n",
" # ds = xr.open_dataset(ff, decode_times=False, decode_timedelta=False)\n",
" # k_coef = ds[\"K_coef\"]\n",
" new_params = pws.Parameters.from_netcdf(ff)\n",
" new_params = pws.parameters.PrmsParameters.from_netcdf(ff)\n",
" k_coef = new_params.data_vars[\"K_coef\"]\n",
" multipliers = k_coef / params.data_vars[\"K_coef\"]\n",
" assert (multipliers - multipliers[0] < 1e-15).all()\n",
" print(multipliers[0])"
]
},
{
"cell_type": "markdown",
"id": "cad8c37c-ddbc-4c0f-868e-c0f9fdce6e78",
"metadata": {},
"source": [
"## A helper function for running the parameters through the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97cc4ba6-9f81-4c6a-8a1f-d2e2a169c986",
"metadata": {},
"outputs": [],
"source": [
"def run_channel_model(output_dir_parent, param_file):\n",
" # for concurrent.futures we have to write this function to file/module\n",
" # so we have to import things that wont be in scope in that case.\n",
" import numpy as np\n",
" import pywatershed as pws\n",
"\n",
" domain_dir = pws.constants.__pywatershed_root__ / \"data/drb_2yr\"\n",
"\n",
" params = pws.parameters.PrmsParameters.from_netcdf(param_file)\n",
"\n",
" param_id = param_file.with_suffix(\"\").name.split(\"_\")[-1]\n",
" nc_output_dir = output_dir_parent / f\"run_params_{param_id}\"\n",
" nc_output_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" control = pws.Control.load(domain_dir / \"control.test\")\n",
" control.edit_end_time(np.datetime64(\"1979-07-01T00:00:00\"))\n",
" control.options = control.options | {\n",
" \"input_dir\": \"01_multi-process_models/nhm_memory\",\n",
" \"budget_type\": \"warn\",\n",
" \"calc_method\": \"numba\",\n",
" \"netcdf_output_dir\": nc_output_dir,\n",
" }\n",
"\n",
" model = pws.Model(\n",
" [pws.PRMSChannel],\n",
" control=control,\n",
" parameters=params,\n",
" )\n",
"\n",
" model.run(finalize=True)\n",
" return nc_output_dir"
]
},
{
"cell_type": "markdown",
"id": "bf2cb631-91da-4488-962c-ed5c80378049",
"metadata": {},
"source": [
"## Serial execution of the model over the parameter files"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8dc58917-2b29-474a-bb56-1d401c68b4ef",
"metadata": {},
"outputs": [],
"source": []
"source": [
"%%time\n",
"serial_output_dirs = []\n",
"serial_output_parent = nb_output_dir / \"serial\"\n",
"if serial_output_parent.exists():\n",
" shutil.rmtree(serial_output_parent)\n",
"serial_output_parent.mkdir()\n",
"for ff in param_files:\n",
" serial_output_dirs += [run_channel_model(serial_output_parent, ff)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81841620-fbc1-49cf-a699-ccad3b2db051",
"metadata": {},
"outputs": [],
"source": [
"serial_output_dirs"
]
},
{
"cell_type": "markdown",
"id": "e3ccefe5-2b6e-40e7-afee-fd6718a3af38",
"metadata": {},
"source": [
"## concurrent.futures approach\n",
"For concurrent futures to work in an interactive setting, we have to import the iterated/mapped function from a module, the function can not be defined in the notebook/interactive setting. We can easily just write the function out to file (ensure above that everything is in scope when imported, as noted in the function)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2188410-5d3e-40b0-a539-b26292ed0e17",
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"\n",
"with open(\"param_edits/run_channel_model.py\", \"w\") as the_file:\n",
" the_file.write(inspect.getsource(run_channel_model))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7913ba4b-3a46-432d-abb5-a86a11f6e02c",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"import time\n",
"from concurrent.futures import ProcessPoolExecutor as PoolExecutor\n",
"from concurrent.futures import as_completed\n",
"from functools import partial\n",
"from param_edits.run_channel_model import run_channel_model\n",
"\n",
"parallel_output_parent = nb_output_dir / \"parallel\"\n",
"if parallel_output_parent.exists():\n",
" shutil.rmtree(parallel_output_parent)\n",
"parallel_output_parent.mkdir()\n",
"\n",
"# you can set your choice of max_workers\n",
"with PoolExecutor(max_workers=11) as executor:\n",
" parallel_output_dirs = executor.map(\n",
" partial(run_channel_model, parallel_output_parent), param_files\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "34d68668-dac4-4f77-9917-ea72074f3dad",
"metadata": {},
"source": [
"# Checks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc4d7600-8488-44f4-81fb-893cb0876c29",
"metadata": {},
"outputs": [],
"source": [
"# check serial == parallel\n",
"serial_runs = sorted(serial_output_parent.glob(\"*\"))\n",
"parallel_runs = sorted(parallel_output_parent.glob(\"*\"))\n",
"\n",
"for ss, pp in zip(serial_runs, parallel_runs):\n",
" serial_files = sorted(ss.glob(\"*.nc\"))\n",
" parallel_files = sorted(pp.glob(\"*.nc\"))\n",
" for sf, pf in zip(serial_files, parallel_files):\n",
" s_ds = xr.open_dataset(sf)\n",
" p_ds = xr.open_dataset(pf)\n",
" xr.testing.assert_allclose(s_ds, p_ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cca4dc2f-c91b-475d-9021-c52261c0c31e",
"metadata": {},
"outputs": [],
"source": [
"# check serial 5 is the same as in notebook 02\n",
"run_005 = serial_output_parent / \"run_params_005\"\n",
"files_005 = sorted(run_005.glob(\"*.nc\"))\n",
"for ff in files_005:\n",
" if ff.name == \"PRMSChannel_budget.nc\":\n",
" continue\n",
" ds_005 = xr.open_dataset(ff)\n",
" ds_02 = xr.open_dataset(\n",
" pl.Path(\"01_multi-process_models/nhm_memory\") / ff.name\n",
" )\n",
" xr.testing.assert_allclose(ds_005, ds_02)"
]
}
],
"metadata": {
Expand Down
Loading