Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to add checkpointing scheme for reactant #777

Open
swilliamson7 opened this issue Feb 18, 2025 · 8 comments
Open

How to add checkpointing scheme for reactant #777

swilliamson7 opened this issue Feb 18, 2025 · 8 comments

Comments

@swilliamson7
Copy link
Collaborator

Sorry if this isn't the right place to post, but in the process of adding reactant to my code I'm seeing two main issues and wanted to document here:

  1. If I don't use Checkpointing.jl, the code increases in memory usage until my computer kills the process, and this is happening for a one day integration which shouldn't be enough to cause memory issues

  2. I tried to change to the integration that uses checkpointing to circumvent (1) by adding a few lines:

    S = Reactant.to_rarray(S)
    dS = Reactant.to_rarray(dS)
    revolve = Reactant.to_rarray(revolve) # first added line
    compiled_outer = @compile outer(S, dS, revolve) # adjusted to take the checkpointing scheme as an argument

    compiled_outer = outer
    compiled_outer(S, dS, revolve)
    return S, dS

inside the relevant functions run_adjoint and outer,

function outer(S, dS, revolve)
    autodiff(Enzyme.ReverseWithPrimal, checkpointed_integration, Duplicated(S, dS), Const(revolve))
    nothing
end

function run_adjoint(::Type{T}=Float32;
    kwargs...
    ) where {T<:AbstractFloat}

    P = ShallowWaters.Parameter(T=T;kwargs...)
    @show typeof(P)
    S = ShallowWaters.model_setup(P)

    dS = Enzyme.Compiler.make_zero(S)
    snaps = Int(floor(sqrt(S.grid.nt)))
    revolve = Revolve{ShallowWaters.ModelSetup}(S.grid.nt,
        snaps;
        verbose=1,
        gc=true,
        write_checkpoints=false,
        write_checkpoints_filename = "technicalpaper_checkingderivatives_30dayrun_withcheckpointing_120924",
        write_checkpoints_period = 224
    )

    # autodiff(Enzyme.ReverseWithPrimal, checkpointed_integration, Duplicated(S, dS), Const(revolve))
    # autodiff(Enzyme.ReverseWithPrimal, integration, Duplicated(S, dS))
    # outer()

    S = Reactant.to_rarray(S)
    dS = Reactant.to_rarray(dS)
    revolve = Reactant.to_rarray(revolve)
    compiled_outer = @compile outer(S, dS, revolve)

    compiled_outer = outer
    compiled_outer(S, dS, revolve)
    return S, dS

end

but this just resulted in the error message:

