diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index cb0d18e178..a938fccfb9 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -166,7 +166,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ, switch), θ) end sr = @rule switch => 1