Skip to content

Commit

Permalink
Add support for plotting 1D function (#2250)
Browse files Browse the repository at this point in the history
* add support for plotting 1D function

* allow more than one variable

* add a test

* format

* format

* simplify

* add comment

* function -> method

* add support for 1D StructuredMesh and 1D DGMulti

* specialize method for 1D meshes

* format

* Apply suggestions from code review

Co-authored-by: Hendrik Ranocha <[email protected]>

* always use solution_variables = cons2cons

* don't pass variable_names to recipe (title can be changed by `title`)

* fix plotting scalar function for DGMulti

* fix for StructuredMesh

* Update src/visualization/types.jl

Co-authored-by: Hendrik Ranocha <[email protected]>

* fix comment

* clarify output of function

---------

Co-authored-by: Hendrik Ranocha <[email protected]>
  • Loading branch information
JoshuaLampert and ranocha authored Feb 5, 2025
1 parent 06e8a20 commit c24ce91
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 15 deletions.
29 changes: 29 additions & 0 deletions src/visualization/recipes_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ RecipesBase.@recipe function f(u, semi::AbstractSemidiscretization;
end
end

# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions.
# We need this recipe in addition to the one above to avoid method ambiguities.
RecipesBase.@recipe function f(func::Function, semi::AbstractSemidiscretization;
solution_variables = nothing)
n_variables = length(func(0.0, semi.equations))
variable_names = SVector(["func[$i]" for i in 1:n_variables]...)
if ndims(semi) == 1
return PlotData1D(func, semi; solution_variables = cons2cons, variable_names)
else
throw(ArgumentError("Plotting of functions is only supported in 1D."))
end
end

# Recipe specifically for TreeMesh-type solutions
# Note: If you change the defaults values here, you need to also change them in the PlotData1D or PlotData2D
# constructor.
Expand All @@ -189,6 +202,22 @@ RecipesBase.@recipe function f(u, semi::SemidiscretizationHyperbolic{<:TreeMesh}
end
end

# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions.
RecipesBase.@recipe function f(func::Function,
semi::SemidiscretizationHyperbolic{<:TreeMesh};
solution_variables = nothing,
nvisnodes = nothing, slice = :xy,
point = (0.0, 0.0, 0.0), curve = nothing)
n_variables = length(func(0.0, semi.equations))
variable_names = SVector(["func[$i]" for i in 1:n_variables]...)
if ndims(semi) == 1
return PlotData1D(func, semi; solution_variables = cons2cons, nvisnodes, slice,
point, curve, variable_names)
else
throw(ArgumentError("Plotting of functions is only supported in 1D."))
end
end

# Series recipe for PlotData2DTriangulated
RecipesBase.@recipe function f(pds::PlotDataSeries{<:PlotData2DTriangulated})
pd = pds.plot_data
Expand Down
60 changes: 45 additions & 15 deletions src/visualization/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,15 @@ end
solution_variables=nothing, nvisnodes=nothing)
Create a new `PlotData1D` object that can be used for visualizing 1D DGSEM solution data array
`u` with `Plots.jl`. All relevant geometrical information is extracted from the semidiscretization
`semi`. By default, the primitive variables (if existent) or the conservative variables (otherwise)
from the solution are used for plotting. This can be changed by passing an appropriate conversion
function to `solution_variables`.
`u` with `Plots.jl`. All relevant geometrical information is extracted from the
semidiscretization `semi`. By default, the primitive variables (if existent)
or the conservative variables (otherwise) from the solution are used for
plotting. This can be changed by passing an appropriate conversion function to
`solution_variables`, e.g., [`cons2cons`](@ref) or [`cons2prim`](@ref).
Alternatively, you can also pass a function `u` with signature `u(x, equations)`
returning a vector. In this case, the `solution_variables` are ignored. This is useful,
e.g., to visualize an analytical solution.
`nvisnodes` specifies the number of visualization nodes to be used. If it is `nothing`,
twice the number of solution DG nodes are used for visualization, and if set to `0`,
Expand All @@ -547,11 +552,19 @@ function PlotData1D(u_ode, semi; kwargs...)
kwargs...)
end

function PlotData1D(func::Function, semi; kwargs...)
PlotData1D(func,
mesh_equations_solver_cache(semi)...;
kwargs...)
end

function PlotData1D(u, mesh::TreeMesh, equations, solver, cache;
solution_variables = nothing, nvisnodes = nothing,
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing)
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing,
variable_names = nothing)
solution_variables_ = digest_solution_variables(equations, solution_variables)
variable_names = SVector(varnames(solution_variables_, equations))
variable_names_ = digest_variable_names(solution_variables_, equations,
variable_names)

