Skip to content

Commit

Permalink
MAINT: Add mkdir for robustness, and utilize info from output.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Nov 5, 2023
1 parent 4809a80 commit 4417849
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/darsia/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,28 @@


def plot_2d_wasserstein_distance(
grid: darsia.Grid,
mass_diff: np.ndarray,
flux: np.ndarray,
pressure: np.ndarray,
transport_density: np.ndarray,
info: dict,
**kwargs,
) -> None:
"""Plot the 2d Wasserstein distance between two mass distributions.
The inputs are assumed to satisfy the layout of the Beckman solution.
"""Post-processing utility to plot the 2d Wasserstein distance.
Args:
grid (darsia.Grid): grid
mass_diff (np.ndarray): difference of mass distributions
flux (np.ndarray): fluxes
pressure (np.ndarray): pressure
transport_density (np.ndarray): transport density
kwargs: additional keyword arguments
info (dict): information about the Beckman solution, output of
darsia.wasserstein_distance.
"""
# Fetch fields
grid = info["grid"]
mass_diff = info["mass_diff"]
flux = info["flux"]
pressure = info["pressure"]
transport_density = info["transport_density"]

# Fetch options
name = kwargs.get("name", None)
path = kwargs.get("path", None)
save_plot = kwargs.get("save", False)
if save_plot:
folder = kwargs.get("folder", ".")
Path(folder).mkdir(parents=True, exist_ok=True)
Path(path).mkdir(parents=True, exist_ok=True)
dpi = kwargs.get("dpi", 500)
show_plot = kwargs.get("show", True)

Expand All @@ -66,7 +62,7 @@ def plot_2d_wasserstein_distance(
# Save the plot
if save_plot:
plt.savefig(
folder + "/" + name + "_beckman_solution_pressure.png",
Path(str(path) + "_beckman_solution_pressure.png"),
dpi=dpi,
transparent=True,
)
Expand All @@ -90,7 +86,7 @@ def plot_2d_wasserstein_distance(
# Save the plot
if save_plot:
plt.savefig(
folder + "/" + name + "_beckman_solution_fluxes.png",
Path(str(path) + "_beckman_solution_fluxes.png"),
dpi=dpi,
transparent=True,
)
Expand All @@ -105,7 +101,7 @@ def plot_2d_wasserstein_distance(
# Save the plot
if save_plot:
plt.savefig(
folder + "/" + name + "_beckman_solution_transport_density.png",
Path(str(path) + "_beckman_solution_transport_density.png"),
dpi=dpi,
transparent=True,
)
Expand Down Expand Up @@ -218,6 +214,9 @@ def to_vtk(
img = (img[1], img[2], img[0])
cellData[name] = img

# Make directory if necessary
Path(path).mkdir(parents=True, exist_ok=True)

# Write to VTK
gridToVTK(str(Path(path)), x, y, z, cellData=cellData)

Expand Down

0 comments on commit 4417849

Please sign in to comment.