Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support identical restarts with JLD2 files #1179

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -19,6 +20,7 @@ ClimaComms = "0.6.2"
ClimaCore = "0.14.17"
ClimaUtilities = "0.1.9"
Dates = "1"
JLD2 = "0.5.11"
Logging = "1"
SciMLBase = "2.11"
StaticArrays = "1.6"
Expand Down
28 changes: 17 additions & 11 deletions experiments/ClimaEarth/Manifest-v1.11.toml
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ uuid = "908f55d8-4145-4867-9c14-5dad1a479e4d"
version = "0.4.6"

[[deps.ClimaCoupler]]
deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"]
deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "JLD2", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"]
path = "../.."
uuid = "4ade58fe-a8da-486c-bd89-46df092ec0c7"
version = "0.1.2"
Expand Down Expand Up @@ -376,9 +376,9 @@ version = "0.10.18"

[[deps.ClimaTimeSteppers]]
deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"]
git-tree-sha1 = "f03e9f4316d380cdf851ec2c4c55efbfdb064439"
git-tree-sha1 = "b452132022416ad3511143230f51660a62d583b2"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.8.1"
version = "0.8.2"

[deps.ClimaTimeSteppers.extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]
Expand Down Expand Up @@ -538,9 +538,9 @@ version = "0.6.3"

[[deps.CoordinateTransformations]]
deps = ["LinearAlgebra", "StaticArrays"]
git-tree-sha1 = "f9d7112bfff8a19a3a4ea4e03a8e6a91fe8456bf"
git-tree-sha1 = "a692f5e257d332de1e554e4566a4e5a8a72de2b2"
uuid = "150eb455-5306-5404-9cee-2592286d6298"
version = "0.6.3"
version = "0.6.4"

[[deps.CpuId]]
deps = ["Markdown"]
Expand Down Expand Up @@ -1196,6 +1196,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
git-tree-sha1 = "91d501cb908df6f134352ad73cde5efc50138279"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
version = "0.5.11"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6"
Expand Down Expand Up @@ -1880,9 +1886,9 @@ version = "1.0.1"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da"
git-tree-sha1 = "9da16da70037ba9d701192e27befedefb91ec284"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.11.1"
version = "2.11.2"

[deps.QuadGK.extensions]
QuadGKEnzymeExt = "Enzyme"
Expand Down Expand Up @@ -1937,9 +1943,9 @@ version = "1.3.4"

[[deps.RecursiveArrayTools]]
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "ea6ad53c168c7c1c2e8f870aefda269692a8a91f"
git-tree-sha1 = "fe9d37a17ab4d41a98951332ee8067f8dca8c4c2"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "3.28.0"
version = "3.29.0"

[deps.RecursiveArrayTools.extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
Expand Down Expand Up @@ -2606,9 +2612,9 @@ version = "1.5.0+0"

[[deps.libzip_jll]]
deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "OpenSSL_jll", "XZ_jll", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "e797fa066eba69f4c0585ffbd81bc780b5118ce2"
git-tree-sha1 = "86addc139bca85fdf9e7741e10977c45785727b7"
uuid = "337d8026-41b4-5cde-a456-74a10e5b31d1"
version = "1.11.2+2"
version = "1.11.3+0"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down
4 changes: 2 additions & 2 deletions experiments/ClimaEarth/cli_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ function argparse_settings()
default = nothing
"--restart_t"
help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [0 (default)]"
arg_type = Int
default = 0
arg_type = String
default = "0secs"
# Diagnostics information
"--use_coupler_diagnostics"
help = "Boolean flag indicating whether to compute and output coupler diagnostics [`true` (default), `false`]"
Expand Down
15 changes: 15 additions & 0 deletions experiments/ClimaEarth/components/atmosphere/climaatmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Thermodynamics as TD
import ClimaCoupler: Checkpointer, FieldExchanger, FluxCalculator, Interfacer, Utilities

include("climaatmos_extra_diags.jl")
include("climaatmos_recursive.jl")

###
### Functions required by ClimaCoupler.jl for an AtmosModelSimulation
Expand Down Expand Up @@ -102,6 +103,20 @@ function Checkpointer.get_model_prog_state(sim::ClimaAtmosSimulation)
return sim.integrator.u
end

function Checkpointer.get_model_cache(sim::ClimaAtmosSimulation)
return sim.integrator.p
end

function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache)
restore!(
Checkpointer.get_model_cache(sim),
new_cache;
ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler]),
)
return nothing
end