swilliamson@CRIOS-A66253 ~/D/G/S/eddy-stresses> julia --project=. eddy_paper.jl                                                                :) main!?#
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184
typeof(P) = Parameter{Array{Float64, 3}, Vector{Float64}, Vector{Float64}}
[ Info: Revolve: Number of checkpoints: 15
[ Info: Revolve: Number of steps: 225
[ Info: Prediction:
[ Info: Forward steps   : 522
[ Info: Overhead factor : 2.32
ERROR: LoadError: Abstract type Function does not have a definite size.
Stacktrace:
  [1] sizeof
    @ ./essentials.jl:631 [inlined]
  [2] traced_type_inner(T::Type{<:Function}, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:76
  [3] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:415
  [4] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:444
  [5] traced_type(T::Type, ::Val{Reactant.ArrayToConcrete}, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:611
  [6] make_tracer(seen::Reactant.OrderedIdDict{Any, Any}, prev::Any, path::Any, mode::Reactant.TraceMode; toscalar::Bool, tobatch::Nothing, track_numbers::Type, kwargs::@Kwargs{})
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:668
  [7] make_tracer
    @ ~/.julia/dev/Reactant/src/Tracing.jl:654 [inlined]
  [8] to_rarray_internal
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1067 [inlined]
  [9] #to_rarray#79
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1063 [inlined]
 [10] to_rarray
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1061 [inlined]
 [11] run_adjoint(::Type{Float32}; kwargs::@Kwargs{output::Bool, L_ratio::Int64, g::Float64, H::Int64, wind_forcing_x::String, Lx::Float64, seasonal_wind_x::Bool, topography::String, bc::String, bottom_drag::String, nn_forcing_dissipation::Bool, handwritten::Bool, α::Int64, nx::Int64, Ndays::Int64})
    @ Main ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_experiment_functions.jl:39
 [12] run_adjoint
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_experiment_functions.jl:14 [inlined]
 [13] top-level scope
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_run_experiments.jl:3
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [15] top-level scope
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:28
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_run_experiments.jl:3
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:28

I tried to add the scheme in the same way that S and dS are treated, is this not the right idea? I mainly want to move to using checkpointing because, without it, I can't even see any error messages

@swilliamson7
Copy link
Collaborator Author

Adding that line 39 in eddy_paper_experiment_functions.jl points to

revolve = Reactant.to_rarray(revolve)

so I know its some problem related to how I've included revolve

@wsmoses wsmoses transferred this issue from EnzymeAD/Enzyme.jl Feb 20, 2025
@wsmoses
Copy link
Member

wsmoses commented Feb 20, 2025

so where you have the previously checkpointed loop, you definitely should have @trace for like

@trace for i in 1:n

@swilliamson7
Copy link
Collaborator Author

Now sure I follow, if I just use @trace for S.parameters.i = 1:S.grid.nt I see

swilliamson@CRIOS-A66253 ~/D/G/S/eddy-stresses> julia --project=. eddy_paper.jl                                                                                                                           :( main!?#
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184
ERROR: LoadError: malformed for loop assignment
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] trace_for(mod::Module, expr::Expr)
   @ ReactantCore ~/.julia/packages/ReactantCore/t40zc/src/ReactantCore.jl:148
 [3] var"@trace"(__source__::LineNumberNode, __module__::Module, expr::Any)
   @ ReactantCore ~/.julia/packages/ReactantCore/t40zc/src/ReactantCore.jl:135
 [4] include(fname::String)
   @ Base.MainInclude ./client.jl:494
 [5] top-level scope
   @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:71
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:67
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24

and if I try to keep the things needed for checkpointing, like @checkpoint_struct scheme S @trace for S.parameters.i = 1:S.grid.nt I see

swilliamson@CRIOS-A66253 ~/D/G/S/eddy-stresses> julia --project=. eddy_paper.jl                                                                                                                           :( main!?#
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184
ERROR: LoadError: Checkpointing.jl: Unknown loop construct.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] var"@checkpoint_struct"(__source__::LineNumberNode, __module__::Module, alg::Any, model::Any, loop::Any)
   @ Checkpointing ~/.julia/packages/Checkpointing/uBrnJ/src/Checkpointing.jl:200
 [3] include(fname::String)
   @ Base.MainInclude ./client.jl:494
 [4] top-level scope
   @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:71
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:67
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24

@swilliamson7
Copy link
Collaborator Author

swilliamson7 commented Feb 20, 2025

Mainly I don't think I can just get rid of the @checkpoint_struct scheme S before the for loop

@wsmoses
Copy link
Member

wsmoses commented Feb 22, 2025

can you instead do

@trace for i = 1:S.grid.nt

@swilliamson7
Copy link
Collaborator Author

This was the first thing I tried, it led to:

swilliamson@CRIOS-A66253 ~/D/G/S/eddy-stresses> julia --project=. eddy_paper.jl                                                                                                                           :( main!?#
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184
ERROR: LoadError: malformed for loop assignment
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] trace_for(mod::Module, expr::Expr)
   @ ReactantCore ~/.julia/packages/ReactantCore/t40zc/src/ReactantCore.jl:148
 [3] var"@trace"(__source__::LineNumberNode, __module__::Module, expr::Any)
   @ ReactantCore ~/.julia/packages/ReactantCore/t40zc/src/ReactantCore.jl:135
 [4] include(fname::String)
   @ Base.MainInclude ./client.jl:494
 [5] top-level scope
   @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:71
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_integration.jl:67
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:24

@swilliamson7
Copy link
Collaborator Author

Oh wait my mistake, I can adjust how i is defined

@swilliamson7
Copy link
Collaborator Author

Okay, using @trace for i = 1:S.grid.nt gives

swilliamson@CRIOS-A66253 ~/D/G/S/eddy-stresses> julia --project=. eddy_paper.jl                                                                :( main!?#
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184
typeof(P) = Parameter{Array{Float64, 3}, Vector{Float64}, Vector{Float64}}
[ Info: Revolve: Number of checkpoints: 15
[ Info: Revolve: Number of steps: 225
[ Info: Prediction:
[ Info: Forward steps   : 522
[ Info: Overhead factor : 2.32
ERROR: LoadError: Abstract type Function does not have a definite size.
Stacktrace:
  [1] sizeof
    @ ./essentials.jl:631 [inlined]
  [2] traced_type_inner(T::Type{<:Function}, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:76
  [3] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:415
  [4] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:444
  [5] traced_type(T::Type, ::Val{Reactant.ArrayToConcrete}, track_numbers::Type)
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:611
  [6] make_tracer(seen::Reactant.OrderedIdDict{Any, Any}, prev::Any, path::Any, mode::Reactant.TraceMode; toscalar::Bool, tobatch::Nothing, track_numbers::Type, kwargs::@Kwargs{})
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:668
  [7] make_tracer
    @ ~/.julia/dev/Reactant/src/Tracing.jl:654 [inlined]
  [8] to_rarray_internal
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1067 [inlined]
  [9] #to_rarray#79
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1063 [inlined]
 [10] to_rarray
    @ ~/.julia/dev/Reactant/src/Tracing.jl:1061 [inlined]
 [11] run_adjoint(::Type{Float32}; kwargs::@Kwargs{output::Bool, L_ratio::Int64, g::Float64, H::Int64, wind_forcing_x::String, Lx::Float64, seasonal_wind_x::Bool, topography::String, bc::String, bottom_drag::String, nn_forcing_dissipation::Bool, handwritten::Bool, α::Int64, nx::Int64, Ndays::Int64})
    @ Main ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_experiment_functions.jl:39
 [12] run_adjoint
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_experiment_functions.jl:14 [inlined]
 [13] top-level scope
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_run_experiments.jl:3
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [15] top-level scope
    @ ~/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:28
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper_run_experiments.jl:3
in expression starting at /Users/swilliamson/Documents/GitHub/ShallowWaters_work/eddy-stresses/eddy_paper.jl:28

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants