From 72ee135200757b1bc0ed211dce362b37d9dc90df Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 26 Mar 2024 08:10:56 +0330 Subject: [PATCH] bugfix: loss dispatch --- src/icnf.jl | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/icnf.jl b/src/icnf.jl index 6a217763..9eb8e07c 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -804,20 +804,46 @@ end @inline function loss( icnf::ICNF{<:AbstractFloat, <:VectorMode}, - mode::Mode, + mode::TrainMode, + xs::AbstractVector{<:Real}, + ps::Any, + st::Any, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) + -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:VectorMode}, + mode::TrainMode, xs::AbstractVector{<:Real}, + ys::AbstractVector{<:Real}, ps::Any, + st::Any, ) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ end @inline function loss( icnf::ICNF{<:AbstractFloat, <:MatrixMode}, - mode::Mode, + mode::TrainMode, + xs::AbstractMatrix{<:Real}, + ps::Any, + st::Any, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) + mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:MatrixMode}, + mode::TrainMode, xs::AbstractMatrix{<:Real}, + ys::AbstractMatrix{<:Real}, ps::Any, + st::Any, ) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) end