Skip to content

Commit

Permalink
GradientScaleApative has problems on HMC,NUTS bayesian posterior case
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Jan 19, 2024
1 parent f2f91e7 commit 6111be7
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 33 deletions.
9 changes: 8 additions & 1 deletion src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ end

LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim

function LogDensityProblems.capabilities(::Type{<:PDELogTargetDensity})
function LogDensityProblems.capabilities(::PDELogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

Expand Down Expand Up @@ -351,6 +351,9 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)

pinnrep.iteration = [0]

dataset_pde, dataset_bc = discretization.dataset

if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
Expand Down Expand Up @@ -494,12 +497,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
end
return bpinnsols
else
println("now 1")

initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

Kernel = AdvancedHMC.make_kernel(Kernel, integrator)
println("now 2")

samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

Expand Down
59 changes: 41 additions & 18 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,24 @@ https://arxiv.org/abs/2001.04536v1
With code reference:
https://github.com/PredictiveIntelligenceLab/GradientPathologiesPINNs
"""
mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
mutable struct GradientScaleAdaptiveLoss{T <: Real, R <: Real} <: AbstractAdaptiveLoss
reweight_every::Int64
weight_change_inertia::T
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
bc_loss_weights::Vector{R}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every;
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T, R}(reweight_every;
weight_change_inertia = 0.9,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1) where {
T <:
Real,
R <:
Real
}
new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, R),
vectorify(additional_loss_weights, T))
end
end
Expand All @@ -111,46 +113,67 @@ SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
GradientScaleAdaptiveLoss{Float64}(reweight_every;
GradientScaleAdaptiveLoss{Float64, Float64}(reweight_every;
weight_change_inertia = weight_change_inertia,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

# function GradientScaleAdaptiveLoss(reweight_every;
# weight_change_inertia = 0.9,
# pde_loss_weights = 1,
# bc_loss_weights = 1,
# additional_loss_weights = 1)
# GradientScaleAdaptiveLoss{Float64, ForwardDiff.Dual}(reweight_every;
# weight_change_inertia = weight_change_inertia,
# pde_loss_weights = pde_loss_weights,
# bc_loss_weights = bc_loss_weights,
# additional_loss_weights = additional_loss_weights)
# end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::GradientScaleAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
weight_change_inertia = adaloss.weight_change_inertia
iteration = pinnrep.iteration
adaloss_T = eltype(adaloss.pde_loss_weights)

function run_loss_gradients_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
# the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function
pde_grads_maxes = [maximum(abs.(Zygote.gradient(pde_loss_function, θ)[1]))
pde_grads_maxes = [maximum(abs.(ForwardDiff.gradient(pde_loss_function, θ)[1]))
for pde_loss_function in pde_loss_functions]
pde_grads_max = maximum(pde_grads_maxes)
bc_grads_mean = [mean(abs.(Zygote.gradient(bc_loss_function, θ)[1]))
bc_grads_mean = [mean(abs.(ForwardDiff.gradient(bc_loss_function, θ)[1]))
for bc_loss_function in bc_loss_functions]

nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) :
convert(adaloss_T, 1e-7)
bc_loss_weights_proposed = pde_grads_max ./
(bc_grads_mean .+ nonzero_divisor_eps)

# println("adaloss.bc_loss_weights :", adaloss.bc_loss_weights)
if bc_loss_weights_proposed[1] isa ForwardDiff.Dual
bc_loss_weights_proposed = [bc_loss_weights_propose.value
for bc_loss_weights_propose in bc_loss_weights_proposed]
end

adaloss.bc_loss_weights .= weight_change_inertia .*
adaloss.bc_loss_weights .+
(1 .- weight_change_inertia) .*
bc_loss_weights_proposed
logscalar(pinnrep.logger, pde_grads_max, "adaptive_loss/pde_grad_max",
iteration[1])
logvector(pinnrep.logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes",
iteration[1])
logvector(pinnrep.logger, bc_grads_mean, "adaptive_loss/bc_grad_mean",
iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights",
iteration[1])

println("iteration : ", iteration[1])
println(" adaloss.bc_loss_weights : ", adaloss.bc_loss_weights)
# logscalar(pinnrep.logger, pde_grads_max, "adaptive_loss/pde_grad_max",
# iteration[1])
# logvector(pinnrep.logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes",
# iteration[1])
# logvector(pinnrep.logger, bc_grads_mean, "adaptive_loss/bc_grad_mean",
# iteration[1])
# logvector(pinnrep.logger, adaloss.bc_loss_weights,
# "adaptive_loss/bc_loss_weights",
# iteration[1])
end
nothing
end
Expand Down
52 changes: 44 additions & 8 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,14 +709,13 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,

# this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
# that's why we prefer the user to maintain the increment in the outer loop callback during optimization
ChainRulesCore.@ignore_derivatives if self_increment
iteration[1] += 1
end

ChainRulesCore.@ignore_derivatives begin
reweight_losses_func(θ, pde_loglikelihoods,
bc_loglikelihoods)
end
# println("terminate increment")
# ChainRulesCore.@ignore_derivatives if iteration[1] > 100
# self_increment=false
# end
pinnrep.iteration[1] += 1
reweight_losses_func(θ, pde_loglikelihoods, bc_loglikelihoods)

weighted_pde_loglikelihood = adaloss.pde_loss_weights .* pde_loglikelihoods
weighted_bc_loglikelihood = adaloss.bc_loss_weights .* bc_loglikelihoods
Expand Down Expand Up @@ -751,9 +750,46 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
weighted_additional_loglikelihood = adaloss.additional_loss_weights[1] *
_additional_loglikelihood

weighted_loglikelihood_before_additional + weighted_additional_loglikelihood
weighted_loglikelihood_before_additional + weighted_additional_loglikelihood
end

# println("full_weighted_loglikelihood : ", full_weighted_loglikelihood)

# ChainRulesCore.@ignore_derivatives begin
# println(" inside lower chainrules logging log_frequency part ")
# if iteration[1] % log_frequency == 0
# logvector(pinnrep.logger, pde_loglikelihoods, "unweighted_likelihood/pde_loglikelihoods",
# iteration[1])
# logvector(pinnrep.logger,
# bc_loglikelihoods,
# "unweighted_likelihood/bc_loglikelihoods",
# iteration[1])
# logvector(pinnrep.logger, weighted_pde_loglikelihood,
# "weighted_likelihood/weighted_pde_loglikelihood",
# iteration[1])
# logvector(pinnrep.logger, weighted_bc_loglikelihood,
# "weighted_likelihood/weighted_bc_loglikelihood",
# iteration[1])
# if !(additional_loss isa Nothing)
# logscalar(pinnrep.logger, weighted_additional_loglikelihood,
# "weighted_likelihood/weighted_additional_loglikelihood", iteration[1])
# end
# logscalar(pinnrep.logger, sum_weighted_pde_loglikelihood,
# "weighted_likelihood/sum_weighted_pde_loglikelihood", iteration[1])
# logscalar(pinnrep.logger, sum_weighted_bc_loglikelihood,
# "weighted_likelihood/sum_weighted_bc_loglikelihood", iteration[1])
# logscalar(pinnrep.logger, full_weighted_loglikelihood,
# "weighted_likelihood/full_weighted_loglikelihood",
# iteration[1])
# logvector(pinnrep.logger, adaloss.pde_loss_weights,
# "adaptive_loss/pde_loss_weights",
# iteration[1])
# logvector(pinnrep.logger, adaloss.bc_loss_weights,
# "adaptive_loss/bc_loss_weights",
# iteration[1])
# end
# end

return full_weighted_loglikelihood
end

Expand Down
64 changes: 58 additions & 6 deletions test/BPINN_PDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
Lux.Dense(8, 1))

discretization = NeuralPDE.BayesianPINN([chain],
adaptive_loss =GradientScaleAdaptiveLoss(5),
adaptive_loss = GradientScaleAdaptiveLoss(5),
# MiniMaxAdaptiveLoss(5),
GridTraining([dx, dt]), param_estim = true, dataset = [datasetpde, nothing])
@named pde_system = PDESystem(eq,
Expand All @@ -287,9 +287,10 @@ discretization = NeuralPDE.BayesianPINN([chain],

sol1 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 100, Kernel = AdvancedHMC.NUTS(0.80),
bcstd = [0.2, 0.2, 0.2, 0.2, 0.2],
phystd = [1.0], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
draw_samples = 1000,
Kernel = AdvancedHMC.NUTS(0.80),
bcstd = [1.0, 1.0, 1.0, 1.0, 1.0],
phystd = [0.1], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
priorsNNw = (0.0, 10.0),
saveats = [1 / 100.0, 1 / 100.0], progress = true)

Expand Down Expand Up @@ -373,7 +374,6 @@ datasetpde[1][:, 1] = datasetpde[1][:, 1] .+
datasetpde[1][:, 1]
plot!(datasetpde[1][:, 2], datasetpde[1][:, 1])


function CostFun(x::AbstractVector{T}) where {T}
function SpringEqu!(du, u, x, t)
du[1] = u[2]
Expand All @@ -396,4 +396,56 @@ function CostFun(x::AbstractVector{T}) where {T}

totalCost = sum(Simpos)
return totalCost
end
end

using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL
import ModelingToolkit: Interval, infimum, supremum

@parameters x, t
@variables u(..)
Dt = Differential(t)
Dx = Differential(x)
Dx2 = Differential(x)^2
Dx3 = Differential(x)^3
Dx4 = Differential(x)^4

α = 1
β = 4
γ = 1
eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0

u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2

bcs = [u(x, 0) ~ u_analytic(x, 0),
u(-10, t) ~ u_analytic(-10, t),
u(10, t) ~ u_analytic(10, t),
Dx(u(-10, t)) ~ du(-10, t),
Dx(u(10, t)) ~ du(10, t)]

# Space and time domains
domains = [x Interval(-10.0, 10.0),
t Interval(0.0, 1.0)]
# Discretization
dx = 0.4;
dt = 0.2;

# Neural network
chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
Lux.Dense(8, 8, Lux.tanh),
Lux.Dense(8, 1))

discretization = PhysicsInformedNN(chain,
adaptive_loss = GradientScaleAdaptiveLoss(1),
GridTraining([dx, dt]))
@named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)])
prob = discretize(pde_system, discretization)

callback = function (p, l)
println("Current loss is: $l")
return false
end

opt = OptimizationOptimJL.BFGS()
res = Optimization.solve(prob, opt; callback = callback, maxiters = 100)
phi = discretization.phi

0 comments on commit 6111be7

Please sign in to comment.