From 465e7dbadc4e0624d04264f837d3179f3d76387b Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Mon, 3 Mar 2025 07:31:43 -0800 Subject: [PATCH] Add restart test and support for restarts This commit implements a way to restart simulations by saving both state and caches of component models, as well as the coupler fields. Given that caches are complex object, I implemented this using JLD2 files. The challenges with JLD2 files are that: - they are not MPI compatible, - they are not GPU compatible. For this reason, I have to move everything to the CPU, and have each process write to its own output. This adds a restriction: only the same number of MPI process (and the same machine) can be used for restarts. In addition to this, this approach requires component models to implement their functions to restore their caches. Something that can be improved in the future is that, ClimaAtmos is currently producing two checkpoints, one independently, and one from ClimaCoupler. This should not be needed, but it is currently needed because there's no other way to start ClimaAtmos at a different time. The other problem here is that the MPI test occasionally hangs (as it does in ClimaAtmos). --- .buildkite/pipeline.yml | 34 ++-- NEWS.md | 14 ++ Project.toml | 4 +- docs/src/checkpointer.md | 131 ++++++++++++- experiments/ClimaEarth/Manifest-v1.11.toml | 78 +++++--- experiments/ClimaEarth/cli_options.jl | 6 +- .../components/atmosphere/climaatmos.jl | 52 +++++ .../components/land/climaland_bucket.jl | 18 ++ .../components/ocean/prescr_ocean.jl | 4 + .../components/ocean/prescr_seaice.jl | 4 + .../ClimaEarth/components/shared/restore.jl | 86 +++++++++ experiments/ClimaEarth/run_amip.jl | 2 +- .../ClimaEarth/run_cloudless_aquaplanet.jl | 2 +- .../ClimaEarth/run_cloudy_aquaplanet.jl | 2 +- .../ClimaEarth/run_cloudy_slabplanet.jl | 2 +- experiments/ClimaEarth/run_dry_held_suarez.jl | 2 +- .../ClimaEarth/run_moist_held_suarez.jl | 2 +- experiments/ClimaEarth/setup_run.jl | 150 ++++++++------- experiments/ClimaEarth/test/compare.jl | 122 ++++++++++++ experiments/ClimaEarth/test/restart.jl | 117 +++++++++++ experiments/ClimaEarth/test/runtests.jl | 4 + experiments/ClimaEarth/user_io/arg_parsing.jl | 8 +- src/Checkpointer.jl | 181 +++++++++++++++--- src/Utilities.jl | 2 +- test/mpi_tests/checkpointer_mpi_tests.jl | 47 ----- test/runtests.jl | 3 - 26 files changed, 878 insertions(+), 199 deletions(-) create mode 100644 experiments/ClimaEarth/components/shared/restore.jl create mode 100644 experiments/ClimaEarth/test/compare.jl create mode 100644 experiments/ClimaEarth/test/restart.jl delete mode 100644 test/mpi_tests/checkpointer_mpi_tests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 92030ebfa9..e3230333c3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -67,16 +67,6 @@ steps: - group: "Unit Tests" steps: - - label: "MPI Checkpointer unit tests" - key: "checkpointer_mpi_tests" - command: "srun julia --color=yes --project=test/ test/mpi_tests/checkpointer_mpi_tests.jl" - timeout_in_minutes: 20 - env: - CLIMACOMMS_CONTEXT: "MPI" - agents: - slurm_ntasks: 2 - slurm_mem: 16GB - - label: "MPI Utilities unit tests" key: "utilities_mpi_tests" command: "srun julia --color=yes --project=test/ test/utilities_tests.jl" @@ -97,6 +87,7 @@ steps: agents: slurm_ntasks: 1 slurm_gres: "gpu:1" + slurm_mem: 24GB - group: "GPU: experiments/ClimaEarth/ unit tests and global bucket" steps: @@ -109,6 +100,27 @@ steps: slurm_gres: "gpu:1" slurm_mem: 20GB + - group: "ClimaEarth test" + steps: + - label: "ClimaEarth test" + key: "restarts" + command: "julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/runtests.jl" + agents: + slurm_mem: 16GB + + - label: "MPI restarts" + key: "mpi_restarts" + command: "srun julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/restart.jl" + env: + CLIMACOMMS_CONTEXT: "MPI" + timeout_in_minutes: 40 + soft_fail: + - exit_status: -1 + - exit_status: 255 + agents: + slurm_ntasks: 2 + slurm_mem: 32G + - group: "Integration Tests" steps: # SLABPLANET EXPERIMENTS @@ -218,7 +230,7 @@ steps: CLIMACOMMS_CONTEXT: "MPI" agents: slurm_ntasks: 4 - slurm_mem_per_cpu: 8GB + slurm_mem_per_cpu: 12GB # short high-res performance test - label: "Unthreaded AMIP FINE" # also reported by longruns with a flame graph diff --git a/NEWS.md b/NEWS.md index 5f9e17bffa..ea5a664e44 100644 --- a/NEWS.md +++ b/NEWS.md @@ -21,6 +21,20 @@ TOA radiation and net precipitation are added only if conservation is enabled. The coupler fields are also now stored as a ClimaCore Field of NamedTuples, rather than as a NamedTuple of ClimaCore Fields. +#### Restart simulations with JLD2 files PR[#1179](https://github.com/CliMA/ClimaCoupler.jl/pull/1179) + +`ClimaCoupler` can now use `JLD2` files to save state and cache for its model +component, allowing it to restart from saved checkpoints. Some restrictions +apply: + +- The number of MPI processes has to remain the same across checkpoints +- Restart files are generally not portable across machines, julia versions, and package versions +- Adding/changing new component models will probably require adding/changing code + +Please, refer to the +[documentation](https://clima.github.io/ClimaCoupler.jl/dev/checkpointer/) for +more information. + #### Remove extra `get_field` functions PR[#1203](https://github.com/CliMA/ClimaCoupler.jl/pull/1203) Removes the `get_field` functions for `air_density` for all models, which were unused except for the `BucketSimulation` method, which is replaced by a diff --git a/Project.toml b/Project.toml index 80ab0ed366..928df8677e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -16,9 +17,10 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" [compat] ClimaComms = "0.6.2" -ClimaCore = "0.14.23" +ClimaCore = "0.14.25" ClimaUtilities = "0.1.22" Dates = "1" +JLD2 = "0.5.11" Logging = "1" SciMLBase = "2.11" StaticArrays = "1.6" diff --git a/docs/src/checkpointer.md b/docs/src/checkpointer.md index 9b84246e3d..5b3aae7020 100644 --- a/docs/src/checkpointer.md +++ b/docs/src/checkpointer.md @@ -1,12 +1,137 @@ # Checkpointer -This module contains general functions for logging the model states and restarting simulations. The `Checkpointer` uses `ClimaCore.InputOutput` infrastructure, which allows it to handle arbitrarily distributed logging and restart setups. +## How to save and restart from checkpoints + +`ClimaCoupler` supports saving and reading simulation checkpoints. This is +useful to split a long simulation into smaller, more manageable chunks. + +Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a +`checkpoints` folder in the simulation output. See +[`Utilities.setup_output_dirs`](@ref) for more information. + +!!! known limitations + + - The number of MPI processes has to remain the same across checkpoints + - Restart files are generally not portable across machines, julia versions, and package versions + - Adding/changing new component models will probably require adding/changing code + +### Saving checkpoints + +If you are running a model (such as AMIP), chances are that you can enable +checkpointing just by setting a command-line argument; The `checkpoint_dt` +option controls how frequently a checkpoint should be produced. + +If your model does not come with this option already, you can checkpoint the +simulation by adding a callback that calls the +[`Checkpointer.checkpoint_sims`](@ref) function. + +For example, to add a callback to checkpoint every hour of simulated time, +assuming you have a `start_date` +```julia +import Dates + +import ClimaCoupler: Checkpointer, TimeManager +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule + +schedule = EveryCalendarDtSchedule(Dates.Hour(1); start_date) +checkpoint_callback = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) + +# In the coupling loop: +TimeManager.maybe_trigger_callback(checkpoint_callback, coupled_simulation, time) +``` + +### Reading checkpoints + +There are two ways to restart a simulation from checkpoints. By default, +`ClimaCoupler` tries finding suitable checkpoints and automatically use them. +Alternatively, you can specify a directory `restart_dir` and a simulation time +`restart_t` and restart from files saved in the given directory at the given +time. If the model you are running supports writing checkpoints via command-line +argument, it will probably also support reading them. In this case, the +arguments `restart_dir` and `restart_t` identify the path of the top level +directory containing all the checkpoint files and the simulated times in second. + +If the model does not support directly reading a checkpoint, the `Checkpointer` +module provides a straightforward way to add this feature. +[`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and +a `restart_t` and overwrites the content of the coupled simulation with what is +in the checkpoint. + +## Developer notes + +In theory, the state of the component models should fully determine the state of +the coupled simulation and one should be able to restart a coupled simulation +just by using the states of the component models. Unfortunately, this is +currently not the case in `ClimaCoupler`. The main reason for this is the +complex interdependencies between component models and within `ClimaAtmos` which +make the initialization step inconsistent. For example, in a coupled simulation, +the surface albedo should be determined by the surface models and used by the +atmospheric model for radiation transfer, but `ClimaAtmos` also tries to set the +surface albedo (since it has to do so when run in standalone mode). In addition +to this, `ClimaAtmos` has a large cache that has internal interdependencies that +are hard to disentangle, and changing a field might require changing some other +field in a different part of the cache. As a result, it is not easy for +`ClimaCoupler` to consistently do initialization from a cold state. To conclude, +restarting a simulation exclusively using the states of the component models is +currently impossible. + +Given that restarting a simulation from the state is impossible, `ClimaCoupler` +needs to save the states and the caches. Let us review how we use +`ClimaCore.InputOutput` and `JLD2` package to accomplish this. + +`ClimaCore.InputOutput` provides a loss-less way to save the content of certain +`ClimaCore` objects to HDF5 files. Objects saved in this way are not tied to a +particular computing device or configuration. When running with MPI, +`ClimaCore.InputOutput` are also efficiently written in parallel. + +Unfortunately, `ClimaCore.InputOutput` only supports certain objects, such as +`Field`s and `Space`s, but the cache in component models is more complex than +this and contains complex objects with highly stateful quantities (e.g., C +pointers). Because of this, model states are saved to HDF5 but caches must be +saved to JLD2 files. + +`JLD2` allows us to save more complex objects without writing specific +serialization methods for every struct. `JLD2` allows us to take a big step +forward, but there are still several challenges that need to be solved: +1. `JLD2` does not support CUDA natively. To go around this, we have to move + everything onto the CPU first. Then, when the data is read back, we have to + move it back to the GPU. +2. `JLD2` does not support MPI natively. To go around this, each process writes + its `jld2` checkpoint and reads it back. This introduces the constraint that + the number of MPI processes cannot change across restarts. +3. Some quantities are best not saved and read (for example, anything with + pointers). For this, we write a recursive function that traverses the cache + and only restores quantities of a certain type (typically, `ClimaCore` + objects) + +Point 3. adds significant amount of code and requires component models to +specify how their cache has to be restored. + +If you are adding a component model, you have to extend the +``` +Checkpointer.get_model_prog_state +Checkpointer.get_model_cache +Checkpointer.restore_cache! +``` +methods. + +`ClimaCoupler` moves objects to the CPU with `Adapt(Array, x)`. `Adapt` +traverses the object recursively, and proper `Adapt` methods have to be defined +for every object involved in the chain. The easiest way to do this is using the +`Adapt.@adapt_structure` macro, which defines a recursive Adapt for the given +object. + +Types to watch for: +- `MPI` related objects (e.g., `MPICommsContext`) +- `TimeVaryingInputs` (because they contain `NCDatasets`, which contain pointers + to files) ## Checkpointer API ```@docs ClimaCoupler.Checkpointer.get_model_prog_state - ClimaCoupler.Checkpointer.restart_model_state! - ClimaCoupler.Checkpointer.checkpoint_model_state + ClimaCoupler.Checkpointer.get_model_cache + ClimaCoupler.Checkpointer.restart! ClimaCoupler.Checkpointer.checkpoint_sims + ClimaCoupler.Checkpointer.t_start_from_checkpoint ``` diff --git a/experiments/ClimaEarth/Manifest-v1.11.toml b/experiments/ClimaEarth/Manifest-v1.11.toml index e620101a7d..58da936549 100644 --- a/experiments/ClimaEarth/Manifest-v1.11.toml +++ b/experiments/ClimaEarth/Manifest-v1.11.toml @@ -337,7 +337,7 @@ uuid = "908f55d8-4145-4867-9c14-5dad1a479e4d" version = "0.4.6" [[deps.ClimaCoupler]] -deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] +deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "JLD2", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] path = "../.." uuid = "4ade58fe-a8da-486c-bd89-46df092ec0c7" version = "0.1.2" @@ -718,11 +718,10 @@ git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" -[[deps.Expronicon]] -deps = ["MLStyle", "Pkg", "TOML"] -git-tree-sha1 = "fc3951d4d398b5515f91d7fe5d45fc31dccb3c9b" -uuid = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" -version = "0.8.5" +[[deps.ExproniconLite]] +git-tree-sha1 = "c13f0b150373771b0fdc1713c97860f8df12e6c2" +uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" +version = "0.10.14" [[deps.Extents]] git-tree-sha1 = "063512a13dbe9c40d999c439268539aa552d1ae6" @@ -961,10 +960,10 @@ uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" version = "0.4.11" [[deps.GeometryOps]] -deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] -git-tree-sha1 = "7eaffabf21dcdc7a5e543c309b903371af5c9b07" +deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "GeometryOpsCore", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] +git-tree-sha1 = "226ebac075e4a477bbaeacb4f7e720f9dce019a9" uuid = "3251bfac-6a57-4b6d-aa61-ac1fef2975ab" -version = "0.1.14" +version = "0.1.15" [deps.GeometryOps.extensions] GeometryOpsFlexiJoinsExt = "FlexiJoins" @@ -976,6 +975,12 @@ version = "0.1.14" LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" Proj = "c94c279d-25a6-4763-9509-64d165bea63e" +[[deps.GeometryOpsCore]] +deps = ["DataAPI", "GeoInterface", "Tables"] +git-tree-sha1 = "390a7ff6a89a997d6a1fa76c857490e0bd34565f" +uuid = "05efe853-fabf-41c8-927e-7063c8b9f013" +version = "0.1.2" + [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" @@ -1201,6 +1206,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "91d501cb908df6f134352ad73cde5efc50138279" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.5.11" + [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6" @@ -1225,6 +1236,12 @@ version = "1.14.1" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.Jieko]] +deps = ["ExproniconLite"] +git-tree-sha1 = "2f05ed29618da60c06a87e9c033982d4f71d0b6c" +uuid = "ae98c720-c025-4a4a-838c-29b094483192" +version = "0.2.1" + [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" @@ -1478,11 +1495,6 @@ git-tree-sha1 = "5de60bc6cb3899cd318d80d627560fae2e2d99ae" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" version = "2025.0.1+1" -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] git-tree-sha1 = "e7159031670cee777cc2840aef7a521c3603e36c" @@ -1566,6 +1578,12 @@ git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" version = "0.3.4" +[[deps.Moshi]] +deps = ["ExproniconLite", "Jieko"] +git-tree-sha1 = "453de0fc2be3d11b9b93ca4d0fddd91196dcf1ed" +uuid = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +version = "0.3.5" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.12.12" @@ -1613,9 +1631,9 @@ version = "1.1.2" [[deps.NaNStatistics]] deps = ["PrecompileTools", "Static", "StaticArrayInterface"] -git-tree-sha1 = "1529c48f8c63a815c985890e0192b006b06de528" +git-tree-sha1 = "c1a9def67b8a871b51ccf84512f173e7979c186b" uuid = "b946abbf-3ea7-4610-9019-9858bfdeaf2d" -version = "0.6.45" +version = "0.6.47" [[deps.NaturalEarth]] deps = ["Downloads", "GeoJSON", "Pkg", "Scratch"] @@ -1734,9 +1752,9 @@ version = "0.11.32" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "67186a2bc9a90f9f85ff3cc8277868961fb57cbd" +git-tree-sha1 = "cf181f0b1e6a18dfeb0ee8acc4a9d1672499626c" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.3" +version = "0.4.4" [[deps.PROJ_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "Libtiff_jll", "OpenSSL_jll", "SQLite_jll"] @@ -1935,9 +1953,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "fe9d37a17ab4d41a98951332ee8067f8dca8c4c2" +git-tree-sha1 = "a967273b0c96f9e55ccb93322ada38f43e685c49" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.29.0" +version = "3.30.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -2034,13 +2052,14 @@ uuid = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" version = "0.1.0" [[deps.SciMLBase]] -deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] -git-tree-sha1 = "ffed2507209da5b42c6881944ef41a340ab5449b" +deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "Moshi", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "2242fd564bb0202a22a91f575dc58b8820612b6b" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.74.1" +version = "2.75.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" @@ -2051,6 +2070,7 @@ version = "2.74.1" [deps.SciMLBase.weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -2298,9 +2318,9 @@ weakdeps = ["ClimaParams"] [[deps.SymbolicIndexingInterface]] deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] -git-tree-sha1 = "fd2d4f0499f6bb4a0d9f5030f5c7d61eed385e03" +git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.37" +version = "0.3.38" [[deps.TOML]] deps = ["Dates"] @@ -2473,9 +2493,9 @@ version = "1.0.0" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "ee6f41aac16f6c9a8cab34e2f7a200418b1cc1e3" +git-tree-sha1 = "b8b243e47228b4a3877f1dd6aee0c5d56db7fcf4" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.6+0" +version = "2.13.6+1" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -2550,9 +2570,9 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "622cf78670d067c738667aaa96c553430b65e269" +git-tree-sha1 = "446b23e73536f84e8037f5dce465e92275f6a308" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.7+0" +version = "1.5.7+1" [[deps.isoband_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/experiments/ClimaEarth/cli_options.jl b/experiments/ClimaEarth/cli_options.jl index 6c994653ed..4aebdad2e6 100644 --- a/experiments/ClimaEarth/cli_options.jl +++ b/experiments/ClimaEarth/cli_options.jl @@ -75,7 +75,7 @@ function argparse_settings() arg_type = String default = "10days" "--checkpoint_dt" - help = "Time interval for hourly checkpointing [\"20days\" (default)]" + help = "Time interval for checkpointing [\"20days\" (default)]" arg_type = String default = "20days" # Restart information @@ -84,9 +84,9 @@ function argparse_settings() arg_type = String default = nothing "--restart_t" - help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [0 (default)]" + help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [nothing (default)]" arg_type = Int - default = 0 + default = nothing # Diagnostics information "--use_coupler_diagnostics" help = "Boolean flag indicating whether to compute and output coupler diagnostics [`true` (default), `false`]" diff --git a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl index 5d81414c93..507dcc5bf6 100644 --- a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl +++ b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl @@ -14,6 +14,15 @@ import ClimaUtilities.TimeManager: ITime include("climaatmos_extra_diags.jl") +if pkgversion(CA) < v"0.28.6" + # Allow cache to be moved to CPU (this is a little bit of type piracy, but we + # allow it in this particular file) + CC.Adapt.@adapt_structure CA.AtmosCache + CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel +end + +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for an AtmosModelSimulation ### @@ -103,6 +112,22 @@ function Checkpointer.get_model_prog_state(sim::ClimaAtmosSimulation) return sim.integrator.u end +function Checkpointer.get_model_cache(sim::ClimaAtmosSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache) + comms_ctx = ClimaComms.context(sim.integrator.u.c) + restore!( + Checkpointer.get_model_cache(sim), + new_cache, + comms_ctx; + ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler, :graph_context]), + ) + return nothing +end + + """ Interfacer.get_field(atmos_sim::ClimaAtmosSimulation, ::Val{:radiative_energy_flux_toa}) @@ -375,6 +400,11 @@ function get_atmos_config_dict(coupler_dict::Dict, job_id::String, atmos_output_ atmos_config["output_dir_style"] = "RemovePreexisting" atmos_config["output_dir"] = atmos_output_dir + # Ensure Atmos's own checkpoints are synced up with ClimaCoupler, so that we + # can pick up from where we have left. NOTE: This should not be needed, but + # there is no easy way to initialize ClimaAtmos with a different t_start + atmos_config["dt_save_state_to_disk"] = coupler_dict["checkpoint_dt"] + # Add all extra atmos diagnostic entries into the vector of atmos diagnostics # If atmos doesn't have any diagnostics, use the extra_atmos_diagnostics from the coupler atmos_config["diagnostics"] = @@ -529,3 +559,25 @@ function FluxCalculator.water_albedo_from_atmosphere!( diffuse_albedo .= CA.surface_albedo_diffuse(α_model).(λ, μ, LinearAlgebra.norm.(CC.Fields.level(Y.c.uₕ, 1))) end + +""" + climaatmos_restart_path(output_dir_root, t) + +Look at the most recent output in `output_dir_root` and find a checkpoint for time `t`. +""" +function climaatmos_restart_path(output_dir_root, t) + isdir(output_dir_root) || error("$(output_dir_root) does not exist") + name_rx = r"output_(\d\d\d\d)" + existing_outputs = filter(x -> !isnothing(match(name_rx, x)), readdir(output_dir_root)) + + day = floor(Int, t / (60 * 60 * 24)) + sec = floor(Int, t % (60 * 60 * 24)) + + # Walk back the folders and tyr to find a checkpoint + for output in sort(existing_outputs, rev = true) + previous_folder = joinpath(output_dir_root, output) + restart_file = joinpath(previous_folder, "clima_atmos", "day$day.$sec.hdf5") + ispath(restart_file) && return restart_file + end + error("Restart file for time $t not found") +end diff --git a/experiments/ClimaEarth/components/land/climaland_bucket.jl b/experiments/ClimaEarth/components/land/climaland_bucket.jl index da97ea3eb0..a626bb2eee 100644 --- a/experiments/ClimaEarth/components/land/climaland_bucket.jl +++ b/experiments/ClimaEarth/components/land/climaland_bucket.jl @@ -1,6 +1,7 @@ import Dates import SciMLBase import Statistics +import ClimaComms import ClimaCore as CC import ClimaTimeSteppers as CTS import Thermodynamics as TD @@ -10,6 +11,8 @@ import ClimaDiagnostics as CD import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer using NCDatasets +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation ### @@ -407,6 +410,21 @@ function make_land_domain( return CL.Domains.SphericalShell{FT}(radius, depth, nothing, nelements, npolynomial, space, fields) end +function Checkpointer.get_model_cache(sim::BucketSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + comms_ctx = ClimaComms.context(sim.model) + restore!( + old_cache, + new_cache, + comms_ctx, + ignore = Set([:rc, :params, :dss_buffer_2d, :dss_buffer_3d, :graph_context]), + ) +end + """ dss_state!(sim::BucketSimulation) diff --git a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl index 75e3b6644b..ea9f9b5fae 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl @@ -112,3 +112,7 @@ at each timestep. function Interfacer.step!(sim::PrescribedOceanSimulation, t) evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) end + +function Checkpointer.get_model_cache(sim::PrescribedOceanSimulation) + return sim.cache +end diff --git a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl index ac2a7ef7e9..a08f0e0319 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl @@ -262,3 +262,7 @@ Perform DSS on the state of a component simulation, intended to be used before the initial step of a run. This method acts on prescribed ice simulations. """ dss_state!(sim::PrescribedIceSimulation) = CC.Spaces.weighted_dss!(sim.integrator.u, sim.integrator.p.dss_buffer) + +function Checkpointer.get_model_cache(sim::PrescribedIceSimulation) + return sim.integrator.p +end diff --git a/experiments/ClimaEarth/components/shared/restore.jl b/experiments/ClimaEarth/components/shared/restore.jl new file mode 100644 index 0000000000..fc56e5c21a --- /dev/null +++ b/experiments/ClimaEarth/components/shared/restore.jl @@ -0,0 +1,86 @@ +# Define shared methods to allow reading back a saved cache + +import ClimaComms +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput +import StaticArrays +import NCDatasets + +""" + restore!(v1, v2, comms_ctx; ignore) + +Recursively traverse `v1` and `v2`, setting each field of `v1` with the +corresponding field in `v2`. In this, ignore all the properties that have name +within the `ignore` iterable. + +`ignore` is useful when there are stateful properties, such as live pointers. +""" +function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1, T2} + # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed + # to CuArray) + fields = filter(x -> !(x in ignore), fieldnames(T2)) + if isempty(fields) + if !Base.issingletontype(typeof(v1)) + restore!(v1, v2, comms_ctx; name, ignore) + else + v1 == v2 || error("$v1 != $v2") + end + else + # Recursive case + for p in fields + restore!(getfield(v1, p), getfield(v2, p), comms_ctx; name = "$(name).$(p)", ignore) + end + end + return nothing +end + +# Ignoring certain types that don't need to be restored +# UnionAll and DataType are infinitely recursive, so we also ignore those +function restore!( + v1::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + v2::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + _comms_ctx; + name, + ignore, +) + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where {T1 <: Union{AbstractData, AbstractArray}, T2 <: Union{AbstractData, AbstractArray}} + ArrayType = parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx)) + moved_to_device = ArrayType(parent(v2)) + + parent(v1) .= moved_to_device + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where { + T1 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, + T2 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, +} + v1 == v2 || error("$name is a immutable but it inconsistent") + return nothing +end + +function restore!(v1::T1, v2::T2, comms_ctx; name, ignore) where {T1 <: Dict, T2 <: Dict} + # RRTGMP has some internal dictionaries + v1 == v2 || error("$name is inconsistent") + return nothing +end diff --git a/experiments/ClimaEarth/run_amip.jl b/experiments/ClimaEarth/run_amip.jl index c4b697e923..31a754986f 100644 --- a/experiments/ClimaEarth/run_amip.jl +++ b/experiments/ClimaEarth/run_amip.jl @@ -22,4 +22,4 @@ include("setup_run.jl") config_file = parse_commandline(argparse_settings())["config_file"] # Set up and run the coupled simulation -setup_and_run(config_file) +cs = setup_and_run(config_file) diff --git a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl index 6db69b2d45..2c4945b94a 100644 --- a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -103,6 +102,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl index a5739f20c2..8e27e23373 100644 --- a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl @@ -74,7 +74,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -100,6 +99,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_slabplanet.jl b/experiments/ClimaEarth/run_cloudy_slabplanet.jl index ba3a061d94..e91753700e 100644 --- a/experiments/ClimaEarth/run_cloudy_slabplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_slabplanet.jl @@ -82,7 +82,6 @@ dt_rad = "6hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist @@ -109,6 +108,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_dry_held_suarez.jl b/experiments/ClimaEarth/run_dry_held_suarez.jl index b8ef9a54dd..17e282e8d9 100644 --- a/experiments/ClimaEarth/run_dry_held_suarez.jl +++ b/experiments/ClimaEarth/run_dry_held_suarez.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -102,6 +101,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_moist_held_suarez.jl b/experiments/ClimaEarth/run_moist_held_suarez.jl index 7b7b72d25b..8c13e01a27 100644 --- a/experiments/ClimaEarth/run_moist_held_suarez.jl +++ b/experiments/ClimaEarth/run_moist_held_suarez.jl @@ -80,7 +80,6 @@ checkpoint_dt = "480hours" dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -106,6 +105,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/setup_run.jl b/experiments/ClimaEarth/setup_run.jl index 8e6ccce707..84a56475af 100644 --- a/experiments/ClimaEarth/setup_run.jl +++ b/experiments/ClimaEarth/setup_run.jl @@ -90,7 +90,6 @@ and exchanges combined fields and calculates fluxes using the selected turbulent Note that we want to implement this in a dispatchable function to allow for other forms of timestepping (e.g. leapfrog). """ - function solve_coupler!(cs) (; model_sims, Δt_cpl, tspan, comms_ctx) = cs (; atmos_sim, land_sim, ocean_sim, ice_sim) = model_sims @@ -164,6 +163,9 @@ input config file. It initializes the component models, all coupler objects, diagnostics, and conservation checks, and then runs the simulation. """ function setup_and_run(config_dict::AbstractDict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Initialize communication context (do this first so all printing is only on root) comms_ctx = Utilities.get_comms_context(config_dict) # Select the correct timestep for each component model based on which are available @@ -209,8 +211,7 @@ function setup_and_run(config_dict::AbstractDict) #and `dir_paths.checkpoints`, where restart files are saved. =# - COUPLER_OUTPUT_DIR = joinpath(output_dir_root, job_id) - dir_paths = Utilities.setup_output_dirs(output_dir = COUPLER_OUTPUT_DIR, comms_ctx = comms_ctx) + dir_paths = Utilities.setup_output_dirs(output_dir = output_dir_root, comms_ctx = comms_ctx) @info "Coupler output directory $(dir_paths.output)" @info "Coupler artifacts directory $(dir_paths.artifacts)" @info "Coupler checkpoint directory $(dir_paths.checkpoints)" @@ -231,6 +232,23 @@ function setup_and_run(config_dict::AbstractDict) Random.seed!(random_seed) @info "Random seed set to $(random_seed)" + isnothing(restart_t) && (restart_t = Checkpointer.t_start_from_checkpoint(dir_paths.checkpoints)) + isnothing(restart_dir) && (restart_dir = dir_paths.checkpoints) + should_restart = !isnothing(restart_t) && !isnothing(restart_dir) + if should_restart + if t_start isa ITime + t_start, _ = promote(ITime(restart_t), t_start) + else + t_start = restart_t + end + + # TODO: Find a cleaner way to do this instead of having a second restart + # just for atmos + atmos_config_dict["restart_file"] = climaatmos_restart_path(output_dir_root, restart_t) + + @info "Starting from t_start $(t_start)" + end + tspan = (t_start, t_end) #= @@ -579,71 +597,66 @@ function setup_and_run(config_dict::AbstractDict) If a restart directory is specified and contains output files from the `checkpoint_cb` callback, the component model states are restarted from those files. The restart directory is specified in the `config_dict` dictionary. The `restart_t` field specifies the time step at which the restart is performed. =# + should_restart && Checkpointer.restart!(cs, restart_dir, restart_t) - if !isnothing(restart_dir) - for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - Checkpointer.restart_model_state!(sim, comms_ctx, restart_t; input_dir = restart_dir) - end - end - end + if !should_restart + #= + ## Initialize Component Model Exchange - #= - ## Initialize Component Model Exchange + We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) + depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. + The concrete steps for proper initialization are: + =# - We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) - depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. - The concrete steps for proper initialization are: - =# + # 1.coupler updates surface model area fractions + FieldExchanger.update_surface_fractions!(cs) - # 1.coupler updates surface model area fractions - FieldExchanger.update_surface_fractions!(cs) - - # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. - # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - - # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally - ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) - Interfacer.step!(land_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - - # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent - # surface fluxes using either the combined state or the partitioned state method - if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST - ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc - ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box - FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts - elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes - ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields - FluxCalculator.partitioned_turbulent_fluxes!( - cs.model_sims, - cs.fields, - cs.boundary_space, - FluxCalculator.MoninObukhovScheme(), - cs.thermo_params, - ) + # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. + # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - ## update atmos sfc_conditions for surface temperature - ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) - new_p = get_new_cache(atmos_sim, cs.fields) - CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) - atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions - end + # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally + ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) + Interfacer.step!(land_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) - FieldExchanger.reinit_model_sims!(cs.model_sims) + # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent + # surface fluxes using either the combined state or the partitioned state method + if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST + ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc + ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box + FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts + elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes + ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields + FluxCalculator.partitioned_turbulent_fluxes!( + cs.model_sims, + cs.fields, + cs.boundary_space, + FluxCalculator.MoninObukhovScheme(), + cs.thermo_params, + ) + + ## update atmos sfc_conditions for surface temperature + ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) + new_p = get_new_cache(atmos_sim, cs.fields) + CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) + atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions + end + + # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) + FieldExchanger.reinit_model_sims!(cs.model_sims) - # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types - # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, - # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` - # atmos receives the turbulent fluxes from the coupler. - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types + # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, + # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` + # atmos receives the turbulent fluxes from the coupler. + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + end #= ## Precompilation of Coupling Loop @@ -653,13 +666,15 @@ function setup_and_run(config_dict::AbstractDict) beginning and end of the simulation timespan to the correct values. =# - ## run the coupled simulation for two timesteps to precompile - cs.tspan[2] = tspan[1] + Δt_cpl * 2 - solve_coupler!(cs) + if tspan[2] > 2Δt_cpl + tspan[1] + ## run the coupled simulation for two timesteps to precompile + cs.tspan[2] = tspan[1] + Δt_cpl * 2 + solve_coupler!(cs) - ## update the timespan to the correct values - cs.tspan[1] = tspan[1] + Δt_cpl * 2 - cs.tspan[2] = tspan[2] + ## update the timespan to the correct values + cs.tspan[1] = tspan[1] + Δt_cpl * 2 + cs.tspan[2] = tspan[2] + end ## Run garbage collection before solving for more accurate memory comparison to ClimaAtmos GC.gc() @@ -724,4 +739,5 @@ function setup_and_run(config_dict::AbstractDict) # Close all diagnostics file writers isnothing(cs.diags_handler) || foreach(diag -> close(diag.output_writer), cs.diags_handler.scheduled_diagnostics) isnothing(atmos_sim.output_writers) || foreach(close, atmos_sim.output_writers) + return cs end diff --git a/experiments/ClimaEarth/test/compare.jl b/experiments/ClimaEarth/test/compare.jl new file mode 100644 index 0000000000..3ec12b8674 --- /dev/null +++ b/experiments/ClimaEarth/test/compare.jl @@ -0,0 +1,122 @@ +# compare.jl provides function to recursively compare complex objects while also +# allowing for some numerical tolerance. + +import ClimaComms +import ClimaAtmos as CA +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import NCDatasets + +""" + _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + +We compute the error in this way: +- when the absolute value is larger than ABS_TOL, we use the absolute error +- in the other cases, we compare the relative errors +""" +function _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + # There are some parameters, e.g. Obukhov length, for which Inf + # is a reasonable value (implying a stability parameter in the neutral boundary layer + # regime, for instance). We account for such instances with the `isfinite` function. + arr1 = Array(arr1) .* isfinite.(Array(arr1)) + arr2 = Array(arr2) .* isfinite.(Array(arr2)) + diff = abs.(arr1 .- arr2) + denominator = abs.(arr1) + error = ifelse.(denominator .> ABS_TOL, diff ./ denominator, diff) + return error +end + +""" + compare(v1, v2; name = "", ignore = Set([:rc])) + +Return whether `v1` and `v2` are the same (up to floating point errors). +`compare` walks through all the properties in `v1` and `v2` until it finds +that there are no more properties. At that point, `compare` tries to match the +resulting objects. When such objects are arrays with floating point, `compare` +defines a notion of `error` that is the following: when the absolute value is +less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative +error. The `error` is then compared against a tolerance. +Keyword arguments +================= +- `name` is used to collect the name of the property while we go recursively + over all the properties. You can pass a base name. +- `ignore` is a collection of `Symbol`s that identify properties that are + ignored when walking through the tree. This is useful for properties that + are known to be different (e.g., `output_dir`). +`:rc` is some CUDA/CuArray internal object that we don't care about +""" +function compare( + v1::T1, + v2::T2; + name = "", + ignore = Set([:rc]), +) where { + T1 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, + T2 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, +} + pass = true + return _compare(pass, v1, v2; name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T} + properties = filter(x -> !(x in ignore), propertynames(v1)) + if isempty(properties) + pass &= _compare(v1, v2; name, ignore) + else + # Recursive case + for p in properties + pass &= _compare(pass, getproperty(v1, p), getproperty(v2, p); name = "$(name).$(p)", ignore) + end + end + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T} + return print_maybe(v1 == v2, "$name differs") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol}} + # What we can safely print without filling STDOUT + return print_maybe(v1 == v2, "$name differs: $v1 vs $v2") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Number} + # We check with triple equal so that we also catch NaNs being equal + return print_maybe(v1 === v2, "$name differs: $v1 vs $v2") +end + +# We ignore NCDatasets. They contain a lot of state-ful information +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset} + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Field{<:AbstractData{<:Real}}} + return _compare(parent(v1), parent(v2); name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData} + return pass && _compare(parent(v1), parent(v2); name, ignore) +end + +# Handle views +function _compare(pass, v1::SubArray{FT}, v2::SubArray{FT}; name, ignore) where {FT <: AbstractFloat} + return pass && _compare(collect(v1), collect(v2); name, ignore) +end + +function _compare(v1::AbstractArray{FT}, v2::AbstractArray{FT}; name, ignore) where {FT <: AbstractFloat} + error = maximum(_error(v1, v2); init = zero(eltype(v1))) + return print_maybe(error <= 100eps(eltype(v1)), "$name error: $error") +end + +function _compare(pass, v1::T1, v2::T2; name, ignore) where {T1, T2} + error("v1 and v2 have different types") +end + +function print_maybe(exp, what) + exp || println(what) + return exp +end diff --git a/experiments/ClimaEarth/test/restart.jl b/experiments/ClimaEarth/test/restart.jl new file mode 100644 index 0000000000..9779391348 --- /dev/null +++ b/experiments/ClimaEarth/test/restart.jl @@ -0,0 +1,117 @@ +# This test runs a small AMIP simulation four times. +# +# - The first time the simulation is run for four steps +# - The second time the simulation is run for two steps +# - The third time the simulation is run for two steps, but restarting from the +# second simulation +# +# After all these simulations are run, we compare the first and last runs. They +# should be bit-wise identical. +# +# The content of the simulation is not the most important, but it helps if it +# has all of the complexity possible. + +import ClimaComms +ClimaComms.@import_required_backends +import ClimaUtilities.OutputPathGenerator: maybe_wait_filesystem +import YAML +import Logging +using Test + +# Uncomment the following for cleaner output (but more difficult debugging) +# Logging.disable_logging(Logging.Warn) + +include("compare.jl") +include("../setup_run.jl") + +comms_ctx = ClimaComms.context() +@info "Context: $(comms_ctx)" +ClimaComms.init(comms_ctx) + +# Make sure that all MPI processes agree on the output_loc +tmpdir = ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : "" +tmpdir = ClimaComms.bcast(comms_ctx, tmpdir) +# Sometimes the shared filesystem doesn't work properly and the folder is not +# synced across MPI processes. Let's add an additional check here. +maybe_wait_filesystem(ClimaComms.context(), tmpdir) + +default_config = parse_commandline(argparse_settings()) +base_config_file = joinpath(@__DIR__, "amip_test.yml") +base_config = YAML.load_file(base_config_file) +merge!(default_config, base_config) + +# Four steps +four_steps = deepcopy(default_config) + +four_steps["dt"] = "180secs" +four_steps["dt_cpl"] = "180secs" +four_steps["t_end"] = "720secs" +four_steps["dt_rad"] = "180secs" +four_steps["job_id"] = "four_steps" + +cs_four_steps = setup_and_run(four_steps) + +println("Simulating two steps") + +# Now, two steps plus one +two_steps = deepcopy(default_config) + +two_steps["dt"] = "180secs" +two_steps["dt_cpl"] = "180secs" +two_steps["t_end"] = "360secs" +two_steps["dt_rad"] = "180secs" +two_steps["coupler_output_dir"] = tmpdir +two_steps["checkpoint_dt"] = "360secs" +two_steps["job_id"] = "two_steps" + +# Copying since setup_and_run changes its content +cs_two_steps1 = setup_and_run(two_steps) + +println("Reading and simulating last step") +# Two additional steps +two_steps["t_end"] = "720secs" +cs_two_steps2 = setup_and_run(two_steps) + +@testset "Restarts" begin + # We put cs_four_steps.fields in a NamedTuple so that we can start the recursion in compare + @test compare((; coupler_fields = cs_four_steps.fields), (; coupler_fields = cs_two_steps2.fields)) + + @test compare(cs_four_steps.model_sims.atmos_sim.integrator.u, cs_two_steps2.model_sims.atmos_sim.integrator.u) + + @test compare( + cs_four_steps.model_sims.atmos_sim.integrator.p, + cs_two_steps2.model_sims.atmos_sim.integrator.p, + ignore = [ + :walltime_estimate, # Stateful + :output_dir, # Changes across runs + :scratch, # Irrelevant + :ghost_buffer, # Irrelevant + :hyperdiffusion_ghost_buffer, # Irrelevant + :data_handler, # Stateful + :face_clear_sw_direct_flux_dn, # Not filled by RRTGMP + :face_sw_direct_flux_dn, # Not filled by RRTGMP + :rc, # CUDA internal object + ], + ) + + @test compare(cs_four_steps.model_sims.ice_sim.integrator.u, cs_two_steps2.model_sims.ice_sim.integrator.u) + + @test compare(cs_four_steps.model_sims.land_sim.integrator.u, cs_two_steps2.model_sims.land_sim.integrator.u) + @test compare( + cs_four_steps.model_sims.land_sim.integrator.p, + cs_two_steps2.model_sims.land_sim.integrator.p, + ignore = [:dss_buffer_3d, :dss_buffer_2d, :rc], + ) + + # Ignoring SST_timevaryinginput because it contains closures (which should be + # reinitialized correctly). We have to remove it from the type, otherwise + # compare will not work + function delete(nt::NamedTuple, fieldnames...) + return (; filter(p -> !(first(p) in fieldnames), collect(pairs(nt)))...) + end + + ocean_cache_four = delete(cs_four_steps.model_sims.ocean_sim.cache, :SST_timevaryinginput) + ocean_cache_two2 = delete(cs_two_steps2.model_sims.ocean_sim.cache, :SST_timevaryinginput) + + @test compare(ocean_cache_four, ocean_cache_two2) +end diff --git a/experiments/ClimaEarth/test/runtests.jl b/experiments/ClimaEarth/test/runtests.jl index 92e9f69ef0..ccd36cb1bc 100644 --- a/experiments/ClimaEarth/test/runtests.jl +++ b/experiments/ClimaEarth/test/runtests.jl @@ -25,3 +25,7 @@ end @safetestset "AMIP test" begin include("amip_test.jl") end + +@safetestset "Restart test" begin + include("restart.jl") +end diff --git a/experiments/ClimaEarth/user_io/arg_parsing.jl b/experiments/ClimaEarth/user_io/arg_parsing.jl index 23032ddb28..918b81626a 100644 --- a/experiments/ClimaEarth/user_io/arg_parsing.jl +++ b/experiments/ClimaEarth/user_io/arg_parsing.jl @@ -44,6 +44,9 @@ This function may modify the input dictionary to remove unnecessary keys. - All arguments needed for the coupled simulation """ function get_coupler_args(config_dict::Dict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Simulation-identifying information; Print `config_dict` if requested config_dict["print_config_dict"] && @info(config_dict) job_id = config_dict["job_id"] @@ -84,7 +87,8 @@ function get_coupler_args(config_dict::Dict) # Restart information restart_dir = config_dict["restart_dir"] - restart_t = Int(config_dict["restart_t"]) + restart_t = + isnothing(config_dict["restart_t"]) ? nothing : Int64(Utilities.time_to_seconds(config_dict["restart_t"])) # Diagnostics information use_coupler_diagnostics = config_dict["use_coupler_diagnostics"] @@ -101,7 +105,7 @@ function get_coupler_args(config_dict::Dict) conservation_softfail = config_dict["conservation_softfail"] # Output information - output_dir_root = config_dict["coupler_output_dir"] + output_dir_root = joinpath(config_dict["coupler_output_dir"], job_id) # ClimaLand-specific information land_domain_type = config_dict["land_domain_type"] diff --git a/src/Checkpointer.jl b/src/Checkpointer.jl index eec2463198..51899be24e 100644 --- a/src/Checkpointer.jl +++ b/src/Checkpointer.jl @@ -7,10 +7,13 @@ module Checkpointer import ClimaComms import ClimaCore as CC +import ClimaUtilities.Utils: sort_by_creation_time import ..Interfacer import Dates -export get_model_prog_state, checkpoint_model_state, restart_model_state!, checkpoint_sims +import JLD2 + +export get_model_prog_state, checkpoint_model_state, checkpoint_sims """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) @@ -20,10 +23,18 @@ This is a template function that should be implemented for each component model. """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) = nothing +""" + get_model_cache(sim::Interfacer.ComponentModelSimulation) + +Returns the model cache of a simulation. +This is a template function that should be implemented for each component model. +""" +get_model_cache(sim::Interfacer.ComponentModelSimulation) = nothing + """ checkpoint_model_state(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") -Checkpoints the model state of a simulation to a HDF5 file at a given time, t (in seconds). +Checkpoint the model state of a simulation to a HDF5 file at a given time, t (in seconds). """ function checkpoint_model_state( sim::Interfacer.ComponentModelSimulation, @@ -35,8 +46,7 @@ function checkpoint_model_state( day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) @info "Saving checkpoint " * Interfacer.name(sim) * " model state to HDF5 on day $day second $sec" - mkpath(joinpath(output_dir, "checkpoint")) - output_file = joinpath(output_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + output_file = joinpath(output_dir, "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") checkpoint_writer = CC.InputOutput.HDF5Writer(output_file, comms_ctx) CC.InputOutput.HDF5.write_attribute(checkpoint_writer.file, "time", t) CC.InputOutput.write!(checkpoint_writer, Y, "model_state") @@ -46,30 +56,41 @@ function checkpoint_model_state( end """ - restart_model_state!(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; input_dir = "input") + checkpoint_model_cache(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") -Sets the model state of a simulation from a HDF5 file from a given time, t (in seconds). +Checkpoint the model cache to N JLD2 files at a given time, t (in seconds), +where N is the number of MPI ranks. + +Objects are saved to JLD2 files because caches are generally not ClimaCore +objects (and ClimaCore.InputOutput can only save `Field`s or `FieldVector`s). """ -function restart_model_state!( +function checkpoint_model_cache( sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; - input_dir = "input", + output_dir = "output", ) - Y = get_model_prog_state(sim) + # Move p to CPU (because we cannot save CUArrays) + p = CC.Adapt.adapt(Array, get_model_cache(sim)) day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) - input_file = joinpath(input_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + @info "Saving checkpoint " * Interfacer.name(sim) * " model cache to JLD2 on day $day second $sec" + pid = ClimaComms.mypid(comms_ctx) + output_file = joinpath(output_dir, "checkpoint_cache_$(pid)_" * Interfacer.name(sim) * "_$t.jld2") + JLD2.jldsave(output_file, cache = p) + return nothing +end - @info "Setting " Interfacer.name(sim) " state to checkpoint: $input_file, corresponding to day $day second $sec" - # open file and read - restart_reader = CC.InputOutput.HDF5Reader(input_file, comms_ctx) - Y_new = CC.InputOutput.read_field(restart_reader, "model_state") - Base.close(restart_reader) +""" + restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + +Replace the cache in `sim` with `new_cache`. - # set new state - Y .= Y_new +Component models can define new methods for this to change how cache is restored. +""" +function restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + return nothing end """ @@ -78,18 +99,126 @@ end This is a callback function that checkpoints all simulations defined in the current coupled simulation. """ function checkpoint_sims(cs::Interfacer.CoupledSimulation) + t = Dates.datetime2epochms(cs.dates.date[1]) + t0 = Dates.datetime2epochms(cs.dates.date0[1]) + time = Int((t - t0) / 1e3) + day = floor(Int, time / (60 * 60 * 24)) + sec = floor(Int, time % (60 * 60 * 24)) + output_dir = cs.dirs.checkpoints + comms_ctx = cs.comms_ctx + for sim in cs.model_sims + if !isnothing(Checkpointer.get_model_prog_state(sim)) + Checkpointer.checkpoint_model_state(sim, comms_ctx, time; output_dir) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + Checkpointer.checkpoint_model_cache(sim, comms_ctx, time; output_dir) + end + end + # Checkpoint the Coupler fields + pid = ClimaComms.mypid(comms_ctx) + @info "Saving coupler fields to JLD2 on day $day second $sec" + output_file = joinpath(output_dir, "checkpoint_coupler_fields_$(pid)_$time.jld2") + # Adapt to Array move fields to the CPU + JLD2.jldsave(output_file, coupler_fields = CC.Adapt.adapt(Array, cs.fields)) +end + +""" + restart!(cs::CoupledSimulation, checkpoint_dir, checkpoint_t) + +Overwrite the content of `cs` with checkpoints in `checkpoint_dir` at time `checkpoint_t`. + +Return a true if the simulation was restarted. +""" +function restart!(cs, checkpoint_dir, checkpoint_t) + pid = ClimaComms.mypid(cs.comms_ctx) for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - t = Dates.datetime2epochms(cs.dates.date[1]) - t0 = Dates.datetime2epochms(cs.dates.date0[1]) - Checkpointer.checkpoint_model_state( - sim, - cs.comms_ctx, - Int((t - t0) / 1e3), - output_dir = cs.dirs.checkpoints, - ) + if !isnothing(Checkpointer.get_model_prog_state(sim)) + input_file_state = + output_file = joinpath(checkpoint_dir, "checkpoint_$(Interfacer.name(sim))_$(checkpoint_t).hdf5") + restart_model_state!(sim, input_file_state, cs.comms_ctx) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + input_file_cache = + joinpath(checkpoint_dir, "checkpoint_cache_$(pid)_$(Interfacer.name(sim))_$(checkpoint_t).jld2") + restart_model_cache!(sim, input_file_cache) + end + end + input_file_coupler_fields = joinpath(checkpoint_dir, "checkpoint_coupler_fields_$(pid)_$(checkpoint_t).jld2") + restart_coupler_fields!(cs, input_file_coupler_fields) + return true +end + +""" + restart_model_cache!(sim, input_file) + +Overwrite the content of `sim` with the cache from the `input_file`. + +It relies on `restore_cache!(sim, old_cache)`, which has to be implemented by +the component models that have a cache. +""" +function restart_model_cache!(sim, input_file) + ispath(input_file) || error("File $(input_file) not found") + # Component models are responsible for defining a method for this + restore_cache!(sim, JLD2.jldopen(input_file)["cache"]) +end + +""" + restart_model_state!(sim, input_file, comms_ctx) + +Overwrite the content of `sim` with the state from the `input_file`. +""" +function restart_model_state!(sim, input_file, comms_ctx) + ispath(input_file) || error("File $(input_file) not found") + Y = get_model_prog_state(sim) + # open file and read + CC.InputOutput.HDF5Reader(input_file, comms_ctx) do restart_reader + Y_new = CC.InputOutput.read_field(restart_reader, "model_state") + # set new state + Y .= Y_new + end + return nothing +end + +""" + restart_coupler_fields!(cs, input_file) + +Overwrite the content of the coupled simulation `cs` with the coupler fields +read from `input_file`. +""" +function restart_coupler_fields!(cs, input_file) + ispath(input_file) || error("File $(input_file) not found") + fields_read = JLD2.jldopen(input_file)["coupler_fields"] + for name in propertynames(cs.fields) + ArrayType = ClimaComms.array_type(ClimaComms.device(cs.comms_ctx)) + parent(getproperty(cs.fields, name)) .= ArrayType(parent(getproperty(fields_read, name))) + end +end + +""" + t_start_from_checkpoint(checkpoint_dir) + +Look for restart files in `checkpoint_dir`, if found, return the time of the latest. +If not found, return `nothing`. +""" +function t_start_from_checkpoint(checkpoint_dir) + isdir(checkpoint_dir) || return nothing + restart_file_rx = r"checkpoint_(\w+)_(\d+).hdf5" + restarts = filter(f -> !isnothing(match(restart_file_rx, f)), readdir(checkpoint_dir)) + isempty(restarts) && return nothing + latest_restart = last(sort_by_creation_time(restarts)) + return parse(Int, match(restart_file_rx, latest_restart)[2]) +end + +# AbstractSurfaceStubs are typically simple enough that we can provide a generic method for them +function restore_cache!(sim::Interfacer.AbstractSurfaceStub, new_cache) + old_cache = get_model_cache(sim) + for p in propertynames(old_cache) + if getproperty(old_cache, p) isa CC.Fields.Field + ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) + parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) end end end + end # module diff --git a/src/Utilities.jl b/src/Utilities.jl index 75b359a729..dbfabb27e9 100644 --- a/src/Utilities.jl +++ b/src/Utilities.jl @@ -115,7 +115,7 @@ CPU of this process since it began. function show_memory_usage() cpu_max_rss_GB = "" cpu_max_rss_GB = string(round(Sys.maxrss() / 1e9, digits = 3)) * " GiB" - @info cpu_max_rss_GB + @info "Memory in use: $(cpu_max_rss_GB)" return cpu_max_rss_GB end diff --git a/test/mpi_tests/checkpointer_mpi_tests.jl b/test/mpi_tests/checkpointer_mpi_tests.jl deleted file mode 100644 index be57f5d472..0000000000 --- a/test/mpi_tests/checkpointer_mpi_tests.jl +++ /dev/null @@ -1,47 +0,0 @@ -#= - Unit tests for ClimaCoupler Checkpointer module functions to exercise MPI - -These are in a separate testing file from the other Checkpointer unit tests so -that MPI can be enabled for testing of these functions. -=# -import Test: @test, @testset -import ClimaComms -@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends -import ClimaCore as CC -import ClimaCoupler -import ClimaCoupler: Checkpointer, Interfacer - -include(joinpath("..", "..", "experiments", "ClimaEarth", "test", "TestHelper.jl")) -import .TestHelper - -# set up MPI communications context -const comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded()) -const pid, nprocs = ClimaComms.init(comms_ctx) -@info pid -ClimaComms.barrier(comms_ctx) - -FT = Float64 -struct DummySimulation{S} <: Interfacer.AtmosModelSimulation - state::S -end -Checkpointer.get_model_prog_state(sim::DummySimulation) = sim.state -@testset "checkpoint_model_state, restart_model_state!" begin - boundary_space = TestHelper.create_space(FT, comms_ctx = comms_ctx) - t = 1 - - # old sim run - sim = DummySimulation(CC.Fields.FieldVector(T = ones(boundary_space))) - Checkpointer.checkpoint_model_state(sim, comms_ctx, t, output_dir = "test_checkpoint") - ClimaComms.barrier(comms_ctx) - - # new sim run - sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(boundary_space))) - Checkpointer.restart_model_state!(sim_new, comms_ctx, t, input_dir = "test_checkpoint") - @test sim_new.state.T == sim.state.T - - # remove checkpoint directory - ClimaComms.barrier(comms_ctx) - if ClimaComms.iamroot(comms_ctx) - rm("./test_checkpoint/", force = true, recursive = true) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index beff063424..a8bb21c21c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,3 @@ end @safetestset "FluxCalculator tests" begin include("flux_calculator_tests.jl") end -@safetestset "Checkpointer tests" begin - include("checkpointer_tests.jl") -end