Skip to content

Commit

Permalink
Updated after lint check
Browse files Browse the repository at this point in the history
  • Loading branch information
halehawk committed Dec 7, 2023
1 parent e1ce138 commit 8d7cefa
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 56 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
args: ['--profile=black', '--filter-files']
args: ["--profile=black", "--filter-files"]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1
Expand All @@ -46,4 +46,3 @@ repos:
additional_dependencies: [pyupgrade==2.7.3]
- id: nbqa-isort
additional_dependencies: [isort==5.12.0]

113 changes: 59 additions & 54 deletions notebooks/teopb_MPAS_ECT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"outputs": [],
"source": [
"summ_path = \"/glade/work/abaker/mpas_data/100_ens_summary\"\n",
"summ_files = [summ_path+\"/mpas_sum_ts\"+str(i)+\".nc\" for i in [6, 9, 12, 15, 18]]\n",
"summ_files = [summ_path + \"/mpas_sum_ts\" + str(i) + \".nc\" for i in [6, 9, 12, 15, 18]]\n",
"test_summary = xr.open_dataset(summ_files[0])\n",
"var_strings = np.char.rstrip(np.char.decode(test_summary.vars))"
]
Expand Down Expand Up @@ -62,19 +62,19 @@
"# Correlation Plots\n",
"\n",
"summ_path = \"/glade/work/abaker/mpas_data/100_ens_summary\"\n",
"summ_files = [summ_path+\"/mpas_sum_ts\"+str(i)+\".nc\" for i in [6, 9, 12, 15, 18]]\n",
"summ_files = [summ_path + \"/mpas_sum_ts\" + str(i) + \".nc\" for i in [6, 9, 12, 15, 18]]\n",
"\n",
"cutoff = .85\n",
"cutoff = 0.85\n",
"\n",
"fig, axs = plt.subplots(len(summ_files), figsize=(10, 40))\n",
"\n",
"for i, f in enumerate(summ_files):\n",
" test_summary = xr.open_dataset(f)\n",
" corr = np.corrcoef(test_summary.global_mean)\n",
" sns.heatmap(corr, vmin=-1, vmax=1, ax=axs[i])\n",
" axs[i].set_title(\"Timeslice = \" + str(i*3 + 6))\n",
" \n",
" correlated_entries = np.argwhere((corr > cutoff)|(corr < -cutoff))\n",
" axs[i].set_title(\"Timeslice = \" + str(i * 3 + 6))\n",
"\n",
" correlated_entries = np.argwhere((corr > cutoff) | (corr < -cutoff))\n",
" filt_correlated_entries = []\n",
" for i in correlated_entries:\n",
" if i[0] != i[1]:\n",
Expand Down Expand Up @@ -180,15 +180,15 @@
"\n",
"test_summary = xr.open_dataset(summ_files[0])\n",
"\n",
"fig, axs = plt.subplots(11, 6, sharex=True, figsize=(16,16))\n",
"fig, axs = plt.subplots(11, 6, sharex=True, figsize=(16, 16))\n",
"for i in range(len(var_strings)):\n",
" row = i//6\n",
" row = i // 6\n",
" col = i % 6\n",
" sm.qqplot(test_summary.global_mean[i, :], ax=axs[row, col], fit=True, line=\"45\")\n",
" axs[row, col].set_title(var_strings[i])\n",
"\n",
"for i in range(len(var_strings), 11 * 6):\n",
" row = i//6\n",
" row = i // 6\n",
" col = i % 6\n",
" axs[row, col].axis('off')\n",
"\n",
Expand Down Expand Up @@ -760,7 +760,7 @@
],
"source": [
"test_summary = xr.open_dataset(summ_files[0])\n",
"plt.hist(test_summary.global_mean[0,:], bins=20)"
"plt.hist(test_summary.global_mean[0, :], bins=20)"
]
},
{
Expand Down Expand Up @@ -801,7 +801,7 @@
],
"source": [
"test_summary = xr.open_dataset(summ_files[1])\n",
"plt.hist(test_summary.global_mean[0,:], bins=20)"
"plt.hist(test_summary.global_mean[0, :], bins=20)"
]
},
{
Expand Down Expand Up @@ -842,7 +842,7 @@
],
"source": [
"test_summary = xr.open_dataset(summ_files[2])\n",
"plt.hist(test_summary.global_mean[0,:], bins=20)"
"plt.hist(test_summary.global_mean[0, :], bins=20)"
]
},
{
Expand Down Expand Up @@ -882,7 +882,7 @@
}
],
"source": [
"plt.hist(test_summary.global_mean[0,:], bins=20)"
"plt.hist(test_summary.global_mean[0, :], bins=20)"
]
},
{
Expand All @@ -893,7 +893,7 @@
"outputs": [],
"source": [
"hist_path = \"/glade/scratch/abaker/mpas_hist\"\n",
"hist_files = [hist_path+\"/history.\"+str(i).zfill(3)+\".nc\" for i in range(100)]"
"hist_files = [hist_path + \"/history.\" + str(i).zfill(3) + \".nc\" for i in range(100)]"
]
},
{
Expand Down Expand Up @@ -933,32 +933,31 @@
"source": [
"def plot_vars(data, lats, lons, title, unit, save=False, filename=None):\n",
" plt.clf()\n",
" \n",
"\n",
" cmap = sns.color_palette(\"flare\", as_cmap=True)\n",
" \n",
"\n",
" # f, ax = plt.subplots(figsize=(15, 10))\n",
" \n",
"\n",
" f, ax = plt.subplots(figsize=(15, 10), subplot_kw={'projection': ccrs.Robinson()})\n",
" \n",
"\n",
" ax.coastlines(alpha=0.3)\n",
" \n",
"\n",
" # points = ax.scatter(x=lons, y=lats, c=data, s=7, cmap=cmap)\n",
" \n",
"# convert from radians to degrees\n",
"\n",
" # convert from radians to degrees\n",
"\n",
" lats = lats * 180 / np.pi\n",
" lons = lons * 180 / np.pi\n",
" lons = ((lons - 180.0) % 360.0) - 180.0\n",
" \n",
"\n",
" points = ax.scatter(x=lons, y=lats, c=data, s=7, cmap=cmap, transform=ccrs.PlateCarree())\n",
" \n",
" f.colorbar(points, label = unit)\n",
" \n",
"\n",
" f.colorbar(points, label=unit)\n",
"\n",
" plt.title(title)\n",
" plt.xlabel(\"Longitude (rad)\")\n",
" plt.ylabel(\"Latitude (rad)\")\n",
" \n",
" \n",
"\n",
" if filename != None:\n",
" plt.savefig(filename)\n",
" plt.close('all')\n",
Expand Down Expand Up @@ -1003,7 +1002,7 @@
" all_hist_vars.append(get_var_timeslice(data, var_name, time_slice).data)\n",
"\n",
" stacked_vars = np.vstack(all_hist_vars)\n",
" \n",
"\n",
" if opp_name == \"mean\":\n",
" trans_var = np.mean(stacked_vars, axis=0)\n",
" title = \"Mean \" + title\n",
Expand All @@ -1017,8 +1016,8 @@
" trans_var = np.amax(stacked_vars, axis=0)\n",
" title = \"Max \" + title\n",
" if filepath != None:\n",
" filename = \"global_\"+var_name +\"_\" +opp_name+ \"_\" + \"t\"+ str(time_slice)+\".png\"\n",
" filename = filepath +\"/\" + filename\n",
" filename = \"global_\" + var_name + \"_\" + opp_name + \"_\" + \"t\" + str(time_slice) + \".png\"\n",
" filename = filepath + \"/\" + filename\n",
" plot_vars(trans_var, lats, lons, title, units, filename=filename)\n",
" else:\n",
" plot_vars(trans_var, lats, lons, title, units)"
Expand Down Expand Up @@ -1341,9 +1340,9 @@
" all_hist_vars.append(get_var_timeslice(data, var_name, time_slice).data)\n",
"\n",
" stacked_vars = np.vstack(all_hist_vars)\n",
" \n",
"\n",
" _, cell_count = stacked_vars.shape\n",
" \n",
"\n",
" if opp_name == \"mean\":\n",
" trans_var = np.mean(stacked_vars, axis=0)\n",
" title = \"Mean \" + title\n",
Expand All @@ -1356,38 +1355,40 @@
" elif opp_name == \"max\":\n",
" trans_var = np.amax(stacked_vars, axis=0)\n",
" title = \"Max \" + title\n",
" \n",
"# Create histogram of nonzero transformed variable cells (aka cells where peak to peak is greater than zero)\n",
"# Actually plot bins after first of 100 to enable viewing relevant values.\n",
"\n",
" # Create histogram of nonzero transformed variable cells (aka cells where peak to peak is greater than zero)\n",
" # Actually plot bins after first of 100 to enable viewing relevant values.\n",
" nonzero = np.count_nonzero(trans_var)\n",
" n, bins, patches= plt.hist(trans_var, bins=100)\n",
" n, bins, patches = plt.hist(trans_var, bins=100)\n",
" plt.clf()\n",
" plt.hist(trans_var[trans_var > bins[1]], bins=99)\n",
" title_str = f\"{title} \\n Cell {opp_name} Hist > {bins[1]:.4f}, Nonzero: {(nonzero/cell_count)*100:.1f}%\"\n",
" title_str = (\n",
" f\"{title} \\n Cell {opp_name} Hist > {bins[1]:.4f}, Nonzero: {(nonzero/cell_count)*100:.1f}%\"\n",
" )\n",
" x_label = f\"{opp_name} {units}\"\n",
" plt.xlabel(x_label)\n",
" plt.title(title_str)\n",
" \n",
"\n",
" if filepath != None:\n",
" filename = f\"{var_name}_{opp_name}_hist_t_{str(time_slice)}.png\"\n",
" filename = filepath +\"/\" + filename\n",
" filename = filepath + \"/\" + filename\n",
" plt.savefig(filename)\n",
" plt.clf()\n",
" else:\n",
" plt.show()\n",
" \n",
"# Get cell indices in last bin\n",
" n_last = n[-1]\n",
"\n",
" # Get cell indices in last bin\n",
" n_last = n[-1]\n",
" cell_idx = np.argwhere(trans_var > bins[-2])[0][0]\n",
" plt.hist(stacked_vars[:, cell_idx], bins=100)\n",
" title_str = f\"{title} \\n Hist of cell in top {opp_name} bin (1 of {int(n_last)})\"\n",
" plt.title(title_str)\n",
" x_label = f\"{units}\"\n",
" plt.xlabel(x_label)\n",
" \n",
"\n",
" if filepath != None:\n",
" filename = f\"{var_name}_{opp_name}_tophist_t_{str(time_slice)}.png\"\n",
" filename = filepath +\"/\" + filename\n",
" filename = filepath + \"/\" + filename\n",
" plt.savefig(filename)\n",
" plt.clf()\n",
" else:\n",
Expand Down Expand Up @@ -1451,7 +1452,9 @@
"for t in [6, 18]:\n",
" for name in var_strings:\n",
" try:\n",
" ensemble_variance_cause(hist_files, name, t, \"ptp\", filepath=\"/glade/u/home/teopb/figures\")\n",
" ensemble_variance_cause(\n",
" hist_files, name, t, \"ptp\", filepath=\"/glade/u/home/teopb/figures\"\n",
" )\n",
" except:\n",
" print(f\"issue with {name}\")"
]
Expand All @@ -1465,6 +1468,8 @@
"source": [
"# Spread of ensemble global mean over time\n",
"long_summ_filepath = \"/glade/work/abaker/mpas_data/100_ens_summary/\"\n",
"\n",
"\n",
"def gm_time_spread(long_summ_filepath, hist_file, var_name=None, filepath=None):\n",
" summ_files = [fn for fn in listdir(long_summ_filepath) if fn.endswith(\".nc\")]\n",
"\n",
Expand All @@ -1481,7 +1486,7 @@
" time_size = len(timesteps)\n",
"\n",
" temp_array = np.empty((time_size, var_size, ens_size))\n",
" \n",
"\n",
" hist_data = xr.open_dataset(hist_file)\n",
"\n",
" for t, step in enumerate(timesteps):\n",
Expand All @@ -1490,33 +1495,33 @@
" data = xr.open_dataset(file_name)\n",
" # print(data.global_mean.shape)\n",
" for v in range(len(var_names)):\n",
" temp_array[t, v, :] = data.global_mean[v,:]\n",
" temp_array[t, v, :] = data.global_mean[v, :]\n",
"\n",
" means = np.mean(temp_array, axis=2)\n",
" \n",
"# Plot 1 variable and display\n",
"\n",
" # Plot 1 variable and display\n",
" if var_name != None:\n",
" title, units = get_info(hist_data, var_name, timesteps[0])\n",
" var_idx = np.where(var_names == var_name)[0][0]\n",
" plt.clf()\n",
" for i in range(ens_size):\n",
" diff = temp_array[:, var_idx, i] - means[:, var_idx]\n",
" plt.plot(timesteps, diff, alpha=.5)\n",
" \n",
" plt.plot(timesteps, diff, alpha=0.5)\n",
"\n",
" title_str = f\"{title} Ensemble Spread From Mean\"\n",
" x_label = \"Timestep\"\n",
" plt.xlabel(x_label)\n",
" plt.title(title_str)\n",
" plt.show()\n",
" \n",
"# Plot all variables and save at filepath\n",
"\n",
" # Plot all variables and save at filepath\n",
" if filepath != None:\n",
" for var_idx, var_name in enumerate(var_names):\n",
" try:\n",
" title, units = get_info(hist_data, var_name, timesteps[0])\n",
" for i in range(ens_size):\n",
" diff = temp_array[:, var_idx, i] - means[:, var_idx]\n",
" plt.plot(timesteps, diff, alpha=.5)\n",
" plt.plot(timesteps, diff, alpha=0.5)\n",
"\n",
" title_str = f\"{title} Ensemble Spread From Mean\"\n",
" x_label = \"Timestep\"\n",
Expand Down

0 comments on commit 8d7cefa

Please sign in to comment.