diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 92030ebfa9..e3230333c3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -67,16 +67,6 @@ steps: - group: "Unit Tests" steps: - - label: "MPI Checkpointer unit tests" - key: "checkpointer_mpi_tests" - command: "srun julia --color=yes --project=test/ test/mpi_tests/checkpointer_mpi_tests.jl" - timeout_in_minutes: 20 - env: - CLIMACOMMS_CONTEXT: "MPI" - agents: - slurm_ntasks: 2 - slurm_mem: 16GB - - label: "MPI Utilities unit tests" key: "utilities_mpi_tests" command: "srun julia --color=yes --project=test/ test/utilities_tests.jl" @@ -97,6 +87,7 @@ steps: agents: slurm_ntasks: 1 slurm_gres: "gpu:1" + slurm_mem: 24GB - group: "GPU: experiments/ClimaEarth/ unit tests and global bucket" steps: @@ -109,6 +100,27 @@ steps: slurm_gres: "gpu:1" slurm_mem: 20GB + - group: "ClimaEarth test" + steps: + - label: "ClimaEarth test" + key: "restarts" + command: "julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/runtests.jl" + agents: + slurm_mem: 16GB + + - label: "MPI restarts" + key: "mpi_restarts" + command: "srun julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/restart.jl" + env: + CLIMACOMMS_CONTEXT: "MPI" + timeout_in_minutes: 40 + soft_fail: + - exit_status: -1 + - exit_status: 255 + agents: + slurm_ntasks: 2 + slurm_mem: 32G + - group: "Integration Tests" steps: # SLABPLANET EXPERIMENTS @@ -218,7 +230,7 @@ steps: CLIMACOMMS_CONTEXT: "MPI" agents: slurm_ntasks: 4 - slurm_mem_per_cpu: 8GB + slurm_mem_per_cpu: 12GB # short high-res performance test - label: "Unthreaded AMIP FINE" # also reported by longruns with a flame graph diff --git a/NEWS.md b/NEWS.md index 5f9e17bffa..a1d6dbe41e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,23 @@ ClimaCoupler.jl Release Notes ### ClimaCoupler features +#### Shared component `dt` can be overwritten for individual components +Previously, we required that the user either specify a shared `dt` to be +used by all component models, or specify values for all component models +(`dt_atmos`, `dt_ocean`, `dt_seaice`, `dt_land`). If fewer than 4 +model-specific timesteps were provided, they would be discarded and +`dt` would be used uniformly instead. After this PR, if a user provides +fewer than 4 model-specific timesteps, they will be used for those models, +and the generic `dt` will be used for any models that don't have a more +specific timestep. +This makes choosing the timesteps simpler and allows us to easily set +specific `dt`s only for the models we're interested in. + +This PR also changes the prescribed ocean and sea ice simulations +to update the stored SST/SIC based on a daily schedule. Now, the +input data will be interpolated from monthly to daily instead of +to every timestep. + #### Add default `get_field` methods for surface models PR[#1210](https://github.com/CliMA/ClimaCoupler.jl/pull/1210) Add default methods for `get_field` methods that are commonly not extended for surface models. These return reasonable default @@ -21,6 +38,20 @@ TOA radiation and net precipitation are added only if conservation is enabled. The coupler fields are also now stored as a ClimaCore Field of NamedTuples, rather than as a NamedTuple of ClimaCore Fields. +#### Restart simulations with JLD2 files PR[#1179](https://github.com/CliMA/ClimaCoupler.jl/pull/1179) + +`ClimaCoupler` can now use `JLD2` files to save state and cache for its model +component, allowing it to restart from saved checkpoints. Some restrictions +apply: + +- The number of MPI processes has to remain the same across checkpoints +- Restart files are generally not portable across machines, julia versions, and package versions +- Adding/changing new component models will probably require adding/changing code + +Please, refer to the +[documentation](https://clima.github.io/ClimaCoupler.jl/dev/checkpointer/) for +more information. + #### Remove extra `get_field` functions PR[#1203](https://github.com/CliMA/ClimaCoupler.jl/pull/1203) Removes the `get_field` functions for `air_density` for all models, which were unused except for the `BucketSimulation` method, which is replaced by a diff --git a/Project.toml b/Project.toml index 80ab0ed366..928df8677e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -16,9 +17,10 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" [compat] ClimaComms = "0.6.2" -ClimaCore = "0.14.23" +ClimaCore = "0.14.25" ClimaUtilities = "0.1.22" Dates = "1" +JLD2 = "0.5.11" Logging = "1" SciMLBase = "2.11" StaticArrays = "1.6" diff --git a/config/ci_configs/amip_component_dts.yml b/config/ci_configs/amip_component_dts.yml index d8770b9812..67c5168319 100644 --- a/config/ci_configs/amip_component_dts.yml +++ b/config/ci_configs/amip_component_dts.yml @@ -3,7 +3,7 @@ co2: "maunaloa" dt_atmos: "150secs" dt_cpl: "150secs" dt_land: "50secs" -dt_ocean: "30secs" +dt_ocean: "300secs" dt_rad: "1hours" dt_save_to_sol: "1days" dt_seaice: "37.5secs" diff --git a/docs/src/checkpointer.md b/docs/src/checkpointer.md index 9b84246e3d..5b3aae7020 100644 --- a/docs/src/checkpointer.md +++ b/docs/src/checkpointer.md @@ -1,12 +1,137 @@ # Checkpointer -This module contains general functions for logging the model states and restarting simulations. The `Checkpointer` uses `ClimaCore.InputOutput` infrastructure, which allows it to handle arbitrarily distributed logging and restart setups. +## How to save and restart from checkpoints + +`ClimaCoupler` supports saving and reading simulation checkpoints. This is +useful to split a long simulation into smaller, more manageable chunks. + +Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a +`checkpoints` folder in the simulation output. See +[`Utilities.setup_output_dirs`](@ref) for more information. + +!!! known limitations + + - The number of MPI processes has to remain the same across checkpoints + - Restart files are generally not portable across machines, julia versions, and package versions + - Adding/changing new component models will probably require adding/changing code + +### Saving checkpoints + +If you are running a model (such as AMIP), chances are that you can enable +checkpointing just by setting a command-line argument; The `checkpoint_dt` +option controls how frequently a checkpoint should be produced. + +If your model does not come with this option already, you can checkpoint the +simulation by adding a callback that calls the +[`Checkpointer.checkpoint_sims`](@ref) function. + +For example, to add a callback to checkpoint every hour of simulated time, +assuming you have a `start_date` +```julia +import Dates + +import ClimaCoupler: Checkpointer, TimeManager +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule + +schedule = EveryCalendarDtSchedule(Dates.Hour(1); start_date) +checkpoint_callback = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) + +# In the coupling loop: +TimeManager.maybe_trigger_callback(checkpoint_callback, coupled_simulation, time) +``` + +### Reading checkpoints + +There are two ways to restart a simulation from checkpoints. By default, +`ClimaCoupler` tries finding suitable checkpoints and automatically use them. +Alternatively, you can specify a directory `restart_dir` and a simulation time +`restart_t` and restart from files saved in the given directory at the given +time. If the model you are running supports writing checkpoints via command-line +argument, it will probably also support reading them. In this case, the +arguments `restart_dir` and `restart_t` identify the path of the top level +directory containing all the checkpoint files and the simulated times in second. + +If the model does not support directly reading a checkpoint, the `Checkpointer` +module provides a straightforward way to add this feature. +[`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and +a `restart_t` and overwrites the content of the coupled simulation with what is +in the checkpoint. + +## Developer notes + +In theory, the state of the component models should fully determine the state of +the coupled simulation and one should be able to restart a coupled simulation +just by using the states of the component models. Unfortunately, this is +currently not the case in `ClimaCoupler`. The main reason for this is the +complex interdependencies between component models and within `ClimaAtmos` which +make the initialization step inconsistent. For example, in a coupled simulation, +the surface albedo should be determined by the surface models and used by the +atmospheric model for radiation transfer, but `ClimaAtmos` also tries to set the +surface albedo (since it has to do so when run in standalone mode). In addition +to this, `ClimaAtmos` has a large cache that has internal interdependencies that +are hard to disentangle, and changing a field might require changing some other +field in a different part of the cache. As a result, it is not easy for +`ClimaCoupler` to consistently do initialization from a cold state. To conclude, +restarting a simulation exclusively using the states of the component models is +currently impossible. + +Given that restarting a simulation from the state is impossible, `ClimaCoupler` +needs to save the states and the caches. Let us review how we use +`ClimaCore.InputOutput` and `JLD2` package to accomplish this. + +`ClimaCore.InputOutput` provides a loss-less way to save the content of certain +`ClimaCore` objects to HDF5 files. Objects saved in this way are not tied to a +particular computing device or configuration. When running with MPI, +`ClimaCore.InputOutput` are also efficiently written in parallel. + +Unfortunately, `ClimaCore.InputOutput` only supports certain objects, such as +`Field`s and `Space`s, but the cache in component models is more complex than +this and contains complex objects with highly stateful quantities (e.g., C +pointers). Because of this, model states are saved to HDF5 but caches must be +saved to JLD2 files. + +`JLD2` allows us to save more complex objects without writing specific +serialization methods for every struct. `JLD2` allows us to take a big step +forward, but there are still several challenges that need to be solved: +1. `JLD2` does not support CUDA natively. To go around this, we have to move + everything onto the CPU first. Then, when the data is read back, we have to + move it back to the GPU. +2. `JLD2` does not support MPI natively. To go around this, each process writes + its `jld2` checkpoint and reads it back. This introduces the constraint that + the number of MPI processes cannot change across restarts. +3. Some quantities are best not saved and read (for example, anything with + pointers). For this, we write a recursive function that traverses the cache + and only restores quantities of a certain type (typically, `ClimaCore` + objects) + +Point 3. adds significant amount of code and requires component models to +specify how their cache has to be restored. + +If you are adding a component model, you have to extend the +``` +Checkpointer.get_model_prog_state +Checkpointer.get_model_cache +Checkpointer.restore_cache! +``` +methods. + +`ClimaCoupler` moves objects to the CPU with `Adapt(Array, x)`. `Adapt` +traverses the object recursively, and proper `Adapt` methods have to be defined +for every object involved in the chain. The easiest way to do this is using the +`Adapt.@adapt_structure` macro, which defines a recursive Adapt for the given +object. + +Types to watch for: +- `MPI` related objects (e.g., `MPICommsContext`) +- `TimeVaryingInputs` (because they contain `NCDatasets`, which contain pointers + to files) ## Checkpointer API ```@docs ClimaCoupler.Checkpointer.get_model_prog_state - ClimaCoupler.Checkpointer.restart_model_state! - ClimaCoupler.Checkpointer.checkpoint_model_state + ClimaCoupler.Checkpointer.get_model_cache + ClimaCoupler.Checkpointer.restart! ClimaCoupler.Checkpointer.checkpoint_sims + ClimaCoupler.Checkpointer.t_start_from_checkpoint ``` diff --git a/experiments/ClimaEarth/Manifest-v1.11.toml b/experiments/ClimaEarth/Manifest-v1.11.toml index e620101a7d..58da936549 100644 --- a/experiments/ClimaEarth/Manifest-v1.11.toml +++ b/experiments/ClimaEarth/Manifest-v1.11.toml @@ -337,7 +337,7 @@ uuid = "908f55d8-4145-4867-9c14-5dad1a479e4d" version = "0.4.6" [[deps.ClimaCoupler]] -deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] +deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "JLD2", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] path = "../.." uuid = "4ade58fe-a8da-486c-bd89-46df092ec0c7" version = "0.1.2" @@ -718,11 +718,10 @@ git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" -[[deps.Expronicon]] -deps = ["MLStyle", "Pkg", "TOML"] -git-tree-sha1 = "fc3951d4d398b5515f91d7fe5d45fc31dccb3c9b" -uuid = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" -version = "0.8.5" +[[deps.ExproniconLite]] +git-tree-sha1 = "c13f0b150373771b0fdc1713c97860f8df12e6c2" +uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" +version = "0.10.14" [[deps.Extents]] git-tree-sha1 = "063512a13dbe9c40d999c439268539aa552d1ae6" @@ -961,10 +960,10 @@ uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" version = "0.4.11" [[deps.GeometryOps]] -deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] -git-tree-sha1 = "7eaffabf21dcdc7a5e543c309b903371af5c9b07" +deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "GeometryOpsCore", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] +git-tree-sha1 = "226ebac075e4a477bbaeacb4f7e720f9dce019a9" uuid = "3251bfac-6a57-4b6d-aa61-ac1fef2975ab" -version = "0.1.14" +version = "0.1.15" [deps.GeometryOps.extensions] GeometryOpsFlexiJoinsExt = "FlexiJoins" @@ -976,6 +975,12 @@ version = "0.1.14" LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" Proj = "c94c279d-25a6-4763-9509-64d165bea63e" +[[deps.GeometryOpsCore]] +deps = ["DataAPI", "GeoInterface", "Tables"] +git-tree-sha1 = "390a7ff6a89a997d6a1fa76c857490e0bd34565f" +uuid = "05efe853-fabf-41c8-927e-7063c8b9f013" +version = "0.1.2" + [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" @@ -1201,6 +1206,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "91d501cb908df6f134352ad73cde5efc50138279" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.5.11" + [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6" @@ -1225,6 +1236,12 @@ version = "1.14.1" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.Jieko]] +deps = ["ExproniconLite"] +git-tree-sha1 = "2f05ed29618da60c06a87e9c033982d4f71d0b6c" +uuid = "ae98c720-c025-4a4a-838c-29b094483192" +version = "0.2.1" + [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" @@ -1478,11 +1495,6 @@ git-tree-sha1 = "5de60bc6cb3899cd318d80d627560fae2e2d99ae" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" version = "2025.0.1+1" -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] git-tree-sha1 = "e7159031670cee777cc2840aef7a521c3603e36c" @@ -1566,6 +1578,12 @@ git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" version = "0.3.4" +[[deps.Moshi]] +deps = ["ExproniconLite", "Jieko"] +git-tree-sha1 = "453de0fc2be3d11b9b93ca4d0fddd91196dcf1ed" +uuid = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +version = "0.3.5" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.12.12" @@ -1613,9 +1631,9 @@ version = "1.1.2" [[deps.NaNStatistics]] deps = ["PrecompileTools", "Static", "StaticArrayInterface"] -git-tree-sha1 = "1529c48f8c63a815c985890e0192b006b06de528" +git-tree-sha1 = "c1a9def67b8a871b51ccf84512f173e7979c186b" uuid = "b946abbf-3ea7-4610-9019-9858bfdeaf2d" -version = "0.6.45" +version = "0.6.47" [[deps.NaturalEarth]] deps = ["Downloads", "GeoJSON", "Pkg", "Scratch"] @@ -1734,9 +1752,9 @@ version = "0.11.32" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "67186a2bc9a90f9f85ff3cc8277868961fb57cbd" +git-tree-sha1 = "cf181f0b1e6a18dfeb0ee8acc4a9d1672499626c" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.3" +version = "0.4.4" [[deps.PROJ_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "Libtiff_jll", "OpenSSL_jll", "SQLite_jll"] @@ -1935,9 +1953,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "fe9d37a17ab4d41a98951332ee8067f8dca8c4c2" +git-tree-sha1 = "a967273b0c96f9e55ccb93322ada38f43e685c49" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.29.0" +version = "3.30.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -2034,13 +2052,14 @@ uuid = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" version = "0.1.0" [[deps.SciMLBase]] -deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] -git-tree-sha1 = "ffed2507209da5b42c6881944ef41a340ab5449b" +deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "Moshi", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "2242fd564bb0202a22a91f575dc58b8820612b6b" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.74.1" +version = "2.75.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" @@ -2051,6 +2070,7 @@ version = "2.74.1" [deps.SciMLBase.weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -2298,9 +2318,9 @@ weakdeps = ["ClimaParams"] [[deps.SymbolicIndexingInterface]] deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] -git-tree-sha1 = "fd2d4f0499f6bb4a0d9f5030f5c7d61eed385e03" +git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.37" +version = "0.3.38" [[deps.TOML]] deps = ["Dates"] @@ -2473,9 +2493,9 @@ version = "1.0.0" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "ee6f41aac16f6c9a8cab34e2f7a200418b1cc1e3" +git-tree-sha1 = "b8b243e47228b4a3877f1dd6aee0c5d56db7fcf4" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.6+0" +version = "2.13.6+1" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -2550,9 +2570,9 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "622cf78670d067c738667aaa96c553430b65e269" +git-tree-sha1 = "446b23e73536f84e8037f5dce465e92275f6a308" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.7+0" +version = "1.5.7+1" [[deps.isoband_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/experiments/ClimaEarth/cli_options.jl b/experiments/ClimaEarth/cli_options.jl index 6c994653ed..7cb987ebef 100644 --- a/experiments/ClimaEarth/cli_options.jl +++ b/experiments/ClimaEarth/cli_options.jl @@ -51,11 +51,11 @@ function argparse_settings() arg_type = String default = "20000101" "--dt_cpl" - help = "Coupling time step in seconds [400 (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" + help = "Coupling time step in seconds [400secs (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" arg_type = String default = "400secs" "--dt" - help = "Component model time step [allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" + help = "Component model time step [400secs (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" arg_type = String default = "400secs" "--dt_atmos" @@ -75,7 +75,7 @@ function argparse_settings() arg_type = String default = "10days" "--checkpoint_dt" - help = "Time interval for hourly checkpointing [\"20days\" (default)]" + help = "Time interval for checkpointing [\"20days\" (default)]" arg_type = String default = "20days" # Restart information @@ -84,9 +84,9 @@ function argparse_settings() arg_type = String default = nothing "--restart_t" - help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [0 (default)]" + help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [nothing (default)]" arg_type = Int - default = 0 + default = nothing # Diagnostics information "--use_coupler_diagnostics" help = "Boolean flag indicating whether to compute and output coupler diagnostics [`true` (default), `false`]" diff --git a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl index 2d5017db71..9dee582a19 100644 --- a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl +++ b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl @@ -14,6 +14,15 @@ import ClimaUtilities.TimeManager: ITime include("climaatmos_extra_diags.jl") +if pkgversion(CA) < v"0.28.6" + # Allow cache to be moved to CPU (this is a little bit of type piracy, but we + # allow it in this particular file) + CC.Adapt.@adapt_structure CA.AtmosCache + CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel +end + +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for an AtmosModelSimulation ### @@ -103,6 +112,22 @@ function Checkpointer.get_model_prog_state(sim::ClimaAtmosSimulation) return sim.integrator.u end +function Checkpointer.get_model_cache(sim::ClimaAtmosSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache) + comms_ctx = ClimaComms.context(sim.integrator.u.c) + restore!( + Checkpointer.get_model_cache(sim), + new_cache, + comms_ctx; + ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler, :graph_context]), + ) + return nothing +end + + """ Interfacer.get_field(atmos_sim::ClimaAtmosSimulation, ::Val{:radiative_energy_flux_toa}) @@ -414,6 +439,11 @@ function get_atmos_config_dict(coupler_dict::Dict, job_id::String, atmos_output_ atmos_config["output_dir_style"] = "RemovePreexisting" atmos_config["output_dir"] = atmos_output_dir + # Ensure Atmos's own checkpoints are synced up with ClimaCoupler, so that we + # can pick up from where we have left. NOTE: This should not be needed, but + # there is no easy way to initialize ClimaAtmos with a different t_start + atmos_config["dt_save_state_to_disk"] = coupler_dict["checkpoint_dt"] + # Add all extra atmos diagnostic entries into the vector of atmos diagnostics # If atmos doesn't have any diagnostics, use the extra_atmos_diagnostics from the coupler atmos_config["diagnostics"] = @@ -568,3 +598,25 @@ function FluxCalculator.water_albedo_from_atmosphere!( diffuse_albedo .= CA.surface_albedo_diffuse(α_model).(λ, μ, LinearAlgebra.norm.(CC.Fields.level(Y.c.uₕ, 1))) end + +""" + climaatmos_restart_path(output_dir_root, t) + +Look at the most recent output in `output_dir_root` and find a checkpoint for time `t`. +""" +function climaatmos_restart_path(output_dir_root, t) + isdir(output_dir_root) || error("$(output_dir_root) does not exist") + name_rx = r"output_(\d\d\d\d)" + existing_outputs = filter(x -> !isnothing(match(name_rx, x)), readdir(output_dir_root)) + + day = floor(Int, t / (60 * 60 * 24)) + sec = floor(Int, t % (60 * 60 * 24)) + + # Walk back the folders and tyr to find a checkpoint + for output in sort(existing_outputs, rev = true) + previous_folder = joinpath(output_dir_root, output) + restart_file = joinpath(previous_folder, "clima_atmos", "day$day.$sec.hdf5") + ispath(restart_file) && return restart_file + end + error("Restart file for time $t not found") +end diff --git a/experiments/ClimaEarth/components/land/climaland_bucket.jl b/experiments/ClimaEarth/components/land/climaland_bucket.jl index da97ea3eb0..a626bb2eee 100644 --- a/experiments/ClimaEarth/components/land/climaland_bucket.jl +++ b/experiments/ClimaEarth/components/land/climaland_bucket.jl @@ -1,6 +1,7 @@ import Dates import SciMLBase import Statistics +import ClimaComms import ClimaCore as CC import ClimaTimeSteppers as CTS import Thermodynamics as TD @@ -10,6 +11,8 @@ import ClimaDiagnostics as CD import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer using NCDatasets +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation ### @@ -407,6 +410,21 @@ function make_land_domain( return CL.Domains.SphericalShell{FT}(radius, depth, nothing, nelements, npolynomial, space, fields) end +function Checkpointer.get_model_cache(sim::BucketSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + comms_ctx = ClimaComms.context(sim.model) + restore!( + old_cache, + new_cache, + comms_ctx, + ignore = Set([:rc, :params, :dss_buffer_2d, :dss_buffer_3d, :graph_context]), + ) +end + """ dss_state!(sim::BucketSimulation) diff --git a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl index 75e3b6644b..742a052dce 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl @@ -3,6 +3,7 @@ import ClimaUtilities.ClimaArtifacts: @clima_artifact import Interpolations # triggers InterpolationsExt in ClimaUtilities import Thermodynamics as TD import ClimaCoupler: Checkpointer, FieldExchanger, Interfacer +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule """ PrescribedOceanSimulation{C} @@ -84,6 +85,8 @@ function PrescribedOceanSimulation( SST_init = zeros(space) evaluate!(SST_init, SST_timevaryinginput, t_start) + SST_schedule = CD.Schedules.EveryCalendarDtSchedule(TimeManager.time_to_period("1days"); start_date = date0) + # Create the cache cache = (; T_sfc = SST_init, @@ -97,6 +100,7 @@ function PrescribedOceanSimulation( phase = TD.Liquid(), thermo_params = thermo_params, SST_timevaryinginput = SST_timevaryinginput, + SST_schedule = SST_schedule, ) return PrescribedOceanSimulation(cache) end @@ -107,8 +111,23 @@ end Interfacer.step!(sim::PrescribedOceanSimulation, t) Update the cached surface temperature field using the prescribed data -at each timestep. +at each timestep. This doesn't happen at every timestep, +but only when the SST data is scheduled to be updated. """ function Interfacer.step!(sim::PrescribedOceanSimulation, t) - evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) + sim.cache.SST_schedule(t) && evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) +end + +function Checkpointer.get_model_cache(sim::PrescribedOceanSimulation) + return sim.cache +end + +function Checkpointer.restore_cache!(sim::PrescribedOceanSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + for p in propertynames(old_cache) + if getproperty(old_cache, p) isa Field + ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) + parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) + end + end end diff --git a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl index ac2a7ef7e9..e90c9b4849 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl @@ -1,5 +1,6 @@ import SciMLBase import ClimaCore as CC +import ClimaDiagnostics as CD import ClimaTimeSteppers as CTS import ClimaUtilities.TimeVaryingInputs: TimeVaryingInput, evaluate! import ClimaUtilities.ClimaArtifacts: @clima_artifact @@ -143,8 +144,20 @@ function PrescribedIceSimulation( tspan = Float64.(tspan) saveat = Float64.(saveat) end + + # Set up a callback to read in SST data daily + SIC_schedule = CD.Schedules.EveryCalendarDtSchedule(TimeManager.time_to_period("1days"); start_date = date0) + SIC_update_cb = TimeManager.Callback(SIC_schedule, read_sic_data!) + problem = SciMLBase.ODEProblem(ode_function, Y, tspan, (; cache..., params = params)) - integrator = SciMLBase.init(problem, ode_algo, dt = dt, saveat = saveat, adaptive = false) + integrator = SciMLBase.init( + problem, + ode_algo, + dt = dt, + saveat = saveat, + adaptive = false, + callback = SciMLBase.CallbackSet(SIC_update_cb), + ) sim = PrescribedIceSimulation(params, space, integrator) @@ -153,6 +166,16 @@ function PrescribedIceSimulation( return sim end +""" + read_sic_data!(integrator) + +Read in the sea ice concentration data at the current time step. +This function is intended to be used within a callback +""" +function read_sic_data!(integrator) + evaluate!(integrator.p.area_fraction, integrator.p.SIC_timevaryinginput, integrator.t) +end + # extensions required by Interfacer Interfacer.get_field(sim::PrescribedIceSimulation, ::Val{:area_fraction}) = sim.integrator.p.area_fraction Interfacer.get_field(sim::PrescribedIceSimulation, ::Val{:roughness_buoyancy}) = sim.integrator.p.params.z0b @@ -237,9 +260,6 @@ function ice_rhs!(dY, Y, p, t) FT = eltype(Y) params = p.params - # Update the cached area fraction with the current SIC - evaluate!(p.area_fraction, p.SIC_timevaryinginput, t) - # Overwrite ice fraction with the static land area fraction anywhere we have nonzero land area # max needed to avoid Float32 errors (see issue #271; Heisenbug on HPC) @. p.area_fraction = max(min(p.area_fraction, FT(1) - p.land_fraction), FT(0)) @@ -262,3 +282,17 @@ Perform DSS on the state of a component simulation, intended to be used before the initial step of a run. This method acts on prescribed ice simulations. """ dss_state!(sim::PrescribedIceSimulation) = CC.Spaces.weighted_dss!(sim.integrator.u, sim.integrator.p.dss_buffer) + +function Checkpointer.get_model_cache(sim::PrescribedIceSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::PrescribedIceSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + for p in propertynames(old_cache) + if getproperty(old_cache, p) isa Field + ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) + parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) + end + end +end diff --git a/experiments/ClimaEarth/components/shared/restore.jl b/experiments/ClimaEarth/components/shared/restore.jl new file mode 100644 index 0000000000..fc56e5c21a --- /dev/null +++ b/experiments/ClimaEarth/components/shared/restore.jl @@ -0,0 +1,86 @@ +# Define shared methods to allow reading back a saved cache + +import ClimaComms +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput +import StaticArrays +import NCDatasets + +""" + restore!(v1, v2, comms_ctx; ignore) + +Recursively traverse `v1` and `v2`, setting each field of `v1` with the +corresponding field in `v2`. In this, ignore all the properties that have name +within the `ignore` iterable. + +`ignore` is useful when there are stateful properties, such as live pointers. +""" +function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1, T2} + # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed + # to CuArray) + fields = filter(x -> !(x in ignore), fieldnames(T2)) + if isempty(fields) + if !Base.issingletontype(typeof(v1)) + restore!(v1, v2, comms_ctx; name, ignore) + else + v1 == v2 || error("$v1 != $v2") + end + else + # Recursive case + for p in fields + restore!(getfield(v1, p), getfield(v2, p), comms_ctx; name = "$(name).$(p)", ignore) + end + end + return nothing +end + +# Ignoring certain types that don't need to be restored +# UnionAll and DataType are infinitely recursive, so we also ignore those +function restore!( + v1::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + v2::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + _comms_ctx; + name, + ignore, +) + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where {T1 <: Union{AbstractData, AbstractArray}, T2 <: Union{AbstractData, AbstractArray}} + ArrayType = parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx)) + moved_to_device = ArrayType(parent(v2)) + + parent(v1) .= moved_to_device + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where { + T1 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, + T2 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, +} + v1 == v2 || error("$name is a immutable but it inconsistent") + return nothing +end + +function restore!(v1::T1, v2::T2, comms_ctx; name, ignore) where {T1 <: Dict, T2 <: Dict} + # RRTGMP has some internal dictionaries + v1 == v2 || error("$name is inconsistent") + return nothing +end diff --git a/experiments/ClimaEarth/run_amip.jl b/experiments/ClimaEarth/run_amip.jl index c4b697e923..31a754986f 100644 --- a/experiments/ClimaEarth/run_amip.jl +++ b/experiments/ClimaEarth/run_amip.jl @@ -22,4 +22,4 @@ include("setup_run.jl") config_file = parse_commandline(argparse_settings())["config_file"] # Set up and run the coupled simulation -setup_and_run(config_file) +cs = setup_and_run(config_file) diff --git a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl index 6db69b2d45..2c4945b94a 100644 --- a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -103,6 +102,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl index a5739f20c2..8e27e23373 100644 --- a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl @@ -74,7 +74,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -100,6 +99,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_slabplanet.jl b/experiments/ClimaEarth/run_cloudy_slabplanet.jl index ba3a061d94..e91753700e 100644 --- a/experiments/ClimaEarth/run_cloudy_slabplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_slabplanet.jl @@ -82,7 +82,6 @@ dt_rad = "6hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist @@ -109,6 +108,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_dry_held_suarez.jl b/experiments/ClimaEarth/run_dry_held_suarez.jl index b8ef9a54dd..17e282e8d9 100644 --- a/experiments/ClimaEarth/run_dry_held_suarez.jl +++ b/experiments/ClimaEarth/run_dry_held_suarez.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -102,6 +101,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_moist_held_suarez.jl b/experiments/ClimaEarth/run_moist_held_suarez.jl index 7b7b72d25b..8c13e01a27 100644 --- a/experiments/ClimaEarth/run_moist_held_suarez.jl +++ b/experiments/ClimaEarth/run_moist_held_suarez.jl @@ -80,7 +80,6 @@ checkpoint_dt = "480hours" dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -106,6 +105,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/setup_run.jl b/experiments/ClimaEarth/setup_run.jl index 8e6ccce707..4ff8ccd5c3 100644 --- a/experiments/ClimaEarth/setup_run.jl +++ b/experiments/ClimaEarth/setup_run.jl @@ -90,7 +90,6 @@ and exchanges combined fields and calculates fluxes using the selected turbulent Note that we want to implement this in a dispatchable function to allow for other forms of timestepping (e.g. leapfrog). """ - function solve_coupler!(cs) (; model_sims, Δt_cpl, tspan, comms_ctx) = cs (; atmos_sim, land_sim, ocean_sim, ice_sim) = model_sims @@ -164,6 +163,9 @@ input config file. It initializes the component models, all coupler objects, diagnostics, and conservation checks, and then runs the simulation. """ function setup_and_run(config_dict::AbstractDict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Initialize communication context (do this first so all printing is only on root) comms_ctx = Utilities.get_comms_context(config_dict) # Select the correct timestep for each component model based on which are available @@ -209,8 +211,7 @@ function setup_and_run(config_dict::AbstractDict) #and `dir_paths.checkpoints`, where restart files are saved. =# - COUPLER_OUTPUT_DIR = joinpath(output_dir_root, job_id) - dir_paths = Utilities.setup_output_dirs(output_dir = COUPLER_OUTPUT_DIR, comms_ctx = comms_ctx) + dir_paths = Utilities.setup_output_dirs(output_dir = output_dir_root, comms_ctx = comms_ctx) @info "Coupler output directory $(dir_paths.output)" @info "Coupler artifacts directory $(dir_paths.artifacts)" @info "Coupler checkpoint directory $(dir_paths.checkpoints)" @@ -231,6 +232,23 @@ function setup_and_run(config_dict::AbstractDict) Random.seed!(random_seed) @info "Random seed set to $(random_seed)" + isnothing(restart_t) && (restart_t = Checkpointer.t_start_from_checkpoint(dir_paths.checkpoints)) + isnothing(restart_dir) && (restart_dir = dir_paths.checkpoints) + should_restart = !isnothing(restart_t) && !isnothing(restart_dir) + if should_restart + if t_start isa ITime + t_start, _ = promote(ITime(restart_t), t_start) + else + t_start = restart_t + end + + # TODO: Find a cleaner way to do this instead of having a second restart + # just for atmos + atmos_config_dict["restart_file"] = climaatmos_restart_path(output_dir_root, restart_t) + + @info "Starting from t_start $(t_start)" + end + tspan = (t_start, t_end) #= @@ -509,7 +527,7 @@ function setup_and_run(config_dict::AbstractDict) NB: Eventually, we will call all of radiation from the coupler, in addition to the albedo calculation. =# schedule_checkpoint = EveryCalendarDtSchedule(TimeManager.time_to_period(checkpoint_dt); start_date = date0) - checkpoint_cb = TimeManager.TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) + checkpoint_cb = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) if sim_mode <: AMIPMode schedule_albedo = EveryCalendarDtSchedule(TimeManager.time_to_period(dt_rad); start_date = date0) @@ -579,71 +597,66 @@ function setup_and_run(config_dict::AbstractDict) If a restart directory is specified and contains output files from the `checkpoint_cb` callback, the component model states are restarted from those files. The restart directory is specified in the `config_dict` dictionary. The `restart_t` field specifies the time step at which the restart is performed. =# + should_restart && Checkpointer.restart!(cs, restart_dir, restart_t) - if !isnothing(restart_dir) - for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - Checkpointer.restart_model_state!(sim, comms_ctx, restart_t; input_dir = restart_dir) - end - end - end + if !should_restart + #= + ## Initialize Component Model Exchange - #= - ## Initialize Component Model Exchange + We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) + depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. + The concrete steps for proper initialization are: + =# - We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) - depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. - The concrete steps for proper initialization are: - =# + # 1.coupler updates surface model area fractions + FieldExchanger.update_surface_fractions!(cs) - # 1.coupler updates surface model area fractions - FieldExchanger.update_surface_fractions!(cs) - - # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. - # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - - # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally - ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) - Interfacer.step!(land_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - - # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent - # surface fluxes using either the combined state or the partitioned state method - if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST - ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc - ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box - FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts - elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes - ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields - FluxCalculator.partitioned_turbulent_fluxes!( - cs.model_sims, - cs.fields, - cs.boundary_space, - FluxCalculator.MoninObukhovScheme(), - cs.thermo_params, - ) + # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. + # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - ## update atmos sfc_conditions for surface temperature - ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) - new_p = get_new_cache(atmos_sim, cs.fields) - CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) - atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions - end + # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally + ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) + Interfacer.step!(land_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) - FieldExchanger.reinit_model_sims!(cs.model_sims) + # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent + # surface fluxes using either the combined state or the partitioned state method + if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST + ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc + ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box + FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts + elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes + ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields + FluxCalculator.partitioned_turbulent_fluxes!( + cs.model_sims, + cs.fields, + cs.boundary_space, + FluxCalculator.MoninObukhovScheme(), + cs.thermo_params, + ) + + ## update atmos sfc_conditions for surface temperature + ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) + new_p = get_new_cache(atmos_sim, cs.fields) + CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) + atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions + end + + # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) + FieldExchanger.reinit_model_sims!(cs.model_sims) - # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types - # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, - # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` - # atmos receives the turbulent fluxes from the coupler. - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types + # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, + # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` + # atmos receives the turbulent fluxes from the coupler. + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + end #= ## Precompilation of Coupling Loop @@ -653,13 +666,15 @@ function setup_and_run(config_dict::AbstractDict) beginning and end of the simulation timespan to the correct values. =# - ## run the coupled simulation for two timesteps to precompile - cs.tspan[2] = tspan[1] + Δt_cpl * 2 - solve_coupler!(cs) + if tspan[2] > 2Δt_cpl + tspan[1] + ## run the coupled simulation for two timesteps to precompile + cs.tspan[2] = tspan[1] + Δt_cpl * 2 + solve_coupler!(cs) - ## update the timespan to the correct values - cs.tspan[1] = tspan[1] + Δt_cpl * 2 - cs.tspan[2] = tspan[2] + ## update the timespan to the correct values + cs.tspan[1] = tspan[1] + Δt_cpl * 2 + cs.tspan[2] = tspan[2] + end ## Run garbage collection before solving for more accurate memory comparison to ClimaAtmos GC.gc() @@ -724,4 +739,5 @@ function setup_and_run(config_dict::AbstractDict) # Close all diagnostics file writers isnothing(cs.diags_handler) || foreach(diag -> close(diag.output_writer), cs.diags_handler.scheduled_diagnostics) isnothing(atmos_sim.output_writers) || foreach(close, atmos_sim.output_writers) + return cs end diff --git a/experiments/ClimaEarth/test/compare.jl b/experiments/ClimaEarth/test/compare.jl new file mode 100644 index 0000000000..3ec12b8674 --- /dev/null +++ b/experiments/ClimaEarth/test/compare.jl @@ -0,0 +1,122 @@ +# compare.jl provides function to recursively compare complex objects while also +# allowing for some numerical tolerance. + +import ClimaComms +import ClimaAtmos as CA +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import NCDatasets + +""" + _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + +We compute the error in this way: +- when the absolute value is larger than ABS_TOL, we use the absolute error +- in the other cases, we compare the relative errors +""" +function _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + # There are some parameters, e.g. Obukhov length, for which Inf + # is a reasonable value (implying a stability parameter in the neutral boundary layer + # regime, for instance). We account for such instances with the `isfinite` function. + arr1 = Array(arr1) .* isfinite.(Array(arr1)) + arr2 = Array(arr2) .* isfinite.(Array(arr2)) + diff = abs.(arr1 .- arr2) + denominator = abs.(arr1) + error = ifelse.(denominator .> ABS_TOL, diff ./ denominator, diff) + return error +end + +""" + compare(v1, v2; name = "", ignore = Set([:rc])) + +Return whether `v1` and `v2` are the same (up to floating point errors). +`compare` walks through all the properties in `v1` and `v2` until it finds +that there are no more properties. At that point, `compare` tries to match the +resulting objects. When such objects are arrays with floating point, `compare` +defines a notion of `error` that is the following: when the absolute value is +less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative +error. The `error` is then compared against a tolerance. +Keyword arguments +================= +- `name` is used to collect the name of the property while we go recursively + over all the properties. You can pass a base name. +- `ignore` is a collection of `Symbol`s that identify properties that are + ignored when walking through the tree. This is useful for properties that + are known to be different (e.g., `output_dir`). +`:rc` is some CUDA/CuArray internal object that we don't care about +""" +function compare( + v1::T1, + v2::T2; + name = "", + ignore = Set([:rc]), +) where { + T1 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, + T2 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, +} + pass = true + return _compare(pass, v1, v2; name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T} + properties = filter(x -> !(x in ignore), propertynames(v1)) + if isempty(properties) + pass &= _compare(v1, v2; name, ignore) + else + # Recursive case + for p in properties + pass &= _compare(pass, getproperty(v1, p), getproperty(v2, p); name = "$(name).$(p)", ignore) + end + end + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T} + return print_maybe(v1 == v2, "$name differs") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol}} + # What we can safely print without filling STDOUT + return print_maybe(v1 == v2, "$name differs: $v1 vs $v2") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Number} + # We check with triple equal so that we also catch NaNs being equal + return print_maybe(v1 === v2, "$name differs: $v1 vs $v2") +end + +# We ignore NCDatasets. They contain a lot of state-ful information +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset} + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Field{<:AbstractData{<:Real}}} + return _compare(parent(v1), parent(v2); name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData} + return pass && _compare(parent(v1), parent(v2); name, ignore) +end + +# Handle views +function _compare(pass, v1::SubArray{FT}, v2::SubArray{FT}; name, ignore) where {FT <: AbstractFloat} + return pass && _compare(collect(v1), collect(v2); name, ignore) +end + +function _compare(v1::AbstractArray{FT}, v2::AbstractArray{FT}; name, ignore) where {FT <: AbstractFloat} + error = maximum(_error(v1, v2); init = zero(eltype(v1))) + return print_maybe(error <= 100eps(eltype(v1)), "$name error: $error") +end + +function _compare(pass, v1::T1, v2::T2; name, ignore) where {T1, T2} + error("v1 and v2 have different types") +end + +function print_maybe(exp, what) + exp || println(what) + return exp +end diff --git a/experiments/ClimaEarth/test/restart.jl b/experiments/ClimaEarth/test/restart.jl new file mode 100644 index 0000000000..9779391348 --- /dev/null +++ b/experiments/ClimaEarth/test/restart.jl @@ -0,0 +1,117 @@ +# This test runs a small AMIP simulation four times. +# +# - The first time the simulation is run for four steps +# - The second time the simulation is run for two steps +# - The third time the simulation is run for two steps, but restarting from the +# second simulation +# +# After all these simulations are run, we compare the first and last runs. They +# should be bit-wise identical. +# +# The content of the simulation is not the most important, but it helps if it +# has all of the complexity possible. + +import ClimaComms +ClimaComms.@import_required_backends +import ClimaUtilities.OutputPathGenerator: maybe_wait_filesystem +import YAML +import Logging +using Test + +# Uncomment the following for cleaner output (but more difficult debugging) +# Logging.disable_logging(Logging.Warn) + +include("compare.jl") +include("../setup_run.jl") + +comms_ctx = ClimaComms.context() +@info "Context: $(comms_ctx)" +ClimaComms.init(comms_ctx) + +# Make sure that all MPI processes agree on the output_loc +tmpdir = ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : "" +tmpdir = ClimaComms.bcast(comms_ctx, tmpdir) +# Sometimes the shared filesystem doesn't work properly and the folder is not +# synced across MPI processes. Let's add an additional check here. +maybe_wait_filesystem(ClimaComms.context(), tmpdir) + +default_config = parse_commandline(argparse_settings()) +base_config_file = joinpath(@__DIR__, "amip_test.yml") +base_config = YAML.load_file(base_config_file) +merge!(default_config, base_config) + +# Four steps +four_steps = deepcopy(default_config) + +four_steps["dt"] = "180secs" +four_steps["dt_cpl"] = "180secs" +four_steps["t_end"] = "720secs" +four_steps["dt_rad"] = "180secs" +four_steps["job_id"] = "four_steps" + +cs_four_steps = setup_and_run(four_steps) + +println("Simulating two steps") + +# Now, two steps plus one +two_steps = deepcopy(default_config) + +two_steps["dt"] = "180secs" +two_steps["dt_cpl"] = "180secs" +two_steps["t_end"] = "360secs" +two_steps["dt_rad"] = "180secs" +two_steps["coupler_output_dir"] = tmpdir +two_steps["checkpoint_dt"] = "360secs" +two_steps["job_id"] = "two_steps" + +# Copying since setup_and_run changes its content +cs_two_steps1 = setup_and_run(two_steps) + +println("Reading and simulating last step") +# Two additional steps +two_steps["t_end"] = "720secs" +cs_two_steps2 = setup_and_run(two_steps) + +@testset "Restarts" begin + # We put cs_four_steps.fields in a NamedTuple so that we can start the recursion in compare + @test compare((; coupler_fields = cs_four_steps.fields), (; coupler_fields = cs_two_steps2.fields)) + + @test compare(cs_four_steps.model_sims.atmos_sim.integrator.u, cs_two_steps2.model_sims.atmos_sim.integrator.u) + + @test compare( + cs_four_steps.model_sims.atmos_sim.integrator.p, + cs_two_steps2.model_sims.atmos_sim.integrator.p, + ignore = [ + :walltime_estimate, # Stateful + :output_dir, # Changes across runs + :scratch, # Irrelevant + :ghost_buffer, # Irrelevant + :hyperdiffusion_ghost_buffer, # Irrelevant + :data_handler, # Stateful + :face_clear_sw_direct_flux_dn, # Not filled by RRTGMP + :face_sw_direct_flux_dn, # Not filled by RRTGMP + :rc, # CUDA internal object + ], + ) + + @test compare(cs_four_steps.model_sims.ice_sim.integrator.u, cs_two_steps2.model_sims.ice_sim.integrator.u) + + @test compare(cs_four_steps.model_sims.land_sim.integrator.u, cs_two_steps2.model_sims.land_sim.integrator.u) + @test compare( + cs_four_steps.model_sims.land_sim.integrator.p, + cs_two_steps2.model_sims.land_sim.integrator.p, + ignore = [:dss_buffer_3d, :dss_buffer_2d, :rc], + ) + + # Ignoring SST_timevaryinginput because it contains closures (which should be + # reinitialized correctly). We have to remove it from the type, otherwise + # compare will not work + function delete(nt::NamedTuple, fieldnames...) + return (; filter(p -> !(first(p) in fieldnames), collect(pairs(nt)))...) + end + + ocean_cache_four = delete(cs_four_steps.model_sims.ocean_sim.cache, :SST_timevaryinginput) + ocean_cache_two2 = delete(cs_two_steps2.model_sims.ocean_sim.cache, :SST_timevaryinginput) + + @test compare(ocean_cache_four, ocean_cache_two2) +end diff --git a/experiments/ClimaEarth/test/runtests.jl b/experiments/ClimaEarth/test/runtests.jl index 92e9f69ef0..ccd36cb1bc 100644 --- a/experiments/ClimaEarth/test/runtests.jl +++ b/experiments/ClimaEarth/test/runtests.jl @@ -25,3 +25,7 @@ end @safetestset "AMIP test" begin include("amip_test.jl") end + +@safetestset "Restart test" begin + include("restart.jl") +end diff --git a/experiments/ClimaEarth/user_io/arg_parsing.jl b/experiments/ClimaEarth/user_io/arg_parsing.jl index 23032ddb28..95c63c0d7f 100644 --- a/experiments/ClimaEarth/user_io/arg_parsing.jl +++ b/experiments/ClimaEarth/user_io/arg_parsing.jl @@ -44,6 +44,9 @@ This function may modify the input dictionary to remove unnecessary keys. - All arguments needed for the coupled simulation """ function get_coupler_args(config_dict::Dict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Simulation-identifying information; Print `config_dict` if requested config_dict["print_config_dict"] && @info(config_dict) job_id = config_dict["job_id"] @@ -84,7 +87,8 @@ function get_coupler_args(config_dict::Dict) # Restart information restart_dir = config_dict["restart_dir"] - restart_t = Int(config_dict["restart_t"]) + restart_t = + isnothing(config_dict["restart_t"]) ? nothing : Int64(Utilities.time_to_seconds(config_dict["restart_t"])) # Diagnostics information use_coupler_diagnostics = config_dict["use_coupler_diagnostics"] @@ -101,7 +105,7 @@ function get_coupler_args(config_dict::Dict) conservation_softfail = config_dict["conservation_softfail"] # Output information - output_dir_root = config_dict["coupler_output_dir"] + output_dir_root = joinpath(config_dict["coupler_output_dir"], job_id) # ClimaLand-specific information land_domain_type = config_dict["land_domain_type"] @@ -228,27 +232,20 @@ function parse_component_dts!(config_dict) # Specify component model names component_dt_names = ["dt_atmos", "dt_land", "dt_ocean", "dt_seaice"] component_dt_dict = Dict{String, typeof(Δt_cpl)}() - # check if all component dt's are specified - if all(key -> !isnothing(config_dict[key]), component_dt_names) - # when all component dt's are specified, ignore the dt field - if haskey(config_dict, "dt") - @warn "Removing dt in favor of individual component dt's" - delete!(config_dict, "dt") - end - for key in component_dt_names + + @assert all(key -> !isnothing(config_dict[key]), component_dt_names) || haskey(config_dict, "dt") "all model-specific timesteps (dt_atmos, dt_land, dt_ocean, and dt_seaice) or a generic timestep (dt) must be specified" + + for key in component_dt_names + if !isnothing(config_dict[key]) + # Check if the component timestep is specified component_dt = Float64(Utilities.time_to_seconds(config_dict[key])) @assert isapprox(Δt_cpl % component_dt, 0.0) "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n $key = $component_dt" component_dt_dict[key] = component_dt - end - else - # when not all component dt's are specified, use the dt field - @assert haskey(config_dict, "dt") "dt or (dt_atmos, dt_land, dt_ocean, and dt_seaice) must be specified" - for key in component_dt_names - if !isnothing(config_dict[key]) - @warn "Removing $key from config in favor of dt because not all component dt's are specified" - end - delete!(config_dict, key) - component_dt_dict[key] = Float64(Utilities.time_to_seconds(config_dict["dt"])) + else + # If the component timestep is not specified, use the generic timestep + dt = Float64(Utilities.time_to_seconds(config_dict["dt"])) + @assert isapprox(Δt_cpl % dt, 0.0) "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n dt = $dt" + component_dt_dict[key] = dt end end config_dict["component_dt_dict"] = component_dt_dict diff --git a/src/Checkpointer.jl b/src/Checkpointer.jl index eec2463198..c018598405 100644 --- a/src/Checkpointer.jl +++ b/src/Checkpointer.jl @@ -7,10 +7,13 @@ module Checkpointer import ClimaComms import ClimaCore as CC +import ClimaUtilities.Utils: sort_by_creation_time import ..Interfacer import Dates -export get_model_prog_state, checkpoint_model_state, restart_model_state!, checkpoint_sims +import JLD2 + +export get_model_prog_state, checkpoint_model_state, checkpoint_sims """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) @@ -20,10 +23,18 @@ This is a template function that should be implemented for each component model. """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) = nothing +""" + get_model_cache(sim::Interfacer.ComponentModelSimulation) + +Returns the model cache of a simulation. +This is a template function that should be implemented for each component model. +""" +get_model_cache(sim::Interfacer.ComponentModelSimulation) = nothing + """ checkpoint_model_state(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") -Checkpoints the model state of a simulation to a HDF5 file at a given time, t (in seconds). +Checkpoint the model state of a simulation to a HDF5 file at a given time, t (in seconds). """ function checkpoint_model_state( sim::Interfacer.ComponentModelSimulation, @@ -35,8 +46,7 @@ function checkpoint_model_state( day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) @info "Saving checkpoint " * Interfacer.name(sim) * " model state to HDF5 on day $day second $sec" - mkpath(joinpath(output_dir, "checkpoint")) - output_file = joinpath(output_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + output_file = joinpath(output_dir, "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") checkpoint_writer = CC.InputOutput.HDF5Writer(output_file, comms_ctx) CC.InputOutput.HDF5.write_attribute(checkpoint_writer.file, "time", t) CC.InputOutput.write!(checkpoint_writer, Y, "model_state") @@ -46,30 +56,41 @@ function checkpoint_model_state( end """ - restart_model_state!(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; input_dir = "input") + checkpoint_model_cache(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") + +Checkpoint the model cache to N JLD2 files at a given time, t (in seconds), +where N is the number of MPI ranks. -Sets the model state of a simulation from a HDF5 file from a given time, t (in seconds). +Objects are saved to JLD2 files because caches are generally not ClimaCore +objects (and ClimaCore.InputOutput can only save `Field`s or `FieldVector`s). """ -function restart_model_state!( +function checkpoint_model_cache( sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; - input_dir = "input", + output_dir = "output", ) - Y = get_model_prog_state(sim) + # Move p to CPU (because we cannot save CUArrays) + p = CC.Adapt.adapt(Array, get_model_cache(sim)) day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) - input_file = joinpath(input_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + @info "Saving checkpoint " * Interfacer.name(sim) * " model cache to JLD2 on day $day second $sec" + pid = ClimaComms.mypid(comms_ctx) + output_file = joinpath(output_dir, "checkpoint_cache_$(pid)_" * Interfacer.name(sim) * "_$t.jld2") + JLD2.jldsave(output_file, cache = p) + return nothing +end - @info "Setting " Interfacer.name(sim) " state to checkpoint: $input_file, corresponding to day $day second $sec" - # open file and read - restart_reader = CC.InputOutput.HDF5Reader(input_file, comms_ctx) - Y_new = CC.InputOutput.read_field(restart_reader, "model_state") - Base.close(restart_reader) +""" + restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + +Replace the cache in `sim` with `new_cache`. - # set new state - Y .= Y_new +Component models can define new methods for this to change how cache is restored. +""" +function restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + return nothing end """ @@ -78,18 +99,115 @@ end This is a callback function that checkpoints all simulations defined in the current coupled simulation. """ function checkpoint_sims(cs::Interfacer.CoupledSimulation) + t = Dates.datetime2epochms(cs.dates.date[1]) + t0 = Dates.datetime2epochms(cs.dates.date0[1]) + time = Int((t - t0) / 1e3) + day = floor(Int, time / (60 * 60 * 24)) + sec = floor(Int, time % (60 * 60 * 24)) + output_dir = cs.dirs.checkpoints + comms_ctx = cs.comms_ctx for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - t = Dates.datetime2epochms(cs.dates.date[1]) - t0 = Dates.datetime2epochms(cs.dates.date0[1]) - Checkpointer.checkpoint_model_state( - sim, - cs.comms_ctx, - Int((t - t0) / 1e3), - output_dir = cs.dirs.checkpoints, - ) + if !isnothing(Checkpointer.get_model_prog_state(sim)) + Checkpointer.checkpoint_model_state(sim, comms_ctx, time; output_dir) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + Checkpointer.checkpoint_model_cache(sim, comms_ctx, time; output_dir) + end + end + # Checkpoint the Coupler fields + pid = ClimaComms.mypid(comms_ctx) + @info "Saving coupler fields to JLD2 on day $day second $sec" + output_file = joinpath(output_dir, "checkpoint_coupler_fields_$(pid)_$time.jld2") + # Adapt to Array move fields to the CPU + JLD2.jldsave(output_file, coupler_fields = CC.Adapt.adapt(Array, cs.fields)) +end + +""" + restart!(cs::CoupledSimulation, checkpoint_dir, checkpoint_t) + +Overwrite the content of `cs` with checkpoints in `checkpoint_dir` at time `checkpoint_t`. + +Return a true if the simulation was restarted. +""" +function restart!(cs, checkpoint_dir, checkpoint_t) + @info "Restarting from time $(checkpoint_t) and directory $(checkpoint_dir)" + pid = ClimaComms.mypid(cs.comms_ctx) + for sim in cs.model_sims + if !isnothing(Checkpointer.get_model_prog_state(sim)) + input_file_state = + output_file = joinpath(checkpoint_dir, "checkpoint_$(Interfacer.name(sim))_$(checkpoint_t).hdf5") + restart_model_state!(sim, input_file_state, cs.comms_ctx) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + input_file_cache = + joinpath(checkpoint_dir, "checkpoint_cache_$(pid)_$(Interfacer.name(sim))_$(checkpoint_t).jld2") + restart_model_cache!(sim, input_file_cache) end end + input_file_coupler_fields = joinpath(checkpoint_dir, "checkpoint_coupler_fields_$(pid)_$(checkpoint_t).jld2") + restart_coupler_fields!(cs, input_file_coupler_fields) + return true +end + +""" + restart_model_cache!(sim, input_file) + +Overwrite the content of `sim` with the cache from the `input_file`. + +It relies on `restore_cache!(sim, old_cache)`, which has to be implemented by +the component models that have a cache. +""" +function restart_model_cache!(sim, input_file) + ispath(input_file) || error("File $(input_file) not found") + # Component models are responsible for defining a method for this + restore_cache!(sim, JLD2.jldopen(input_file)["cache"]) +end + +""" + restart_model_state!(sim, input_file, comms_ctx) + +Overwrite the content of `sim` with the state from the `input_file`. +""" +function restart_model_state!(sim, input_file, comms_ctx) + ispath(input_file) || error("File $(input_file) not found") + Y = get_model_prog_state(sim) + # open file and read + CC.InputOutput.HDF5Reader(input_file, comms_ctx) do restart_reader + Y_new = CC.InputOutput.read_field(restart_reader, "model_state") + # set new state + Y .= Y_new + end + return nothing +end + +""" + restart_coupler_fields!(cs, input_file) + +Overwrite the content of the coupled simulation `cs` with the coupler fields +read from `input_file`. +""" +function restart_coupler_fields!(cs, input_file) + ispath(input_file) || error("File $(input_file) not found") + fields_read = JLD2.jldopen(input_file)["coupler_fields"] + for name in propertynames(cs.fields) + ArrayType = ClimaComms.array_type(ClimaComms.device(cs.comms_ctx)) + parent(getproperty(cs.fields, name)) .= ArrayType(parent(getproperty(fields_read, name))) + end +end + +""" + t_start_from_checkpoint(checkpoint_dir) + +Look for restart files in `checkpoint_dir`, if found, return the time of the latest. +If not found, return `nothing`. +""" +function t_start_from_checkpoint(checkpoint_dir) + isdir(checkpoint_dir) || return nothing + restart_file_rx = r"checkpoint_(\w+)_(\d+).hdf5" + restarts = filter(f -> !isnothing(match(restart_file_rx, f)), readdir(checkpoint_dir)) + isempty(restarts) && return nothing + latest_restart = last(sort_by_creation_time(restarts)) + return parse(Int, match(restart_file_rx, latest_restart)[2]) end end # module diff --git a/src/Interfacer.jl b/src/Interfacer.jl index d63f42c39e..a5d6c5bd47 100644 --- a/src/Interfacer.jl +++ b/src/Interfacer.jl @@ -6,6 +6,7 @@ This modules contains abstract types, interface templates and model stubs for co module Interfacer import SciMLBase +import ClimaComms import ClimaCore as CC import Thermodynamics as TD import SciMLBase: step!, reinit! # explicitly import to extend these functions @@ -77,6 +78,17 @@ end CoupledSimulation{FT}(args...) where {FT} = CoupledSimulation{FT, typeof.(args[1:end])...}(args...) +function Base.show(io::IO, sim::CoupledSimulation) + device_type = nameof(typeof(ClimaComms.device(sim.comms_ctx))) + return print( + io, + "Coupled Simulation\n", + "├── Running on: $(device_type)\n", + "├── Output folder: $(sim.dirs.output)\n", + "└── Current date: $(sim.dates.date[])", + ) +end + """ float_type(::CoupledSimulation) diff --git a/src/Utilities.jl b/src/Utilities.jl index 75b359a729..dbfabb27e9 100644 --- a/src/Utilities.jl +++ b/src/Utilities.jl @@ -115,7 +115,7 @@ CPU of this process since it began. function show_memory_usage() cpu_max_rss_GB = "" cpu_max_rss_GB = string(round(Sys.maxrss() / 1e9, digits = 3)) * " GiB" - @info cpu_max_rss_GB + @info "Memory in use: $(cpu_max_rss_GB)" return cpu_max_rss_GB end diff --git a/test/mpi_tests/checkpointer_mpi_tests.jl b/test/mpi_tests/checkpointer_mpi_tests.jl deleted file mode 100644 index be57f5d472..0000000000 --- a/test/mpi_tests/checkpointer_mpi_tests.jl +++ /dev/null @@ -1,47 +0,0 @@ -#= - Unit tests for ClimaCoupler Checkpointer module functions to exercise MPI - -These are in a separate testing file from the other Checkpointer unit tests so -that MPI can be enabled for testing of these functions. -=# -import Test: @test, @testset -import ClimaComms -@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends -import ClimaCore as CC -import ClimaCoupler -import ClimaCoupler: Checkpointer, Interfacer - -include(joinpath("..", "..", "experiments", "ClimaEarth", "test", "TestHelper.jl")) -import .TestHelper - -# set up MPI communications context -const comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded()) -const pid, nprocs = ClimaComms.init(comms_ctx) -@info pid -ClimaComms.barrier(comms_ctx) - -FT = Float64 -struct DummySimulation{S} <: Interfacer.AtmosModelSimulation - state::S -end -Checkpointer.get_model_prog_state(sim::DummySimulation) = sim.state -@testset "checkpoint_model_state, restart_model_state!" begin - boundary_space = TestHelper.create_space(FT, comms_ctx = comms_ctx) - t = 1 - - # old sim run - sim = DummySimulation(CC.Fields.FieldVector(T = ones(boundary_space))) - Checkpointer.checkpoint_model_state(sim, comms_ctx, t, output_dir = "test_checkpoint") - ClimaComms.barrier(comms_ctx) - - # new sim run - sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(boundary_space))) - Checkpointer.restart_model_state!(sim_new, comms_ctx, t, input_dir = "test_checkpoint") - @test sim_new.state.T == sim.state.T - - # remove checkpoint directory - ClimaComms.barrier(comms_ctx) - if ClimaComms.iamroot(comms_ctx) - rm("./test_checkpoint/", force = true, recursive = true) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index beff063424..a8bb21c21c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,3 @@ end @safetestset "FluxCalculator tests" begin include("flux_calculator_tests.jl") end -@safetestset "Checkpointer tests" begin - include("checkpointer_tests.jl") -end