Skip to content

Commit

Permalink
fixup! test: update NNODE tests - forward pass in additional loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jun 29, 2024
1 parent ba1ca23 commit dcf9007
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000),
additional_loss = additional_loss)
Expand Down

0 comments on commit dcf9007

Please sign in to comment.