From c747ac02f3ad015d588883afeac1733aac294e29 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jan 2024 06:27:34 +0000 Subject: [PATCH] refactor: Quadrature Training with Integrals.jl@v4 --- src/training_strategies.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/training_strategies.jl b/src/training_strategies.jl index ca66f6b203..df185e5a67 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -320,9 +320,10 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature # mean(abs2,loss_(x,θ), dims=2) # size_x = fill(size(x)[2],(1,1)) x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x) - sum(abs2, loss_(x, θ), dims = 2) #./ size_x + sum(abs2, vec(loss_(x, θ)), dims = 2) #./ size_x end - prob = IntegralProblem(integrand, lb, ub, θ, batch = strategy.batch, nout = 1) + integral_function = BatchIntegralFunction(integrand, max_batch = strategy.batch) + prob = IntegralProblem(integral_function, lb, ub, θ) solve(prob, strategy.quadrature_alg, reltol = strategy.reltol,