diff --git a/src/darsia/utils/plotting.py b/src/darsia/utils/plotting.py index 1ea79e6e..bad7037c 100644 --- a/src/darsia/utils/plotting.py +++ b/src/darsia/utils/plotting.py @@ -16,32 +16,28 @@ def plot_2d_wasserstein_distance( - grid: darsia.Grid, - mass_diff: np.ndarray, - flux: np.ndarray, - pressure: np.ndarray, - transport_density: np.ndarray, + info: dict, **kwargs, ) -> None: - """Plot the 2d Wasserstein distance between two mass distributions. - - The inputs are assumed to satisfy the layout of the Beckman solution. + """Post-processing utility to plot the 2d Wasserstein distance. Args: - grid (darsia.Grid): grid - mass_diff (np.ndarray): difference of mass distributions - flux (np.ndarray): fluxes - pressure (np.ndarray): pressure - transport_density (np.ndarray): transport density - kwargs: additional keyword arguments + info (dict): information about the Beckman solution, output of + darsia.wasserstein_distance. """ + # Fetch fields + grid = info["grid"] + mass_diff = info["mass_diff"] + flux = info["flux"] + pressure = info["pressure"] + transport_density = info["transport_density"] + # Fetch options - name = kwargs.get("name", None) + path = kwargs.get("path", None) save_plot = kwargs.get("save", False) if save_plot: - folder = kwargs.get("folder", ".") - Path(folder).mkdir(parents=True, exist_ok=True) + Path(path).mkdir(parents=True, exist_ok=True) dpi = kwargs.get("dpi", 500) show_plot = kwargs.get("show", True) @@ -66,7 +62,7 @@ def plot_2d_wasserstein_distance( # Save the plot if save_plot: plt.savefig( - folder + "/" + name + "_beckman_solution_pressure.png", + Path(str(path) + "_beckman_solution_pressure.png"), dpi=dpi, transparent=True, ) @@ -90,7 +86,7 @@ def plot_2d_wasserstein_distance( # Save the plot if save_plot: plt.savefig( - folder + "/" + name + "_beckman_solution_fluxes.png", + Path(str(path) + "_beckman_solution_fluxes.png"), dpi=dpi, transparent=True, ) @@ -105,7 +101,7 @@ def plot_2d_wasserstein_distance( # Save the plot if save_plot: plt.savefig( - folder + "/" + name + "_beckman_solution_transport_density.png", + Path(str(path) + "_beckman_solution_transport_density.png"), dpi=dpi, transparent=True, ) @@ -218,6 +214,9 @@ def to_vtk( img = (img[1], img[2], img[0]) cellData[name] = img + # Make directory if necessary + Path(path).mkdir(parents=True, exist_ok=True) + # Write to VTK gridToVTK(str(Path(path)), x, y, z, cellData=cellData)