"""
Interfacer.get_field(atmos_sim::ClimaAtmosSimulation, ::Val{:radiative_energy_flux_toa})

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import ClimaComms
import ClimaAtmos as CA
import ClimaCore
import ClimaCore: DataLayouts, Fields, Geometry
import ClimaCore.Fields: Field, FieldVector, field_values
import ClimaCore.DataLayouts: AbstractData
import ClimaCore.Geometry: AxisTensor
import ClimaCore.Spaces: AbstractSpace
import NCDatasets

function restore!(v1::T, v2::T; name = "", ignore = Set([:rc])) where {T <: Union{NamedTuple, CA.AtmosCache}}
_restore!(v1, v2; name, ignore)
return nothing
end

function _restore!(v1::T, v2::T; name, ignore) where {T}
properties = filter(x -> !(x in ignore), propertynames(v1))
if isempty(properties)
if !Base.issingletontype(typeof(v1))
_restore_base!(v1, v2; name, ignore)
else
v1 == v2 || error("$v1 != $v2")
end
else
# Recursive case
for p in properties
_restore!(getproperty(v1, p), getproperty(v2, p); name = "$(name).$(p)", ignore)
end
end
return nothing
end

# function _restore!(v1::T, v2::T; name, ignore) where {T}
# v1 .= v2
# return nothing
# end

function _restore_base!(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol, CA.AtmosModel, Nothing}}
v1 == v2 || error("$v1 != $v2")
return nothing
end

function _restore_base!(v1::T, v2::T; name, ignore) where {T <: Number}
# To account for NaN
v1 === v2 || error("$v1 != $v2")
return nothing
end

# # Ignore NCDatasets
# function _restore_base!(v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset}
# return nothing
# end


function _restore_base!(v1::T, v2::T; name, ignore) where {T <: Union{Field, FieldVector, AbstractData, AbstractArray}}
parent(v1) .= parent(v2)
return nothing
end


# # We ignore NCDatasets. They contain a lot of state-ful information
# function _restore!(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset}
# return nothing
# end

function _restore!(v1::T1, v2::T2; name, ignore) where {T1, T2}
error("v1 and v2 have different types")
end
34 changes: 33 additions & 1 deletion experiments/ClimaEarth/components/land/climaland_bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function BucketSimulation(
output_dir::String;
space,
dt::Float64,
saveat::Float64,
saveat::Vector{Float64},
area_fraction,
stepper = CTS.RK4(),
date_ref::Dates.DateTime,
Expand Down Expand Up @@ -396,6 +396,38 @@ function make_land_domain(
return CL.Domains.SphericalShell{FT}(radius, depth, nothing, nelements, npolynomial, space, fields)
end

function Checkpointer.get_model_cache(sim::BucketSimulation)
return sim.integrator.p
end

function Checkpointer.restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache)
old_cache = Checkpointer.get_model_cache(sim)
recursively_reset!(old_cache, new_cache)
end

function recursively_reset!(v1::T, v2::T; ignore = Set([:rc])) where {T}
properties = filter(x -> !(x in ignore), propertynames(v1))
if isempty(properties)
if !Base.issingletontype(typeof(v1))
recursively_reset_base!(v1, v2)
else
v1 == v2 || error("v1 != v2")
end
else
# Recursive case
for p in properties
recursively_reset!(getproperty(v1, p), getproperty(v2, p); ignore)
end
end
end

function recursively_reset_base!(
v1::T,
v2::T,
) where {T <: Union{CC.Fields.Field, CC.Fields.FieldVector, CC.DataLayouts.AbstractData, AbstractArray}}
parent(v1) .= parent(v2)
end

"""
dss_state!(sim::BucketSimulation)

Expand Down
2 changes: 1 addition & 1 deletion experiments/ClimaEarth/components/ocean/eisenman_seaice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function EisenmanIceSimulation(
dss_buffer = CC.Spaces.create_dss_buffer(Y),
)
problem = SciMLBase.ODEProblem(ode_function, Y, Float64.(tspan), cache)
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64(saveat), adaptive = false)
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64.(saveat), adaptive = false)

sim = EisenmanIceSimulation(params, Y, space, integrator)
return sim
Expand Down
2 changes: 1 addition & 1 deletion experiments/ClimaEarth/components/ocean/prescr_seaice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ function PrescribedIceSimulation(
ode_function = CTS.ClimaODEFunction(T_exp! = ice_rhs!, dss! = (Y, p, t) -> CC.Spaces.weighted_dss!(Y, p.dss_buffer))

problem = SciMLBase.ODEProblem(ode_function, Y, Float64.(tspan), (; cache..., params = params))
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64(saveat), adaptive = false)
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64.(saveat), adaptive = false)

sim = PrescribedIceSimulation(params, Y, space, integrator)

Expand Down
2 changes: 1 addition & 1 deletion experiments/ClimaEarth/components/ocean/slab_ocean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function SlabOceanSimulation(
CTS.ClimaODEFunction(; T_exp! = slab_ocean_rhs!, dss! = (Y, p, t) -> CC.Spaces.weighted_dss!(Y, p.dss_buffer))

problem = SciMLBase.ODEProblem(ode_function, Y, Float64.(tspan), cache)
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64(saveat), adaptive = false)
integrator = SciMLBase.init(problem, ode_algo, dt = Float64(dt), saveat = Float64.(saveat), adaptive = false)

sim = SlabOceanSimulation(params, Y, space, integrator)

Expand Down
Loading