Skip to content

Commit

Permalink
add mode to AD
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 3, 2024
1 parent d7eee90 commit d7ede3b
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
Expand Down
4 changes: 2 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ icnf = ContinuousNormalizingFlows.construct(
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down Expand Up @@ -83,7 +83,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down
2 changes: 1 addition & 1 deletion src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function construct(
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
Expand Down
8 changes: 4 additions & 4 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ Test.@testset "Call Tests" begin
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
]
data_types = Type{<:AbstractFloat}[Float32]
Expand Down
8 changes: 4 additions & 4 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ Test.@testset "Fit Tests" begin
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
]
data_types = Type{<:AbstractFloat}[Float32]
Expand Down
2 changes: 1 addition & 1 deletion test/instability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Test.@testset "Instability" begin
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down
2 changes: 1 addition & 1 deletion test/regression_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Test.@testset "Regression Tests" begin
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down

0 comments on commit d7ede3b

Please sign in to comment.