Skip to content

Commit

Permalink
bugfix: loss dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Mar 26, 2024
1 parent 330d335 commit 72ee135
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -804,20 +804,46 @@ end

@inline function loss(
icnf::ICNF{<:AbstractFloat, <:VectorMode},
mode::Mode,
mode::TrainMode,
xs::AbstractVector{<:Real},
ps::Any,
st::Any,
)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st)
-logp̂x + icnf.λ₁ *+ icnf.λ₂ *+ icnf.λ₃ *
end

@inline function loss(
icnf::ICNF{<:AbstractFloat, <:VectorMode},
mode::TrainMode,
xs::AbstractVector{<:Real},
ys::AbstractVector{<:Real},
ps::Any,
st::Any,
)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st)
-logp̂x + icnf.λ₁ *+ icnf.λ₂ *+ icnf.λ₃ *
end

@inline function loss(
icnf::ICNF{<:AbstractFloat, <:MatrixMode},
mode::Mode,
mode::TrainMode,
xs::AbstractMatrix{<:Real},
ps::Any,
st::Any,
)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st)
mean(-logp̂x + icnf.λ₁ *+ icnf.λ₂ *+ icnf.λ₃ * Ȧ)
end

@inline function loss(
icnf::ICNF{<:AbstractFloat, <:MatrixMode},
mode::TrainMode,
xs::AbstractMatrix{<:Real},
ys::AbstractMatrix{<:Real},
ps::Any,
st::Any,
)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps)
logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st)
mean(-logp̂x + icnf.λ₁ *+ icnf.λ₂ *+ icnf.λ₃ * Ȧ)
end

0 comments on commit 72ee135

Please sign in to comment.