From 90e07092cef88f37cddb2ebbf55d888a4befded1 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 28 Feb 2024 00:10:22 +0330 Subject: [PATCH] fix cuda error (#379) * use `logpdf!` * use `similar` * use `oftype` * use `covert` * try a fix * fix zygote error * fix * fix --- .../ContinuousNormalizingFlowsCUDAExt.jl | 8 ++++++-- src/base.jl | 16 ++++++++++++---- src/base_cond_icnf.jl | 12 ++++++------ src/base_icnf.jl | 12 ++++++------ 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl index 4e1f8742..d4b5060f 100644 --- a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl +++ b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl @@ -7,8 +7,12 @@ using ContinuousNormalizingFlows.ComputationalResources CURAND.default_rng() end -@inline function ContinuousNormalizingFlows.base_AT(::CUDALibs) - CuArray +@inline function ContinuousNormalizingFlows.base_AT( + ::CUDALibs, + ::ContinuousNormalizingFlows.AbstractFlows{T}, + dims..., +) where {T <: AbstractFloat} + CuArray{T}(undef, dims...) end end diff --git a/src/base.jl b/src/base.jl index c4a1bff6..2634e098 100644 --- a/src/base.jl +++ b/src/base.jl @@ -149,10 +149,16 @@ end Random.default_rng() end -@inline function base_AT(::AbstractResource) - Array +@inline function base_AT( + ::AbstractResource, + ::AbstractFlows{T}, + dims..., +) where {T <: AbstractFloat} + Array{T}(undef, dims...) end +@non_differentiable base_AT(::Any...) + function inference_sol( icnf::AbstractFlows{T, <:VectorMode, INPLACE}, mode::Mode, @@ -163,8 +169,9 @@ function inference_sol( fsol = get_fsol(sol) z = fsol[begin:(end - n_aug - 1)] Δlogp = fsol[(end - n_aug)] - logp̂x = logpdf(icnf.basedist, z) - Δlogp augs = fsol[(end - n_aug + 1):end] + logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp (logp̂x, augs) end @@ -178,8 +185,9 @@ function inference_sol( fsol = get_fsol(sol) z = fsol[begin:(end - n_aug - 1), :] Δlogp = fsol[(end - n_aug), :] - logp̂x = logpdf(icnf.basedist, z) - Δlogp augs = fsol[(end - n_aug + 1):end, :] + logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp (logp̂x, eachrow(augs)) end diff --git a/src/base_cond_icnf.jl b/src/base_cond_icnf.jl index 03fb5e1e..407b5f5a 100644 --- a/src/base_cond_icnf.jl +++ b/src/base_cond_icnf.jl @@ -12,7 +12,7 @@ function inference_prob( n_aug_input = n_augment_input(icnf) zrs = similar(xs, n_aug_input + n_aug + 1) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -42,7 +42,7 @@ function inference_prob( n_aug_input = n_augment_input(icnf) zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, size(xs, 2)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -69,11 +69,11 @@ function generate_prob( ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.basedist, new_xs) zrs = similar(new_xs, n_aug + 1) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -101,11 +101,11 @@ function generate_prob( ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) rand!(icnf.rng, icnf.basedist, new_xs) zrs = similar(new_xs, n_aug + 1, n) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 2b3e536b..8c0b1635 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -11,7 +11,7 @@ function inference_prob( n_aug_input = n_augment_input(icnf) zrs = similar(xs, n_aug_input + n_aug + 1) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -40,7 +40,7 @@ function inference_prob( n_aug_input = n_augment_input(icnf) zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, size(xs, 2)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -66,11 +66,11 @@ function generate_prob( ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.basedist, new_xs) zrs = similar(new_xs, n_aug + 1) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse( @@ -97,11 +97,11 @@ function generate_prob( ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) rand!(icnf.rng, icnf.basedist, new_xs) zrs = similar(new_xs, n_aug + 1, n) @ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource){T}(undef, icnf.nvars + n_aug_input, n) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) rand!(icnf.rng, icnf.epsdist, ϵ) ODEProblem{INPLACE, SciMLBase.FullSpecialize}( ifelse(