Skip to content

Commit

Permalink
fix cuda error (#379)
Browse files Browse the repository at this point in the history
* use `logpdf!`

* use `similar`

* use `oftype`

* use `covert`

* try a fix

* fix zygote error

* fix

* fix
  • Loading branch information
prbzrg authored Feb 27, 2024
1 parent 15681e2 commit 90e0709
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ using ContinuousNormalizingFlows.ComputationalResources
CURAND.default_rng()
end

@inline function ContinuousNormalizingFlows.base_AT(::CUDALibs)
CuArray
@inline function ContinuousNormalizingFlows.base_AT(
::CUDALibs,
::ContinuousNormalizingFlows.AbstractFlows{T},
dims...,
) where {T <: AbstractFloat}
CuArray{T}(undef, dims...)
end

end
16 changes: 12 additions & 4 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,16 @@ end
Random.default_rng()
end

@inline function base_AT(::AbstractResource)
Array
@inline function base_AT(
::AbstractResource,
::AbstractFlows{T},
dims...,
) where {T <: AbstractFloat}
Array{T}(undef, dims...)
end

@non_differentiable base_AT(::Any...)

function inference_sol(
icnf::AbstractFlows{T, <:VectorMode, INPLACE},
mode::Mode,
Expand All @@ -163,8 +169,9 @@ function inference_sol(
fsol = get_fsol(sol)
z = fsol[begin:(end - n_aug - 1)]
Δlogp = fsol[(end - n_aug)]
logp̂x = logpdf(icnf.basedist, z) - Δlogp
augs = fsol[(end - n_aug + 1):end]
logpz = oftype(Δlogp, logpdf(icnf.basedist, z))
logp̂x = logpz - Δlogp
(logp̂x, augs)
end

Expand All @@ -178,8 +185,9 @@ function inference_sol(
fsol = get_fsol(sol)
z = fsol[begin:(end - n_aug - 1), :]
Δlogp = fsol[(end - n_aug), :]
logp̂x = logpdf(icnf.basedist, z) - Δlogp
augs = fsol[(end - n_aug + 1):end, :]
logpz = oftype(Δlogp, logpdf(icnf.basedist, z))
logp̂x = logpz - Δlogp
(logp̂x, eachrow(augs))
end

Expand Down
12 changes: 6 additions & 6 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down Expand Up @@ -42,7 +42,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, size(xs, 2))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -69,11 +69,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down Expand Up @@ -101,11 +101,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1, n)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down
12 changes: 6 additions & 6 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down Expand Up @@ -40,7 +40,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, size(xs, 2))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -66,11 +66,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -97,11 +97,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1, n)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n)
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
rand!(icnf.rng, icnf.epsdist, ϵ)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down

0 comments on commit 90e0709

Please sign in to comment.