diff --git a/src/visualization/recipes_plots.jl b/src/visualization/recipes_plots.jl index 0e9b5a66a8d..17b2fe33b38 100644 --- a/src/visualization/recipes_plots.jl +++ b/src/visualization/recipes_plots.jl @@ -171,6 +171,19 @@ RecipesBase.@recipe function f(u, semi::AbstractSemidiscretization; end end +# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions. +# We need this recipe in addition to the one above to avoid method ambiguities. +RecipesBase.@recipe function f(func::Function, semi::AbstractSemidiscretization; + solution_variables = nothing) + n_variables = length(func(0.0, semi.equations)) + variable_names = SVector(["func[$i]" for i in 1:n_variables]...) + if ndims(semi) == 1 + return PlotData1D(func, semi; solution_variables = cons2cons, variable_names) + else + throw(ArgumentError("Plotting of functions is only supported in 1D.")) + end +end + # Recipe specifically for TreeMesh-type solutions # Note: If you change the defaults values here, you need to also change them in the PlotData1D or PlotData2D # constructor. @@ -189,6 +202,22 @@ RecipesBase.@recipe function f(u, semi::SemidiscretizationHyperbolic{<:TreeMesh} end end +# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions. +RecipesBase.@recipe function f(func::Function, + semi::SemidiscretizationHyperbolic{<:TreeMesh}; + solution_variables = nothing, + nvisnodes = nothing, slice = :xy, + point = (0.0, 0.0, 0.0), curve = nothing) + n_variables = length(func(0.0, semi.equations)) + variable_names = SVector(["func[$i]" for i in 1:n_variables]...) + if ndims(semi) == 1 + return PlotData1D(func, semi; solution_variables = cons2cons, nvisnodes, slice, + point, curve, variable_names) + else + throw(ArgumentError("Plotting of functions is only supported in 1D.")) + end +end + # Series recipe for PlotData2DTriangulated RecipesBase.@recipe function f(pds::PlotDataSeries{<:PlotData2DTriangulated}) pd = pds.plot_data diff --git a/src/visualization/types.jl b/src/visualization/types.jl index d3330f5536b..a233faeffbf 100644 --- a/src/visualization/types.jl +++ b/src/visualization/types.jl @@ -519,10 +519,15 @@ end solution_variables=nothing, nvisnodes=nothing) Create a new `PlotData1D` object that can be used for visualizing 1D DGSEM solution data array -`u` with `Plots.jl`. All relevant geometrical information is extracted from the semidiscretization -`semi`. By default, the primitive variables (if existent) or the conservative variables (otherwise) -from the solution are used for plotting. This can be changed by passing an appropriate conversion -function to `solution_variables`. +`u` with `Plots.jl`. All relevant geometrical information is extracted from the +semidiscretization `semi`. By default, the primitive variables (if existent) +or the conservative variables (otherwise) from the solution are used for +plotting. This can be changed by passing an appropriate conversion function to +`solution_variables`, e.g., [`cons2cons`](@ref) or [`cons2prim`](@ref). + +Alternatively, you can also pass a function `u` with signature `u(x, equations)` +returning a vector. In this case, the `solution_variables` are ignored. This is useful, +e.g., to visualize an analytical solution. `nvisnodes` specifies the number of visualization nodes to be used. If it is `nothing`, twice the number of solution DG nodes are used for visualization, and if set to `0`, @@ -547,11 +552,19 @@ function PlotData1D(u_ode, semi; kwargs...) kwargs...) end +function PlotData1D(func::Function, semi; kwargs...) + PlotData1D(func, + mesh_equations_solver_cache(semi)...; + kwargs...) +end + function PlotData1D(u, mesh::TreeMesh, equations, solver, cache; solution_variables = nothing, nvisnodes = nothing, - slice = :x, point = (0.0, 0.0, 0.0), curve = nothing) + slice = :x, point = (0.0, 0.0, 0.0), curve = nothing, + variable_names = nothing) solution_variables_ = digest_solution_variables(equations, solution_variables) - variable_names = SVector(varnames(solution_variables_, equations)) + variable_names_ = digest_variable_names(solution_variables_, equations, + variable_names) original_nodes = cache.elements.node_coordinates unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations, @@ -610,15 +623,17 @@ function PlotData1D(u, mesh::TreeMesh, equations, solver, cache; end end - return PlotData1D(x, data, variable_names, mesh_vertices_x, + return PlotData1D(x, data, variable_names_, mesh_vertices_x, orientation_x) end function PlotData1D(u, mesh, equations, solver, cache; solution_variables = nothing, nvisnodes = nothing, - slice = :x, point = (0.0, 0.0, 0.0), curve = nothing) + slice = :x, point = (0.0, 0.0, 0.0), curve = nothing, + variable_names = nothing) solution_variables_ = digest_solution_variables(equations, solution_variables) - variable_names = SVector(varnames(solution_variables_, equations)) + variable_names_ = digest_variable_names(solution_variables_, equations, + variable_names) original_nodes = cache.elements.node_coordinates unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations, @@ -642,15 +657,25 @@ function PlotData1D(u, mesh, equations, solver, cache; slice, point, nvisnodes) end - return PlotData1D(x, data, variable_names, mesh_vertices_x, + return PlotData1D(x, data, variable_names_, mesh_vertices_x, orientation_x) end +function PlotData1D(func::Function, mesh, equations, dg::DGMulti{1}, cache; + solution_variables = nothing, variable_names = nothing) + x = mesh.md.x + u = func.(x, equations) + + return PlotData1D(u, mesh, equations, dg, cache; + solution_variables, variable_names) +end + # Specializes the `PlotData1D` constructor for one-dimensional `DGMulti` solvers. function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache; - solution_variables = nothing) + solution_variables = nothing, variable_names = nothing) solution_variables_ = digest_solution_variables(equations, solution_variables) - variable_names = SVector(varnames(solution_variables_, equations)) + variable_names_ = digest_variable_names(solution_variables_, equations, + variable_names) orientation_x = 0 # Set 'orientation' to zero on default. @@ -679,11 +704,16 @@ function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache; # Same as above - we create `data_plot` as array of size `num_plotting_points` # by "number of plotting variables". x_plot = vec(x) - data_plot = permutedims(reinterpret(reshape, eltype(eltype(data)), vec(data)), - (2, 1)) + data_ = reinterpret(reshape, eltype(eltype(data)), vec(data)) + # If there is only one solution variable, we need to add a singleton dimension + if ndims(data_) == 1 + data_plot = reshape(data_, :, 1) + else + data_plot = permutedims(data_, (2, 1)) + end end - return PlotData1D(x_plot, data_plot, variable_names, mesh.md.VX, orientation_x) + return PlotData1D(x_plot, data_plot, variable_names_, mesh.md.VX, orientation_x) end """ diff --git a/src/visualization/utilities.jl b/src/visualization/utilities.jl index 1f843c6a9d2..db2e830bc0e 100644 --- a/src/visualization/utilities.jl +++ b/src/visualization/utilities.jl @@ -256,6 +256,10 @@ function digest_solution_variables(equations, solution_variables::Nothing) end end +digest_variable_names(solution_variables_, equations, variable_names) = variable_names +digest_variable_names(solution_variables_, equations, ::Nothing) = SVector(varnames(solution_variables_, + equations)) + """ adapt_to_mesh_level!(u_ode, semi, level) adapt_to_mesh_level!(sol::Trixi.TrixiODESolution, level) @@ -481,6 +485,18 @@ function get_unstructured_data(u, solution_variables, mesh, equations, solver, c return unstructured_data end +# This method is only for plotting 1D functions +function get_unstructured_data(func::Function, solution_variables, + mesh::AbstractMesh{1}, equations, solver, cache) + original_nodes = cache.elements.node_coordinates + # original_nodes has size (1, nnodes, nelements) + # we want u to have size (nvars, nnodes, nelements) + # func.(original_nodes, equations) has size (1, nnodes, nelements), where each component has length n_vars + # Therefore, we drop the first (singleton) dimension and then stack the components + u = stack(func.(SVector.(dropdims(original_nodes; dims = 1)), equations)) + return get_unstructured_data(u, solution_variables, mesh, equations, solver, cache) +end + # Convert cell-centered values to node-centered values by averaging over all # four neighbors and making use of the periodicity of the solution # diff --git a/test/test_visualization.jl b/test/test_visualization.jl index 5c7e5dbbd1f..0d3a081f0bc 100644 --- a/test/test_visualization.jl +++ b/test/test_visualization.jl @@ -188,6 +188,10 @@ end @test_nowarn_mod Plots.plot(pd) @test_nowarn_mod Plots.plot(pd["p"]) @test_nowarn_mod Plots.plot(getmesh(pd)) + initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), + equations) + @test_nowarn_mod Plots.plot(initial_condition_t_end, semi) + @test_nowarn_mod Plots.plot((x, equations) -> x, semi) end # Fake a PlotDataXD objects to test code for plotting multiple variables on at least two rows @@ -220,6 +224,9 @@ end tspan = (0.0, 0.0), approximation_type = Polynomial()) @test PlotData1D(sol) isa PlotData1D + initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations) + @test_nowarn_mod Plots.plot(initial_condition_t_end, semi) + @test_nowarn_mod Plots.plot((x, equations) -> x, semi) @test_nowarn_mod trixi_include(@__MODULE__, joinpath(examples_dir(), "dgmulti_1d", @@ -227,6 +234,24 @@ end tspan = (0.0, 0.0), approximation_type = SBP()) @test PlotData1D(sol) isa PlotData1D + @test_nowarn_mod Plots.plot(initial_condition_t_end, semi) + @test_nowarn_mod Plots.plot((x, equations) -> x, semi) +end + +@timed_testset "1D plot recipes (StructuredMesh)" begin + @test_nowarn_mod trixi_include(@__MODULE__, + joinpath(examples_dir(), "structured_1d_dgsem", + "elixir_euler_source_terms.jl"), + tspan = (0.0, 0.0)) + + pd = PlotData1D(sol) + initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations) + @test_nowarn_mod Plots.plot(sol) + @test_nowarn_mod Plots.plot(pd) + @test_nowarn_mod Plots.plot(pd["p"]) + @test_nowarn_mod Plots.plot(sol.u[end], semi) + @test_nowarn_mod Plots.plot(initial_condition_t_end, semi) + @test_nowarn_mod Plots.plot((x, equations) -> x, semi) end @timed_testset "plot time series" begin