original_nodes = cache.elements.node_coordinates
unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations,
Expand Down Expand Up @@ -610,15 +623,17 @@ function PlotData1D(u, mesh::TreeMesh, equations, solver, cache;
end
end

return PlotData1D(x, data, variable_names, mesh_vertices_x,
return PlotData1D(x, data, variable_names_, mesh_vertices_x,
orientation_x)
end

function PlotData1D(u, mesh, equations, solver, cache;
solution_variables = nothing, nvisnodes = nothing,
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing)
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing,
variable_names = nothing)
solution_variables_ = digest_solution_variables(equations, solution_variables)
variable_names = SVector(varnames(solution_variables_, equations))
variable_names_ = digest_variable_names(solution_variables_, equations,
variable_names)

original_nodes = cache.elements.node_coordinates
unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations,
Expand All @@ -642,15 +657,25 @@ function PlotData1D(u, mesh, equations, solver, cache;
slice, point, nvisnodes)
end

return PlotData1D(x, data, variable_names, mesh_vertices_x,
return PlotData1D(x, data, variable_names_, mesh_vertices_x,
orientation_x)
end

function PlotData1D(func::Function, mesh, equations, dg::DGMulti{1}, cache;
solution_variables = nothing, variable_names = nothing)
x = mesh.md.x
u = func.(x, equations)

return PlotData1D(u, mesh, equations, dg, cache;
solution_variables, variable_names)
end

# Specializes the `PlotData1D` constructor for one-dimensional `DGMulti` solvers.
function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache;
solution_variables = nothing)
solution_variables = nothing, variable_names = nothing)
solution_variables_ = digest_solution_variables(equations, solution_variables)
variable_names = SVector(varnames(solution_variables_, equations))
variable_names_ = digest_variable_names(solution_variables_, equations,
variable_names)

orientation_x = 0 # Set 'orientation' to zero on default.

Expand Down Expand Up @@ -679,11 +704,16 @@ function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache;
# Same as above - we create `data_plot` as array of size `num_plotting_points`
# by "number of plotting variables".
x_plot = vec(x)
data_plot = permutedims(reinterpret(reshape, eltype(eltype(data)), vec(data)),
(2, 1))
data_ = reinterpret(reshape, eltype(eltype(data)), vec(data))
# If there is only one solution variable, we need to add a singleton dimension
if ndims(data_) == 1
data_plot = reshape(data_, :, 1)
else
data_plot = permutedims(data_, (2, 1))
end
end

return PlotData1D(x_plot, data_plot, variable_names, mesh.md.VX, orientation_x)
return PlotData1D(x_plot, data_plot, variable_names_, mesh.md.VX, orientation_x)
end

"""
Expand Down
16 changes: 16 additions & 0 deletions src/visualization/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ function digest_solution_variables(equations, solution_variables::Nothing)
end
end

digest_variable_names(solution_variables_, equations, variable_names) = variable_names
digest_variable_names(solution_variables_, equations, ::Nothing) = SVector(varnames(solution_variables_,
equations))

"""
adapt_to_mesh_level!(u_ode, semi, level)
adapt_to_mesh_level!(sol::Trixi.TrixiODESolution, level)
Expand Down Expand Up @@ -481,6 +485,18 @@ function get_unstructured_data(u, solution_variables, mesh, equations, solver, c
return unstructured_data
end

# This method is only for plotting 1D functions
function get_unstructured_data(func::Function, solution_variables,
mesh::AbstractMesh{1}, equations, solver, cache)
original_nodes = cache.elements.node_coordinates
# original_nodes has size (1, nnodes, nelements)
# we want u to have size (nvars, nnodes, nelements)
# func.(original_nodes, equations) has size (1, nnodes, nelements), where each component has length n_vars
# Therefore, we drop the first (singleton) dimension and then stack the components
u = stack(func.(SVector.(dropdims(original_nodes; dims = 1)), equations))
return get_unstructured_data(u, solution_variables, mesh, equations, solver, cache)
end

# Convert cell-centered values to node-centered values by averaging over all
# four neighbors and making use of the periodicity of the solution
#
Expand Down
25 changes: 25 additions & 0 deletions test/test_visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ end
@test_nowarn_mod Plots.plot(pd)
@test_nowarn_mod Plots.plot(pd["p"])
@test_nowarn_mod Plots.plot(getmesh(pd))
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan),
equations)
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
end

# Fake a PlotDataXD objects to test code for plotting multiple variables on at least two rows
Expand Down Expand Up @@ -220,13 +224,34 @@ end
tspan = (0.0, 0.0),
approximation_type = Polynomial())
@test PlotData1D(sol) isa PlotData1D
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations)
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)

@test_nowarn_mod trixi_include(@__MODULE__,
joinpath(examples_dir(), "dgmulti_1d",
"elixir_euler_flux_diff.jl"),
tspan = (0.0, 0.0),
approximation_type = SBP())
@test PlotData1D(sol) isa PlotData1D
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
end

@timed_testset "1D plot recipes (StructuredMesh)" begin
@test_nowarn_mod trixi_include(@__MODULE__,
joinpath(examples_dir(), "structured_1d_dgsem",
"elixir_euler_source_terms.jl"),
tspan = (0.0, 0.0))

pd = PlotData1D(sol)
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations)
@test_nowarn_mod Plots.plot(sol)
@test_nowarn_mod Plots.plot(pd)
@test_nowarn_mod Plots.plot(pd["p"])
@test_nowarn_mod Plots.plot(sol.u[end], semi)
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
end

@timed_testset "plot time series" begin
Expand Down

0 comments on commit c24ce91

Please sign in to comment.