Skip to content

Commit

Permalink
Merge branch 'main' into polyalg-23
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored May 8, 2024
2 parents 741beb1 + 355fa2c commit 0a972e6
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 133 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
version: '1'
arch: x64
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- name: Install dependencies
shell: julia --color=yes {0}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI-ENH.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
version: '1'
arch: x64
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- name: Pkg.add
shell: julia --color=yes {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
version: '1'
arch: x64
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Formatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
version: '1'
arch: x64
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- name: Install JuliaFormatter and format
shell: julia --color=yes {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Invalidations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
version: '1'
arch: x64
show-versioninfo: true
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
id: invs_pr
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ContinuousNormalizingFlows"
uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac"
authors = ["Hossein Pourbozorg <[email protected]> and contributors"]
version = "0.22.1"
version = "0.22.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -50,7 +50,7 @@ ComputationalResources = "0.3"
DataFrames = "1"
Dates = "1"
DifferentialEquations = "7"
DifferentiationInterface = "0.1, 0.2"
DifferentiationInterface = "0.1, 0.2, 0.3"
Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "1"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ADTypes = "0.2, 1"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.1, 0.2"
DifferentiationInterface = "0.1, 0.2, 0.3"
Lux = "0.5"
PkgBenchmark = "0.2"
StableRNGs = "1"
Expand Down
47 changes: 20 additions & 27 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -256,9 +256,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -278,9 +278,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -301,9 +301,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -324,9 +324,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -348,9 +348,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -372,9 +372,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -397,9 +397,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand Down Expand Up @@ -512,28 +512,21 @@ end
@inline function make_ode_func(
icnf::AbstractICNF{T, CM, INPLACE},
mode::Mode,
nn::Lux.StatefulLuxLayer,
nn::LuxCore.AbstractExplicitLayer,
st::NamedTuple,
ϵ::AbstractVecOrMat{T},
) where {T <: AbstractFloat, CM, INPLACE}
function ode_func_op(u, p, t)
augmented_f(u, p, t, icnf, mode, nn, ϵ)
augmented_f(u, p, t, icnf, mode, nn, st, ϵ)
end

function ode_func_ip(du, u, p, t)
augmented_f(du, u, p, t, icnf, mode, nn, ϵ)
augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ)
end

ifelse(INPLACE, ode_func_ip, ode_func_op)
end

@inline function make_dyn_func(nn::Lux.StatefulLuxLayer, ps::Any)
function dyn_func(x)
LuxCore.apply(nn, x, ps)
end

dyn_func
end

@inline function (icnf::AbstractICNF{T, CM, INPLACE, false})(
xs::AbstractVecOrMat,
ps::Any,
Expand Down
Loading

0 comments on commit 0a972e6

Please sign in to comment.