Skip to content

Commit

Permalink
Implement own correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
blakeNaccarato committed Oct 1, 2024
1 parent 4456a5f commit b344904
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 280 deletions.
205 changes: 126 additions & 79 deletions docs/notebooks/e230920_fit_tracks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,15 @@
"source": [
"from __future__ import annotations\n",
"\n",
"from typing import Any\n",
"from warnings import catch_warnings\n",
"\n",
"from boilercore.fits import Fit, fit_from_params, get_guesses\n",
"from boilercore.fits import Fit, fit_from_params\n",
"from boilercv_pipeline.models.path import get_datetime\n",
"from boilercv_pipeline.stages import get_thermal_data\n",
"from boilercv_pipeline.stages.find_tracks import FindTracks as Params\n",
"from dev.docs.nbs import init\n",
"from devtools import pprint\n",
"from more_itertools import one\n",
"from numpy import diagonal, full, inf, isinf, nan, sqrt, where\n",
"from numpy import inf\n",
"from pandas import concat, read_hdf\n",
"from scipy.optimize import OptimizeWarning, curve_fit\n",
"from scipy.stats import t\n",
"\n",
"from boilercv.dimensionless_params import jakob, prandtl"
]
Expand Down Expand Up @@ -84,6 +79,9 @@
"MAX_NUSSELT = 1000\n",
"\"\"\"Maximum Nusselt number to plot.\"\"\"\n",
"\n",
"C_2 = 0.61\n",
"C_3 = 0.33\n",
"\n",
"pprint(params)"
]
},
Expand All @@ -97,24 +95,11 @@
},
"outputs": [],
"source": [
"# C_1 = 1.46\n",
"C_2 = 0.61\n",
"C_3 = 0.33\n",
"\n",
"tracks = (\n",
" concat(\n",
" read_hdf(p, key=\"dst\").assign(**{TC.time(): time})\n",
" for p, time in zip(all_tracks, times, strict=True)\n",
" )\n",
" # .pipe(\n",
" # lambda df: df[\n",
" # (df[C.bub_beta()] > 0)\n",
" # & (df[C.bub_beta()] < MAX_BETA)\n",
" # & (df[C.bub_nusselt()] > 0)\n",
" # & (df[C.bub_nusselt()] < MAX_NUSSELT)\n",
" # & (df[C.bub_fourier()] < MAX_FOURIER)\n",
" # ]\n",
" # )\n",
" .set_index(TC.time())\n",
" .assign(**{\n",
" TC.subcool(): thermal.set_index(TC.time()).loc[times, TC.subcool()], # pyright: ignore[reportArgumentType, reportCallIssue]\n",
Expand All @@ -138,6 +123,11 @@
"\n",
"\n",
"def Nu(Re_b, C_1, C_4): # noqa: N803, D103, N802\n",
" return C_1 * Re_b**C_2 * Pr**C_3 * Ja**C_4\n",
"\n",
"\n",
"def Nu2(x, C_1, C_2, C_3, C_4): # noqa: N803, D103, N802\n",
" Re_b, Ja = x.T # noqa: N806\n",
" return C_1 * Re_b**C_2 * Pr**C_3 * Ja**C_4"
]
},
Expand All @@ -151,56 +141,30 @@
},
"outputs": [],
"source": [
"def fit(\n",
" model: Any,\n",
" free_params: list[str],\n",
" initial_values: dict[str, float],\n",
" x: Any,\n",
" y: Any,\n",
" n: int,\n",
") -> tuple[dict[str, float], dict[str, float]]:\n",
" \"\"\"Get fits and errors for project model.\"\"\"\n",
" with catch_warnings():\n",
" try:\n",
" fits, pcov = curve_fit(\n",
" f=model, p0=get_guesses(free_params, initial_values), xdata=x, ydata=y\n",
" )\n",
" except (RuntimeError, OptimizeWarning):\n",
" dim = len(free_params)\n",
" fits = full(dim, nan)\n",
" pcov = full((dim, dim), nan)\n",
" # Compute confidence interval\n",
" standard_errors = sqrt(diagonal(pcov))\n",
" errors = standard_errors * t.interval(0.95, n)[1]\n",
" # Catching `OptimizeWarning` should be enough, but let's explicitly check for inf\n",
" fits = where(isinf(errors), nan, fits)\n",
" errors = where(isinf(errors), nan, errors)\n",
" return (\n",
" dict(zip(free_params, fits, strict=True)),\n",
" dict(zip([f\"{p}_err\" for p in free_params], errors, strict=True)),\n",
" )\n",
"\n",
"\n",
"fits, errors = fit(\n",
"fits, errors = fit_from_params(\n",
" model=Nu,\n",
" free_params=([\"C_1\", \"C_4\"]),\n",
" initial_values={\"C_1\": 1, \"C_4\": 1.0},\n",
" params=Fit(\n",
" independent_params=([\"Re_b\"]),\n",
" free_params=([\"C_1\", \"C_4\"]),\n",
" values={\"C_1\": 1.0, \"C_4\": 1.0},\n",
" bounds={\"C_1\": (-inf, inf), \"C_4\": (-inf, inf)},\n",
" ),\n",
" x=tracks[C.bub_reynolds()].values,\n",
" y=tracks[C.bub_nusselt()].values,\n",
" n=len(tracks),\n",
")\n",
"display(\n",
" {\n",
" # \"C_1\": C_1,\n",
" \"C_2\": C_2,\n",
" \"C_3\": C_3,\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"],\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" # \"C_5\": 2 * C_1 * (2 - C_2),\n",
" # \"C_6\": 1 + fits[\"C_4\"],\n",
" # \"C_7\": 1 / (2 - C_2),\n",
" },\n",
" dict(\n",
" sorted(\n",
" {\n",
" \"C_2\": C_2,\n",
" \"C_3\": C_3,\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"] * (2 - C_2),\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" \"C_7\": 1 / (2 - C_2),\n",
" }.items()\n",
" )\n",
" ),\n",
" errors,\n",
")"
]
Expand All @@ -218,26 +182,109 @@
"fits, errors = fit_from_params(\n",
" model=Nu,\n",
" params=Fit(\n",
" independent_params=[\"Re_b\"],\n",
" free_params=(nusselt_params := [\"C_1\", \"C_4\"]),\n",
" fixed_params=[],\n",
" bounds={\"C_1\": (-inf, inf), \"C_4\": (-inf, inf)},\n",
" independent_params=([\"Re_b\"]),\n",
" free_params=([\"C_1\", \"C_4\"]),\n",
" values={\"C_1\": 1.0, \"C_4\": 1.0},\n",
" bounds={\"C_1\": (0, inf), \"C_4\": (0, inf)},\n",
" ),\n",
" x=tracks[C.bub_reynolds()].values,\n",
" y=tracks[C.bub_nusselt()].values,\n",
")\n",
"display(\n",
" {\n",
" \"C_2\": C_2,\n",
" \"C_3\": C_3,\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"],\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" # \"C_5\": 2 * C_1 * (2 - C_2),\n",
" # \"C_6\": (1 + fits[\"C_4\"]),\n",
" # \"C_7\": (1 / (2 - C_2)),\n",
" },\n",
" dict(\n",
" sorted(\n",
" {\n",
" \"C_2\": C_2,\n",
" \"C_3\": C_3,\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"] * (2 - C_2),\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" \"C_7\": 1 / (2 - C_2),\n",
" }.items()\n",
" )\n",
" ),\n",
" errors,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"fits, errors = fit_from_params(\n",
" model=Nu2,\n",
" params=Fit(\n",
" independent_params=([\"x\"]),\n",
" free_params=([\"C_1\", \"C_2\", \"C_3\", \"C_4\"]),\n",
" values={\"C_1\": 1.0, \"C_2\": 1.0, \"C_3\": 1.0, \"C_4\": 1.0},\n",
" bounds={\n",
" \"C_1\": (-inf, inf),\n",
" \"C_2\": (-inf, inf),\n",
" \"C_3\": (-inf, inf),\n",
" \"C_4\": (-inf, inf),\n",
" },\n",
" ),\n",
" x=tracks[[C.bub_reynolds(), \"jakob\"]].values,\n",
" y=tracks[C.bub_nusselt()].values,\n",
")\n",
"display(\n",
" dict(\n",
" sorted(\n",
" {\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"] * (2 - C_2),\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" \"C_7\": 1 / (2 - C_2),\n",
" }.items()\n",
" )\n",
" ),\n",
" errors,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"fits, errors = fit_from_params(\n",
" model=Nu2,\n",
" params=Fit(\n",
" independent_params=([\"x\"]),\n",
" free_params=([\"C_1\", \"C_2\", \"C_3\", \"C_4\"]),\n",
" values={\"C_1\": 1.0, \"C_2\": 1.0, \"C_3\": 1.0, \"C_4\": 1.0},\n",
" bounds={\n",
" \"C_1\": (0.0, inf),\n",
" \"C_2\": (0.0, inf),\n",
" \"C_3\": (0.0, inf),\n",
" \"C_4\": (0.0, inf),\n",
" },\n",
" ),\n",
" x=tracks[[C.bub_reynolds(), \"jakob\"]].values,\n",
" y=tracks[C.bub_nusselt()].values,\n",
")\n",
"display(\n",
" dict(\n",
" sorted(\n",
" {\n",
" **fits,\n",
" \"C_5\": 2 * fits[\"C_1\"] * (2 - C_2),\n",
" \"C_6\": 1 + fits[\"C_4\"],\n",
" \"C_7\": 1 / (2 - C_2),\n",
" }.items()\n",
" )\n",
" ),\n",
" errors,\n",
")"
]
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/find_tracks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
"# Plotting\n",
"groups = {C.corr[k](): v for k, v in GROUPS.items() if k in C.corr}\n",
"\"\"\"Groups for mapping to correlations in data.\"\"\"\n",
"GROUP_DRAW_ORDER = [\"Group 2\", \"Group 4\", \"Group 3\", \"Group 1\"] # , \"Group 5\"]\n",
"GROUP_DRAW_ORDER = [\"Group 2\", \"Group 4\", \"Group 3\", \"Group 1\", \"Ours\"] # , \"Group 5\"]\n",
"\"\"\"Order to draw groups.\"\"\"\n",
"GROUP_ORDER = sorted(GROUP_DRAW_ORDER)\n",
"\"\"\"Order to show groups in legend.\"\"\"\n",
Expand All @@ -187,7 +187,7 @@
"\"\"\"Maximum mean absolute error of beta to plot.\"\"\"\n",
"MAX_NUSSELT_ERR = 12000\n",
"\"\"\"Maximum mean absolute error of nusselt to plot.\"\"\"\n",
"WIDTH_SCALE = 1.215 # 1.48\n",
"WIDTH_SCALE = 1.48 # 1.215 # 1.48\n",
"\"\"\"Width to scale plots by.\"\"\"\n",
"HEIGHT_SCALE = 1.000\n",
"\"\"\"Width to scale plots by.\"\"\"\n",
Expand Down Expand Up @@ -302,7 +302,7 @@
" {\n",
" lab: h\n",
" for h, lab in zip(*ax.get_legend_handles_labels(), strict=False)\n",
" if \"Group\" in lab\n",
" if \"Group\" in lab or \"Ours\" in lab\n",
" }[lab]\n",
" for lab in GROUP_ORDER\n",
" ],\n",
Expand Down
Loading

0 comments on commit b344904

Please sign in to comment.