Skip to content

Commit

Permalink
Add restarts with JLD2
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
Sbozzolo committed Feb 6, 2025
1 parent 5b2b44d commit bc99b9e
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 73 deletions.
4 changes: 2 additions & 2 deletions experiments/ClimaEarth/cli_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ function argparse_settings()
default = nothing
"--restart_t"
help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [0 (default)]"
arg_type = Int
default = 0
arg_type = String
default = "0secs"
# Diagnostics information
"--use_coupler_diagnostics"
help = "Boolean flag indicating whether to compute and output coupler diagnostics [`true` (default), `false`]"
Expand Down
7 changes: 5 additions & 2 deletions experiments/ClimaEarth/components/atmosphere/climaatmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ function Checkpointer.get_model_cache(sim::ClimaAtmosSimulation)
end

function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache)
restore!(Checkpointer.get_model_cache(sim), new_cache;
ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler]))
restore!(
Checkpointer.get_model_cache(sim),
new_cache;
ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler]),
)
return nothing
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@ import ClimaCore.Geometry: AxisTensor
import ClimaCore.Spaces: AbstractSpace
import NCDatasets

function restore!(
v1::T,
v2::T;
name = "",
ignore = Set([:rc]),
) where {T <: Union{NamedTuple, CA.AtmosCache}}
function restore!(v1::T, v2::T; name = "", ignore = Set([:rc])) where {T <: Union{NamedTuple, CA.AtmosCache}}
_restore!(v1, v2; name, ignore)
return nothing
end
Expand Down
5 changes: 4 additions & 1 deletion experiments/ClimaEarth/components/land/climaland_bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,10 @@ function recursively_reset!(v1::T, v2::T; ignore = Set([:rc])) where {T}
end
end

function recursively_reset_base!(v1::T, v2::T) where {T <: Union{CC.Fields.Field, CC.Fields.FieldVector, CC.DataLayouts.AbstractData, AbstractArray}}
function recursively_reset_base!(
v1::T,
v2::T,
) where {T <: Union{CC.Fields.Field, CC.Fields.FieldVector, CC.DataLayouts.AbstractData, AbstractArray}}
parent(v1) .= parent(v2)
end

Expand Down
121 changes: 65 additions & 56 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,74 +529,83 @@ Utilities.show_memory_usage()
If a restart directory is specified and contains output files from the `checkpoint_cb` callback, the component model states are restarted from those files. The restart directory
is specified in the `config_dict` dictionary. The `restart_t` field specifies the time step at which the restart is performed.
=#
restart_dir = dir_paths.checkpoints
restart_t = 180
if !isnothing(restart_dir)
for sim in cs.model_sims
if Checkpointer.get_model_prog_state(sim) !== nothing
Checkpointer.restart_model_state!(sim, comms_ctx, restart_t; input_dir = restart_dir)
Checkpointer.restart_model_cache!(sim, comms_ctx, restart_t; input_dir = restart_dir)
end
end
end

#=
## Initialize Component Model Exchange

We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes)
depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them.
The concrete steps for proper initialization are:
=#
# Reset coupler fields
output_dir = cs.dirs.checkpoints
pid = ClimaComms.mypid(comms_ctx)
input_file = joinpath(output_dir, "checkpoint", "checkpoint_coupler_fields_$(pid)_$(restart_t).jld2")
pid = ClimaComms.mypid(comms_ctx)
@info "Restoring coupler fields from $(input_file)"
fields_read = Checkpointer.JLD2.jldopen(input_file)["coupler_fields"]
for name in coupler_field_names
parent(getproperty(cs.fields, name)) .= parent(getproperty(fields_read, name))
end
else
#=
## Initialize Component Model Exchange
We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes)
depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them.
The concrete steps for proper initialization are:
=#

# 1.coupler updates surface model area fractions
FieldExchanger.update_surface_fractions!(cs)

# 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface.
# For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density.
FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes)
FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

# 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally
## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331)
Interfacer.step!(land_sim, tspan[1] + Δt_cpl)
Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl)
Interfacer.step!(ice_sim, tspan[1] + Δt_cpl)

# 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent
# surface fluxes using either the combined state or the partitioned state method
if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST
## import the new surface properties into the coupler (note the atmos state was also imported in step 3.)
FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc
## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box
FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts
elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes
## calculate turbulent fluxes in surface models and save the weighted average in coupler fields
FluxCalculator.partitioned_turbulent_fluxes!(
cs.model_sims,
cs.fields,
cs.boundary_space,
FluxCalculator.MoninObukhovScheme(),
cs.thermo_params,
)

## update atmos sfc_conditions for surface temperature
## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479)
new_p = get_new_cache(atmos_sim, cs.fields)
CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA)
atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions
end

# 1.coupler updates surface model area fractions
FieldExchanger.update_surface_fractions!(cs)

# 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface.
# For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density.
FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes)
FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

# 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally
## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331)
Interfacer.step!(land_sim, tspan[1] + Δt_cpl)
Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl)
Interfacer.step!(ice_sim, tspan[1] + Δt_cpl)

# 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent
# surface fluxes using either the combined state or the partitioned state method
if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST
## import the new surface properties into the coupler (note the atmos state was also imported in step 3.)
FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc
## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box
FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts
elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes
## calculate turbulent fluxes in surface models and save the weighted average in coupler fields
FluxCalculator.partitioned_turbulent_fluxes!(
cs.model_sims,
cs.fields,
cs.boundary_space,
FluxCalculator.MoninObukhovScheme(),
cs.thermo_params,
)
# 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions)
FieldExchanger.reinit_model_sims!(cs.model_sims)

## update atmos sfc_conditions for surface temperature
## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479)
new_p = get_new_cache(atmos_sim, cs.fields)
CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA)
atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions
# 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types
# and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`,
# and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes`
# atmos receives the turbulent fluxes from the coupler.
FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)
end

# 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions)
FieldExchanger.reinit_model_sims!(cs.model_sims)

# 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types
# and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`,
# and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes`
# atmos receives the turbulent fluxes from the coupler.
FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

#=
## Coupling Loop
Expand Down
4 changes: 2 additions & 2 deletions experiments/ClimaEarth/test/compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Keyword arguments
`:rc` is some CUDA/CuArray internal object that we don't care about
"""
function compare(
v1::T,
v2::T;
v1::T1,
v2::T2;
name = "",
ignore = Set([:rc]),
) where {T <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}}
Expand Down
14 changes: 13 additions & 1 deletion experiments/ClimaEarth/test/restart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ println("Reading and simulating last step")

# Restart (just re-run from the same folder)
restarted = YAML.load_file(joinpath(tmpdir, "two_steps.yml"))
restarted["restart_dir"] = TestTwo1.dir_paths.checkpoints
restarted["restart_t"] = "360secs"
restarted["t_end"] = "540secs"
YAML.write_file(joinpath(tmpdir, "restarted.yml"), restarted)
push!(ARGS, "--config_file", joinpath(tmpdir, "restarted.yml"))
Expand Down Expand Up @@ -93,11 +95,21 @@ end
@test compare(TestThree.cs.model_sims.ice_sim.integrator.u, TestTwo2.cs.model_sims.ice_sim.integrator.u)

@test compare(TestThree.cs.model_sims.land_sim.integrator.u, TestTwo2.cs.model_sims.land_sim.integrator.u)
@test compare(TestThree.cs.model_sims.land_sim.integrator.p, TestTwo2.cs.model_sims.land_sim.integrator.p)

@test compare(
TestThree.cs.model_sims.land_sim.integrator.p,
TestTwo2.cs.model_sims.land_sim.integrator.p;
ignore = [:dss_buffer_3d, :dss_buffer_2d],
)

@test compare(TestThree.cs.model_sims.ocean_sim.cache, TestTwo2.cs.model_sims.ocean_sim.cache)
# Ignoring SST_timevaryinginput because it contains closures (which should be reinitialized correctly)
# We have to remove it from the type, otherwise comapre will not work
function delete(nt::NamedTuple, fieldnames...)
return (; filter(p -> !(first(p) in fieldnames), collect(pairs(nt)))...)
end

ocean_cache_three = delete(TestThree.cs.model_sims.ocean_sim.cache, :SST_timevaryinginput)
ocean_cache_two2 = delete(TestTwo2.cs.model_sims.ocean_sim.cache, :SST_timevaryinginput)

@test compare(ocean_cache_three, ocean_cache_two2)
2 changes: 1 addition & 1 deletion experiments/ClimaEarth/user_io/arg_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function get_coupler_args(config_dict::Dict)

# Restart information
restart_dir = config_dict["restart_dir"]
restart_t = Int(config_dict["restart_t"])
restart_t = Int64(Utilities.time_to_seconds(config_dict["restart_t"]))

# Diagnostics information
use_coupler_diagnostics = config_dict["use_coupler_diagnostics"]
Expand Down
16 changes: 14 additions & 2 deletions src/Checkpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ end
This is a callback function that checkpoints all simulations defined in the current coupled simulation.
"""
function checkpoint_sims(cs::Interfacer.CoupledSimulation)
t = Dates.datetime2epochms(cs.dates.date[1])
t0 = Dates.datetime2epochms(cs.dates.date0[1])
for sim in cs.model_sims
if Checkpointer.get_model_prog_state(sim) !== nothing
t = Dates.datetime2epochms(cs.dates.date[1])
t0 = Dates.datetime2epochms(cs.dates.date0[1])
Checkpointer.checkpoint_model_state(
sim,
cs.comms_ctx,
Expand All @@ -164,6 +164,18 @@ function checkpoint_sims(cs::Interfacer.CoupledSimulation)
)
end
end

# Checkpoint the Coupler fields
output_dir = cs.dirs.checkpoints
comms_ctx = cs.comms_ctx
time = Int((t - t0) / 1e3)
day = floor(Int, time / (60 * 60 * 24))
sec = floor(Int, time % (60 * 60 * 24))
pid = ClimaComms.mypid(comms_ctx)
@info "Saving coupler fields to JLD2 on day $day second $sec"
output_file = joinpath(output_dir, "checkpoint", "checkpoint_coupler_fields_$(pid)_$time.jld2")
mkpath(joinpath(output_dir, "checkpoint"))
JLD2.jldsave(output_file, coupler_fields = cs.fields)
end

end # module

0 comments on commit bc99b9e

Please sign in to comment.