From b236e551168db14b23c1384ea58333824a4e1f82 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Sat, 1 Mar 2025 15:01:52 -0800 Subject: [PATCH 1/4] Add show method to CoupledSimulation --- src/Interfacer.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Interfacer.jl b/src/Interfacer.jl index d63f42c39e..a5d6c5bd47 100644 --- a/src/Interfacer.jl +++ b/src/Interfacer.jl @@ -6,6 +6,7 @@ This modules contains abstract types, interface templates and model stubs for co module Interfacer import SciMLBase +import ClimaComms import ClimaCore as CC import Thermodynamics as TD import SciMLBase: step!, reinit! # explicitly import to extend these functions @@ -77,6 +78,17 @@ end CoupledSimulation{FT}(args...) where {FT} = CoupledSimulation{FT, typeof.(args[1:end])...}(args...) +function Base.show(io::IO, sim::CoupledSimulation) + device_type = nameof(typeof(ClimaComms.device(sim.comms_ctx))) + return print( + io, + "Coupled Simulation\n", + "├── Running on: $(device_type)\n", + "├── Output folder: $(sim.dirs.output)\n", + "└── Current date: $(sim.dates.date[])", + ) +end + """ float_type(::CoupledSimulation) From ddb109d9ecf9b7c3f55e341a6f549f3390673c9c Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Mon, 3 Mar 2025 07:31:43 -0800 Subject: [PATCH 2/4] Add restart test and support for restarts This commit implements a way to restart simulations by saving both state and caches of component models, as well as the coupler fields. Given that caches are complex object, I implemented this using JLD2 files. The challenges with JLD2 files are that: - they are not MPI compatible, - they are not GPU compatible. For this reason, I have to move everything to the CPU, and have each process write to its own output. This adds a restriction: only the same number of MPI process (and the same machine) can be used for restarts. In addition to this, this approach requires component models to implement their functions to restore their caches. Something that can be improved in the future is that, ClimaAtmos is currently producing two checkpoints, one independently, and one from ClimaCoupler. This should not be needed, but it is currently needed because there's no other way to start ClimaAtmos at a different time. The other problem here is that the MPI test occasionally hangs (as it does in ClimaAtmos). --- .buildkite/pipeline.yml | 34 ++-- NEWS.md | 14 ++ Project.toml | 4 +- docs/src/checkpointer.md | 131 +++++++++++++- experiments/ClimaEarth/Manifest-v1.11.toml | 78 +++++--- experiments/ClimaEarth/cli_options.jl | 6 +- .../components/atmosphere/climaatmos.jl | 52 ++++++ .../components/land/climaland_bucket.jl | 18 ++ .../components/ocean/prescr_ocean.jl | 14 ++ .../components/ocean/prescr_seaice.jl | 14 ++ .../ClimaEarth/components/shared/restore.jl | 86 +++++++++ experiments/ClimaEarth/run_amip.jl | 2 +- .../ClimaEarth/run_cloudless_aquaplanet.jl | 2 +- .../ClimaEarth/run_cloudy_aquaplanet.jl | 2 +- .../ClimaEarth/run_cloudy_slabplanet.jl | 2 +- experiments/ClimaEarth/run_dry_held_suarez.jl | 2 +- .../ClimaEarth/run_moist_held_suarez.jl | 2 +- experiments/ClimaEarth/setup_run.jl | 150 +++++++++------- experiments/ClimaEarth/test/compare.jl | 122 +++++++++++++ experiments/ClimaEarth/test/restart.jl | 117 ++++++++++++ experiments/ClimaEarth/test/runtests.jl | 4 + experiments/ClimaEarth/user_io/arg_parsing.jl | 8 +- src/Checkpointer.jl | 170 +++++++++++++++--- src/Utilities.jl | 2 +- test/mpi_tests/checkpointer_mpi_tests.jl | 47 ----- test/runtests.jl | 3 - 26 files changed, 887 insertions(+), 199 deletions(-) create mode 100644 experiments/ClimaEarth/components/shared/restore.jl create mode 100644 experiments/ClimaEarth/test/compare.jl create mode 100644 experiments/ClimaEarth/test/restart.jl delete mode 100644 test/mpi_tests/checkpointer_mpi_tests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 92030ebfa9..e3230333c3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -67,16 +67,6 @@ steps: - group: "Unit Tests" steps: - - label: "MPI Checkpointer unit tests" - key: "checkpointer_mpi_tests" - command: "srun julia --color=yes --project=test/ test/mpi_tests/checkpointer_mpi_tests.jl" - timeout_in_minutes: 20 - env: - CLIMACOMMS_CONTEXT: "MPI" - agents: - slurm_ntasks: 2 - slurm_mem: 16GB - - label: "MPI Utilities unit tests" key: "utilities_mpi_tests" command: "srun julia --color=yes --project=test/ test/utilities_tests.jl" @@ -97,6 +87,7 @@ steps: agents: slurm_ntasks: 1 slurm_gres: "gpu:1" + slurm_mem: 24GB - group: "GPU: experiments/ClimaEarth/ unit tests and global bucket" steps: @@ -109,6 +100,27 @@ steps: slurm_gres: "gpu:1" slurm_mem: 20GB + - group: "ClimaEarth test" + steps: + - label: "ClimaEarth test" + key: "restarts" + command: "julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/runtests.jl" + agents: + slurm_mem: 16GB + + - label: "MPI restarts" + key: "mpi_restarts" + command: "srun julia --color=yes --project=experiments/ClimaEarth/ experiments/ClimaEarth/test/restart.jl" + env: + CLIMACOMMS_CONTEXT: "MPI" + timeout_in_minutes: 40 + soft_fail: + - exit_status: -1 + - exit_status: 255 + agents: + slurm_ntasks: 2 + slurm_mem: 32G + - group: "Integration Tests" steps: # SLABPLANET EXPERIMENTS @@ -218,7 +230,7 @@ steps: CLIMACOMMS_CONTEXT: "MPI" agents: slurm_ntasks: 4 - slurm_mem_per_cpu: 8GB + slurm_mem_per_cpu: 12GB # short high-res performance test - label: "Unthreaded AMIP FINE" # also reported by longruns with a flame graph diff --git a/NEWS.md b/NEWS.md index 5f9e17bffa..ea5a664e44 100644 --- a/NEWS.md +++ b/NEWS.md @@ -21,6 +21,20 @@ TOA radiation and net precipitation are added only if conservation is enabled. The coupler fields are also now stored as a ClimaCore Field of NamedTuples, rather than as a NamedTuple of ClimaCore Fields. +#### Restart simulations with JLD2 files PR[#1179](https://github.com/CliMA/ClimaCoupler.jl/pull/1179) + +`ClimaCoupler` can now use `JLD2` files to save state and cache for its model +component, allowing it to restart from saved checkpoints. Some restrictions +apply: + +- The number of MPI processes has to remain the same across checkpoints +- Restart files are generally not portable across machines, julia versions, and package versions +- Adding/changing new component models will probably require adding/changing code + +Please, refer to the +[documentation](https://clima.github.io/ClimaCoupler.jl/dev/checkpointer/) for +more information. + #### Remove extra `get_field` functions PR[#1203](https://github.com/CliMA/ClimaCoupler.jl/pull/1203) Removes the `get_field` functions for `air_density` for all models, which were unused except for the `BucketSimulation` method, which is replaced by a diff --git a/Project.toml b/Project.toml index 80ab0ed366..928df8677e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -16,9 +17,10 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" [compat] ClimaComms = "0.6.2" -ClimaCore = "0.14.23" +ClimaCore = "0.14.25" ClimaUtilities = "0.1.22" Dates = "1" +JLD2 = "0.5.11" Logging = "1" SciMLBase = "2.11" StaticArrays = "1.6" diff --git a/docs/src/checkpointer.md b/docs/src/checkpointer.md index 9b84246e3d..5b3aae7020 100644 --- a/docs/src/checkpointer.md +++ b/docs/src/checkpointer.md @@ -1,12 +1,137 @@ # Checkpointer -This module contains general functions for logging the model states and restarting simulations. The `Checkpointer` uses `ClimaCore.InputOutput` infrastructure, which allows it to handle arbitrarily distributed logging and restart setups. +## How to save and restart from checkpoints + +`ClimaCoupler` supports saving and reading simulation checkpoints. This is +useful to split a long simulation into smaller, more manageable chunks. + +Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a +`checkpoints` folder in the simulation output. See +[`Utilities.setup_output_dirs`](@ref) for more information. + +!!! known limitations + + - The number of MPI processes has to remain the same across checkpoints + - Restart files are generally not portable across machines, julia versions, and package versions + - Adding/changing new component models will probably require adding/changing code + +### Saving checkpoints + +If you are running a model (such as AMIP), chances are that you can enable +checkpointing just by setting a command-line argument; The `checkpoint_dt` +option controls how frequently a checkpoint should be produced. + +If your model does not come with this option already, you can checkpoint the +simulation by adding a callback that calls the +[`Checkpointer.checkpoint_sims`](@ref) function. + +For example, to add a callback to checkpoint every hour of simulated time, +assuming you have a `start_date` +```julia +import Dates + +import ClimaCoupler: Checkpointer, TimeManager +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule + +schedule = EveryCalendarDtSchedule(Dates.Hour(1); start_date) +checkpoint_callback = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) + +# In the coupling loop: +TimeManager.maybe_trigger_callback(checkpoint_callback, coupled_simulation, time) +``` + +### Reading checkpoints + +There are two ways to restart a simulation from checkpoints. By default, +`ClimaCoupler` tries finding suitable checkpoints and automatically use them. +Alternatively, you can specify a directory `restart_dir` and a simulation time +`restart_t` and restart from files saved in the given directory at the given +time. If the model you are running supports writing checkpoints via command-line +argument, it will probably also support reading them. In this case, the +arguments `restart_dir` and `restart_t` identify the path of the top level +directory containing all the checkpoint files and the simulated times in second. + +If the model does not support directly reading a checkpoint, the `Checkpointer` +module provides a straightforward way to add this feature. +[`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and +a `restart_t` and overwrites the content of the coupled simulation with what is +in the checkpoint. + +## Developer notes + +In theory, the state of the component models should fully determine the state of +the coupled simulation and one should be able to restart a coupled simulation +just by using the states of the component models. Unfortunately, this is +currently not the case in `ClimaCoupler`. The main reason for this is the +complex interdependencies between component models and within `ClimaAtmos` which +make the initialization step inconsistent. For example, in a coupled simulation, +the surface albedo should be determined by the surface models and used by the +atmospheric model for radiation transfer, but `ClimaAtmos` also tries to set the +surface albedo (since it has to do so when run in standalone mode). In addition +to this, `ClimaAtmos` has a large cache that has internal interdependencies that +are hard to disentangle, and changing a field might require changing some other +field in a different part of the cache. As a result, it is not easy for +`ClimaCoupler` to consistently do initialization from a cold state. To conclude, +restarting a simulation exclusively using the states of the component models is +currently impossible. + +Given that restarting a simulation from the state is impossible, `ClimaCoupler` +needs to save the states and the caches. Let us review how we use +`ClimaCore.InputOutput` and `JLD2` package to accomplish this. + +`ClimaCore.InputOutput` provides a loss-less way to save the content of certain +`ClimaCore` objects to HDF5 files. Objects saved in this way are not tied to a +particular computing device or configuration. When running with MPI, +`ClimaCore.InputOutput` are also efficiently written in parallel. + +Unfortunately, `ClimaCore.InputOutput` only supports certain objects, such as +`Field`s and `Space`s, but the cache in component models is more complex than +this and contains complex objects with highly stateful quantities (e.g., C +pointers). Because of this, model states are saved to HDF5 but caches must be +saved to JLD2 files. + +`JLD2` allows us to save more complex objects without writing specific +serialization methods for every struct. `JLD2` allows us to take a big step +forward, but there are still several challenges that need to be solved: +1. `JLD2` does not support CUDA natively. To go around this, we have to move + everything onto the CPU first. Then, when the data is read back, we have to + move it back to the GPU. +2. `JLD2` does not support MPI natively. To go around this, each process writes + its `jld2` checkpoint and reads it back. This introduces the constraint that + the number of MPI processes cannot change across restarts. +3. Some quantities are best not saved and read (for example, anything with + pointers). For this, we write a recursive function that traverses the cache + and only restores quantities of a certain type (typically, `ClimaCore` + objects) + +Point 3. adds significant amount of code and requires component models to +specify how their cache has to be restored. + +If you are adding a component model, you have to extend the +``` +Checkpointer.get_model_prog_state +Checkpointer.get_model_cache +Checkpointer.restore_cache! +``` +methods. + +`ClimaCoupler` moves objects to the CPU with `Adapt(Array, x)`. `Adapt` +traverses the object recursively, and proper `Adapt` methods have to be defined +for every object involved in the chain. The easiest way to do this is using the +`Adapt.@adapt_structure` macro, which defines a recursive Adapt for the given +object. + +Types to watch for: +- `MPI` related objects (e.g., `MPICommsContext`) +- `TimeVaryingInputs` (because they contain `NCDatasets`, which contain pointers + to files) ## Checkpointer API ```@docs ClimaCoupler.Checkpointer.get_model_prog_state - ClimaCoupler.Checkpointer.restart_model_state! - ClimaCoupler.Checkpointer.checkpoint_model_state + ClimaCoupler.Checkpointer.get_model_cache + ClimaCoupler.Checkpointer.restart! ClimaCoupler.Checkpointer.checkpoint_sims + ClimaCoupler.Checkpointer.t_start_from_checkpoint ``` diff --git a/experiments/ClimaEarth/Manifest-v1.11.toml b/experiments/ClimaEarth/Manifest-v1.11.toml index e620101a7d..58da936549 100644 --- a/experiments/ClimaEarth/Manifest-v1.11.toml +++ b/experiments/ClimaEarth/Manifest-v1.11.toml @@ -337,7 +337,7 @@ uuid = "908f55d8-4145-4867-9c14-5dad1a479e4d" version = "0.4.6" [[deps.ClimaCoupler]] -deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] +deps = ["ClimaComms", "ClimaCore", "ClimaUtilities", "Dates", "JLD2", "Logging", "SciMLBase", "StaticArrays", "SurfaceFluxes", "Thermodynamics"] path = "../.." uuid = "4ade58fe-a8da-486c-bd89-46df092ec0c7" version = "0.1.2" @@ -718,11 +718,10 @@ git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" -[[deps.Expronicon]] -deps = ["MLStyle", "Pkg", "TOML"] -git-tree-sha1 = "fc3951d4d398b5515f91d7fe5d45fc31dccb3c9b" -uuid = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" -version = "0.8.5" +[[deps.ExproniconLite]] +git-tree-sha1 = "c13f0b150373771b0fdc1713c97860f8df12e6c2" +uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" +version = "0.10.14" [[deps.Extents]] git-tree-sha1 = "063512a13dbe9c40d999c439268539aa552d1ae6" @@ -961,10 +960,10 @@ uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" version = "0.4.11" [[deps.GeometryOps]] -deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] -git-tree-sha1 = "7eaffabf21dcdc7a5e543c309b903371af5c9b07" +deps = ["CoordinateTransformations", "DataAPI", "DelaunayTriangulation", "ExactPredicates", "GeoInterface", "GeometryBasics", "GeometryOpsCore", "LinearAlgebra", "SortTileRecursiveTree", "Statistics", "Tables"] +git-tree-sha1 = "226ebac075e4a477bbaeacb4f7e720f9dce019a9" uuid = "3251bfac-6a57-4b6d-aa61-ac1fef2975ab" -version = "0.1.14" +version = "0.1.15" [deps.GeometryOps.extensions] GeometryOpsFlexiJoinsExt = "FlexiJoins" @@ -976,6 +975,12 @@ version = "0.1.14" LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" Proj = "c94c279d-25a6-4763-9509-64d165bea63e" +[[deps.GeometryOpsCore]] +deps = ["DataAPI", "GeoInterface", "Tables"] +git-tree-sha1 = "390a7ff6a89a997d6a1fa76c857490e0bd34565f" +uuid = "05efe853-fabf-41c8-927e-7063c8b9f013" +version = "0.1.2" + [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" @@ -1201,6 +1206,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "91d501cb908df6f134352ad73cde5efc50138279" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.5.11" + [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6" @@ -1225,6 +1236,12 @@ version = "1.14.1" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.Jieko]] +deps = ["ExproniconLite"] +git-tree-sha1 = "2f05ed29618da60c06a87e9c033982d4f71d0b6c" +uuid = "ae98c720-c025-4a4a-838c-29b094483192" +version = "0.2.1" + [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" @@ -1478,11 +1495,6 @@ git-tree-sha1 = "5de60bc6cb3899cd318d80d627560fae2e2d99ae" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" version = "2025.0.1+1" -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] git-tree-sha1 = "e7159031670cee777cc2840aef7a521c3603e36c" @@ -1566,6 +1578,12 @@ git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" version = "0.3.4" +[[deps.Moshi]] +deps = ["ExproniconLite", "Jieko"] +git-tree-sha1 = "453de0fc2be3d11b9b93ca4d0fddd91196dcf1ed" +uuid = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +version = "0.3.5" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.12.12" @@ -1613,9 +1631,9 @@ version = "1.1.2" [[deps.NaNStatistics]] deps = ["PrecompileTools", "Static", "StaticArrayInterface"] -git-tree-sha1 = "1529c48f8c63a815c985890e0192b006b06de528" +git-tree-sha1 = "c1a9def67b8a871b51ccf84512f173e7979c186b" uuid = "b946abbf-3ea7-4610-9019-9858bfdeaf2d" -version = "0.6.45" +version = "0.6.47" [[deps.NaturalEarth]] deps = ["Downloads", "GeoJSON", "Pkg", "Scratch"] @@ -1734,9 +1752,9 @@ version = "0.11.32" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "67186a2bc9a90f9f85ff3cc8277868961fb57cbd" +git-tree-sha1 = "cf181f0b1e6a18dfeb0ee8acc4a9d1672499626c" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.3" +version = "0.4.4" [[deps.PROJ_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "Libtiff_jll", "OpenSSL_jll", "SQLite_jll"] @@ -1935,9 +1953,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "fe9d37a17ab4d41a98951332ee8067f8dca8c4c2" +git-tree-sha1 = "a967273b0c96f9e55ccb93322ada38f43e685c49" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.29.0" +version = "3.30.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -2034,13 +2052,14 @@ uuid = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" version = "0.1.0" [[deps.SciMLBase]] -deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] -git-tree-sha1 = "ffed2507209da5b42c6881944ef41a340ab5449b" +deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "Moshi", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "2242fd564bb0202a22a91f575dc58b8820612b6b" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.74.1" +version = "2.75.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" @@ -2051,6 +2070,7 @@ version = "2.74.1" [deps.SciMLBase.weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -2298,9 +2318,9 @@ weakdeps = ["ClimaParams"] [[deps.SymbolicIndexingInterface]] deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] -git-tree-sha1 = "fd2d4f0499f6bb4a0d9f5030f5c7d61eed385e03" +git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.37" +version = "0.3.38" [[deps.TOML]] deps = ["Dates"] @@ -2473,9 +2493,9 @@ version = "1.0.0" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "ee6f41aac16f6c9a8cab34e2f7a200418b1cc1e3" +git-tree-sha1 = "b8b243e47228b4a3877f1dd6aee0c5d56db7fcf4" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.6+0" +version = "2.13.6+1" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -2550,9 +2570,9 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "622cf78670d067c738667aaa96c553430b65e269" +git-tree-sha1 = "446b23e73536f84e8037f5dce465e92275f6a308" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.7+0" +version = "1.5.7+1" [[deps.isoband_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/experiments/ClimaEarth/cli_options.jl b/experiments/ClimaEarth/cli_options.jl index 6c994653ed..4aebdad2e6 100644 --- a/experiments/ClimaEarth/cli_options.jl +++ b/experiments/ClimaEarth/cli_options.jl @@ -75,7 +75,7 @@ function argparse_settings() arg_type = String default = "10days" "--checkpoint_dt" - help = "Time interval for hourly checkpointing [\"20days\" (default)]" + help = "Time interval for checkpointing [\"20days\" (default)]" arg_type = String default = "20days" # Restart information @@ -84,9 +84,9 @@ function argparse_settings() arg_type = String default = nothing "--restart_t" - help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [0 (default)]" + help = "Time in seconds rounded to the nearest index to use at `t_start` for restarted simulation [nothing (default)]" arg_type = Int - default = 0 + default = nothing # Diagnostics information "--use_coupler_diagnostics" help = "Boolean flag indicating whether to compute and output coupler diagnostics [`true` (default), `false`]" diff --git a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl index 2d5017db71..9dee582a19 100644 --- a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl +++ b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl @@ -14,6 +14,15 @@ import ClimaUtilities.TimeManager: ITime include("climaatmos_extra_diags.jl") +if pkgversion(CA) < v"0.28.6" + # Allow cache to be moved to CPU (this is a little bit of type piracy, but we + # allow it in this particular file) + CC.Adapt.@adapt_structure CA.AtmosCache + CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel +end + +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for an AtmosModelSimulation ### @@ -103,6 +112,22 @@ function Checkpointer.get_model_prog_state(sim::ClimaAtmosSimulation) return sim.integrator.u end +function Checkpointer.get_model_cache(sim::ClimaAtmosSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache) + comms_ctx = ClimaComms.context(sim.integrator.u.c) + restore!( + Checkpointer.get_model_cache(sim), + new_cache, + comms_ctx; + ignore = Set([:rc, :params, :ghost_buffer, :hyperdiffusion_ghost_buffer, :data_handler, :graph_context]), + ) + return nothing +end + + """ Interfacer.get_field(atmos_sim::ClimaAtmosSimulation, ::Val{:radiative_energy_flux_toa}) @@ -414,6 +439,11 @@ function get_atmos_config_dict(coupler_dict::Dict, job_id::String, atmos_output_ atmos_config["output_dir_style"] = "RemovePreexisting" atmos_config["output_dir"] = atmos_output_dir + # Ensure Atmos's own checkpoints are synced up with ClimaCoupler, so that we + # can pick up from where we have left. NOTE: This should not be needed, but + # there is no easy way to initialize ClimaAtmos with a different t_start + atmos_config["dt_save_state_to_disk"] = coupler_dict["checkpoint_dt"] + # Add all extra atmos diagnostic entries into the vector of atmos diagnostics # If atmos doesn't have any diagnostics, use the extra_atmos_diagnostics from the coupler atmos_config["diagnostics"] = @@ -568,3 +598,25 @@ function FluxCalculator.water_albedo_from_atmosphere!( diffuse_albedo .= CA.surface_albedo_diffuse(α_model).(λ, μ, LinearAlgebra.norm.(CC.Fields.level(Y.c.uₕ, 1))) end + +""" + climaatmos_restart_path(output_dir_root, t) + +Look at the most recent output in `output_dir_root` and find a checkpoint for time `t`. +""" +function climaatmos_restart_path(output_dir_root, t) + isdir(output_dir_root) || error("$(output_dir_root) does not exist") + name_rx = r"output_(\d\d\d\d)" + existing_outputs = filter(x -> !isnothing(match(name_rx, x)), readdir(output_dir_root)) + + day = floor(Int, t / (60 * 60 * 24)) + sec = floor(Int, t % (60 * 60 * 24)) + + # Walk back the folders and tyr to find a checkpoint + for output in sort(existing_outputs, rev = true) + previous_folder = joinpath(output_dir_root, output) + restart_file = joinpath(previous_folder, "clima_atmos", "day$day.$sec.hdf5") + ispath(restart_file) && return restart_file + end + error("Restart file for time $t not found") +end diff --git a/experiments/ClimaEarth/components/land/climaland_bucket.jl b/experiments/ClimaEarth/components/land/climaland_bucket.jl index da97ea3eb0..a626bb2eee 100644 --- a/experiments/ClimaEarth/components/land/climaland_bucket.jl +++ b/experiments/ClimaEarth/components/land/climaland_bucket.jl @@ -1,6 +1,7 @@ import Dates import SciMLBase import Statistics +import ClimaComms import ClimaCore as CC import ClimaTimeSteppers as CTS import Thermodynamics as TD @@ -10,6 +11,8 @@ import ClimaDiagnostics as CD import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer using NCDatasets +include("../shared/restore.jl") + ### ### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation ### @@ -407,6 +410,21 @@ function make_land_domain( return CL.Domains.SphericalShell{FT}(radius, depth, nothing, nelements, npolynomial, space, fields) end +function Checkpointer.get_model_cache(sim::BucketSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + comms_ctx = ClimaComms.context(sim.model) + restore!( + old_cache, + new_cache, + comms_ctx, + ignore = Set([:rc, :params, :dss_buffer_2d, :dss_buffer_3d, :graph_context]), + ) +end + """ dss_state!(sim::BucketSimulation) diff --git a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl index 75e3b6644b..3482471adb 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl @@ -112,3 +112,17 @@ at each timestep. function Interfacer.step!(sim::PrescribedOceanSimulation, t) evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) end + +function Checkpointer.get_model_cache(sim::PrescribedOceanSimulation) + return sim.cache +end + +function Checkpointer.restore_cache!(sim::PrescribedOceanSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + for p in propertynames(old_cache) + if getproperty(old_cache, p) isa Field + ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) + parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) + end + end +end diff --git a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl index ac2a7ef7e9..3fc8390faf 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl @@ -262,3 +262,17 @@ Perform DSS on the state of a component simulation, intended to be used before the initial step of a run. This method acts on prescribed ice simulations. """ dss_state!(sim::PrescribedIceSimulation) = CC.Spaces.weighted_dss!(sim.integrator.u, sim.integrator.p.dss_buffer) + +function Checkpointer.get_model_cache(sim::PrescribedIceSimulation) + return sim.integrator.p +end + +function Checkpointer.restore_cache!(sim::PrescribedIceSimulation, new_cache) + old_cache = Checkpointer.get_model_cache(sim) + for p in propertynames(old_cache) + if getproperty(old_cache, p) isa Field + ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) + parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) + end + end +end diff --git a/experiments/ClimaEarth/components/shared/restore.jl b/experiments/ClimaEarth/components/shared/restore.jl new file mode 100644 index 0000000000..fc56e5c21a --- /dev/null +++ b/experiments/ClimaEarth/components/shared/restore.jl @@ -0,0 +1,86 @@ +# Define shared methods to allow reading back a saved cache + +import ClimaComms +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput +import StaticArrays +import NCDatasets + +""" + restore!(v1, v2, comms_ctx; ignore) + +Recursively traverse `v1` and `v2`, setting each field of `v1` with the +corresponding field in `v2`. In this, ignore all the properties that have name +within the `ignore` iterable. + +`ignore` is useful when there are stateful properties, such as live pointers. +""" +function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1, T2} + # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed + # to CuArray) + fields = filter(x -> !(x in ignore), fieldnames(T2)) + if isempty(fields) + if !Base.issingletontype(typeof(v1)) + restore!(v1, v2, comms_ctx; name, ignore) + else + v1 == v2 || error("$v1 != $v2") + end + else + # Recursive case + for p in fields + restore!(getfield(v1, p), getfield(v2, p), comms_ctx; name = "$(name).$(p)", ignore) + end + end + return nothing +end + +# Ignoring certain types that don't need to be restored +# UnionAll and DataType are infinitely recursive, so we also ignore those +function restore!( + v1::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + v2::Union{AbstractTimeVaryingInput, ClimaComms.AbstractCommsContext, ClimaComms.AbstractDevice, UnionAll, DataType}, + _comms_ctx; + name, + ignore, +) + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where {T1 <: Union{AbstractData, AbstractArray}, T2 <: Union{AbstractData, AbstractArray}} + ArrayType = parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx)) + moved_to_device = ArrayType(parent(v2)) + + parent(v1) .= moved_to_device + return nothing +end + +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name, + ignore, +) where { + T1 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, + T2 <: Union{StaticArrays.StaticArray, Number, UnitRange, Symbol}, +} + v1 == v2 || error("$name is a immutable but it inconsistent") + return nothing +end + +function restore!(v1::T1, v2::T2, comms_ctx; name, ignore) where {T1 <: Dict, T2 <: Dict} + # RRTGMP has some internal dictionaries + v1 == v2 || error("$name is inconsistent") + return nothing +end diff --git a/experiments/ClimaEarth/run_amip.jl b/experiments/ClimaEarth/run_amip.jl index c4b697e923..31a754986f 100644 --- a/experiments/ClimaEarth/run_amip.jl +++ b/experiments/ClimaEarth/run_amip.jl @@ -22,4 +22,4 @@ include("setup_run.jl") config_file = parse_commandline(argparse_settings())["config_file"] # Set up and run the coupled simulation -setup_and_run(config_file) +cs = setup_and_run(config_file) diff --git a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl index 6db69b2d45..2c4945b94a 100644 --- a/experiments/ClimaEarth/run_cloudless_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudless_aquaplanet.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -103,6 +102,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl index a5739f20c2..8e27e23373 100644 --- a/experiments/ClimaEarth/run_cloudy_aquaplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_aquaplanet.jl @@ -74,7 +74,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -100,6 +99,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_cloudy_slabplanet.jl b/experiments/ClimaEarth/run_cloudy_slabplanet.jl index ba3a061d94..e91753700e 100644 --- a/experiments/ClimaEarth/run_cloudy_slabplanet.jl +++ b/experiments/ClimaEarth/run_cloudy_slabplanet.jl @@ -82,7 +82,6 @@ dt_rad = "6hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist @@ -109,6 +108,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_dry_held_suarez.jl b/experiments/ClimaEarth/run_dry_held_suarez.jl index b8ef9a54dd..17e282e8d9 100644 --- a/experiments/ClimaEarth/run_dry_held_suarez.jl +++ b/experiments/ClimaEarth/run_dry_held_suarez.jl @@ -77,7 +77,6 @@ checkpoint_dt = "480hours" =# dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -102,6 +101,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/run_moist_held_suarez.jl b/experiments/ClimaEarth/run_moist_held_suarez.jl index 7b7b72d25b..8c13e01a27 100644 --- a/experiments/ClimaEarth/run_moist_held_suarez.jl +++ b/experiments/ClimaEarth/run_moist_held_suarez.jl @@ -80,7 +80,6 @@ checkpoint_dt = "480hours" dir_paths = Utilities.setup_output_dirs(output_dir = coupler_output_dir, comms_ctx = ClimaComms.context()) -@info(config_dict) ## namelist config_dict = Dict( @@ -106,6 +105,7 @@ config_dict = Dict( "nh_poly" => 4, # output "dt_save_to_sol" => "1days", + "checkpoint_dt" => "1days", # numerics "apply_limiter" => false, "viscous_sponge" => false, diff --git a/experiments/ClimaEarth/setup_run.jl b/experiments/ClimaEarth/setup_run.jl index 8e6ccce707..84a56475af 100644 --- a/experiments/ClimaEarth/setup_run.jl +++ b/experiments/ClimaEarth/setup_run.jl @@ -90,7 +90,6 @@ and exchanges combined fields and calculates fluxes using the selected turbulent Note that we want to implement this in a dispatchable function to allow for other forms of timestepping (e.g. leapfrog). """ - function solve_coupler!(cs) (; model_sims, Δt_cpl, tspan, comms_ctx) = cs (; atmos_sim, land_sim, ocean_sim, ice_sim) = model_sims @@ -164,6 +163,9 @@ input config file. It initializes the component models, all coupler objects, diagnostics, and conservation checks, and then runs the simulation. """ function setup_and_run(config_dict::AbstractDict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Initialize communication context (do this first so all printing is only on root) comms_ctx = Utilities.get_comms_context(config_dict) # Select the correct timestep for each component model based on which are available @@ -209,8 +211,7 @@ function setup_and_run(config_dict::AbstractDict) #and `dir_paths.checkpoints`, where restart files are saved. =# - COUPLER_OUTPUT_DIR = joinpath(output_dir_root, job_id) - dir_paths = Utilities.setup_output_dirs(output_dir = COUPLER_OUTPUT_DIR, comms_ctx = comms_ctx) + dir_paths = Utilities.setup_output_dirs(output_dir = output_dir_root, comms_ctx = comms_ctx) @info "Coupler output directory $(dir_paths.output)" @info "Coupler artifacts directory $(dir_paths.artifacts)" @info "Coupler checkpoint directory $(dir_paths.checkpoints)" @@ -231,6 +232,23 @@ function setup_and_run(config_dict::AbstractDict) Random.seed!(random_seed) @info "Random seed set to $(random_seed)" + isnothing(restart_t) && (restart_t = Checkpointer.t_start_from_checkpoint(dir_paths.checkpoints)) + isnothing(restart_dir) && (restart_dir = dir_paths.checkpoints) + should_restart = !isnothing(restart_t) && !isnothing(restart_dir) + if should_restart + if t_start isa ITime + t_start, _ = promote(ITime(restart_t), t_start) + else + t_start = restart_t + end + + # TODO: Find a cleaner way to do this instead of having a second restart + # just for atmos + atmos_config_dict["restart_file"] = climaatmos_restart_path(output_dir_root, restart_t) + + @info "Starting from t_start $(t_start)" + end + tspan = (t_start, t_end) #= @@ -579,71 +597,66 @@ function setup_and_run(config_dict::AbstractDict) If a restart directory is specified and contains output files from the `checkpoint_cb` callback, the component model states are restarted from those files. The restart directory is specified in the `config_dict` dictionary. The `restart_t` field specifies the time step at which the restart is performed. =# + should_restart && Checkpointer.restart!(cs, restart_dir, restart_t) - if !isnothing(restart_dir) - for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - Checkpointer.restart_model_state!(sim, comms_ctx, restart_t; input_dir = restart_dir) - end - end - end + if !should_restart + #= + ## Initialize Component Model Exchange - #= - ## Initialize Component Model Exchange + We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) + depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. + The concrete steps for proper initialization are: + =# - We need to ensure all models' initial conditions are shared to enable the coupler to calculate the first instance of surface fluxes. Some auxiliary variables (namely surface humidity and radiation fluxes) - depend on initial conditions of other component models than those in which the variables are calculated, which is why we need to step these models in time and/or reinitialize them. - The concrete steps for proper initialization are: - =# + # 1.coupler updates surface model area fractions + FieldExchanger.update_surface_fractions!(cs) - # 1.coupler updates surface model area fractions - FieldExchanger.update_surface_fractions!(cs) - - # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. - # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - - # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally - ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) - Interfacer.step!(land_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) - Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - - # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent - # surface fluxes using either the combined state or the partitioned state method - if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST - ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) - FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc - ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box - FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts - elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes - ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields - FluxCalculator.partitioned_turbulent_fluxes!( - cs.model_sims, - cs.fields, - cs.boundary_space, - FluxCalculator.MoninObukhovScheme(), - cs.thermo_params, - ) + # 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface. + # For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density. + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) - ## update atmos sfc_conditions for surface temperature - ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) - new_p = get_new_cache(atmos_sim, cs.fields) - CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) - atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions - end + # 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally + ## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331) + Interfacer.step!(land_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ocean_sim, tspan[1] + Δt_cpl) + Interfacer.step!(ice_sim, tspan[1] + Δt_cpl) - # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) - FieldExchanger.reinit_model_sims!(cs.model_sims) + # 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent + # surface fluxes using either the combined state or the partitioned state method + if cs.turbulent_fluxes isa FluxCalculator.CombinedStateFluxesMOST + ## import the new surface properties into the coupler (note the atmos state was also imported in step 3.) + FieldExchanger.import_combined_surface_fields!(cs.fields, cs.model_sims, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc + ## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box + FluxCalculator.combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts + elseif cs.turbulent_fluxes isa FluxCalculator.PartitionedStateFluxes + ## calculate turbulent fluxes in surface models and save the weighted average in coupler fields + FluxCalculator.partitioned_turbulent_fluxes!( + cs.model_sims, + cs.fields, + cs.boundary_space, + FluxCalculator.MoninObukhovScheme(), + cs.thermo_params, + ) + + ## update atmos sfc_conditions for surface temperature + ## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479) + new_p = get_new_cache(atmos_sim, cs.fields) + CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) ## sets T_sfc (but SF calculation not necessary - requires split functionality in CA) + atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions + end + + # 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions) + FieldExchanger.reinit_model_sims!(cs.model_sims) - # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types - # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, - # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` - # atmos receives the turbulent fluxes from the coupler. - FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) - FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + # 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types + # and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxesMOST`, + # and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes` + # atmos receives the turbulent fluxes from the coupler. + FieldExchanger.import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) + FieldExchanger.update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes) + end #= ## Precompilation of Coupling Loop @@ -653,13 +666,15 @@ function setup_and_run(config_dict::AbstractDict) beginning and end of the simulation timespan to the correct values. =# - ## run the coupled simulation for two timesteps to precompile - cs.tspan[2] = tspan[1] + Δt_cpl * 2 - solve_coupler!(cs) + if tspan[2] > 2Δt_cpl + tspan[1] + ## run the coupled simulation for two timesteps to precompile + cs.tspan[2] = tspan[1] + Δt_cpl * 2 + solve_coupler!(cs) - ## update the timespan to the correct values - cs.tspan[1] = tspan[1] + Δt_cpl * 2 - cs.tspan[2] = tspan[2] + ## update the timespan to the correct values + cs.tspan[1] = tspan[1] + Δt_cpl * 2 + cs.tspan[2] = tspan[2] + end ## Run garbage collection before solving for more accurate memory comparison to ClimaAtmos GC.gc() @@ -724,4 +739,5 @@ function setup_and_run(config_dict::AbstractDict) # Close all diagnostics file writers isnothing(cs.diags_handler) || foreach(diag -> close(diag.output_writer), cs.diags_handler.scheduled_diagnostics) isnothing(atmos_sim.output_writers) || foreach(close, atmos_sim.output_writers) + return cs end diff --git a/experiments/ClimaEarth/test/compare.jl b/experiments/ClimaEarth/test/compare.jl new file mode 100644 index 0000000000..3ec12b8674 --- /dev/null +++ b/experiments/ClimaEarth/test/compare.jl @@ -0,0 +1,122 @@ +# compare.jl provides function to recursively compare complex objects while also +# allowing for some numerical tolerance. + +import ClimaComms +import ClimaAtmos as CA +import ClimaCore +import ClimaCore: DataLayouts, Fields, Geometry +import ClimaCore.Fields: Field, FieldVector, field_values +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.Geometry: AxisTensor +import ClimaCore.Spaces: AbstractSpace +import NCDatasets + +""" + _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + +We compute the error in this way: +- when the absolute value is larger than ABS_TOL, we use the absolute error +- in the other cases, we compare the relative errors +""" +function _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + # There are some parameters, e.g. Obukhov length, for which Inf + # is a reasonable value (implying a stability parameter in the neutral boundary layer + # regime, for instance). We account for such instances with the `isfinite` function. + arr1 = Array(arr1) .* isfinite.(Array(arr1)) + arr2 = Array(arr2) .* isfinite.(Array(arr2)) + diff = abs.(arr1 .- arr2) + denominator = abs.(arr1) + error = ifelse.(denominator .> ABS_TOL, diff ./ denominator, diff) + return error +end + +""" + compare(v1, v2; name = "", ignore = Set([:rc])) + +Return whether `v1` and `v2` are the same (up to floating point errors). +`compare` walks through all the properties in `v1` and `v2` until it finds +that there are no more properties. At that point, `compare` tries to match the +resulting objects. When such objects are arrays with floating point, `compare` +defines a notion of `error` that is the following: when the absolute value is +less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative +error. The `error` is then compared against a tolerance. +Keyword arguments +================= +- `name` is used to collect the name of the property while we go recursively + over all the properties. You can pass a base name. +- `ignore` is a collection of `Symbol`s that identify properties that are + ignored when walking through the tree. This is useful for properties that + are known to be different (e.g., `output_dir`). +`:rc` is some CUDA/CuArray internal object that we don't care about +""" +function compare( + v1::T1, + v2::T2; + name = "", + ignore = Set([:rc]), +) where { + T1 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, + T2 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, +} + pass = true + return _compare(pass, v1, v2; name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T} + properties = filter(x -> !(x in ignore), propertynames(v1)) + if isempty(properties) + pass &= _compare(v1, v2; name, ignore) + else + # Recursive case + for p in properties + pass &= _compare(pass, getproperty(v1, p), getproperty(v2, p); name = "$(name).$(p)", ignore) + end + end + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T} + return print_maybe(v1 == v2, "$name differs") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol}} + # What we can safely print without filling STDOUT + return print_maybe(v1 == v2, "$name differs: $v1 vs $v2") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Number} + # We check with triple equal so that we also catch NaNs being equal + return print_maybe(v1 === v2, "$name differs: $v1 vs $v2") +end + +# We ignore NCDatasets. They contain a lot of state-ful information +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset} + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Field{<:AbstractData{<:Real}}} + return _compare(parent(v1), parent(v2); name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData} + return pass && _compare(parent(v1), parent(v2); name, ignore) +end + +# Handle views +function _compare(pass, v1::SubArray{FT}, v2::SubArray{FT}; name, ignore) where {FT <: AbstractFloat} + return pass && _compare(collect(v1), collect(v2); name, ignore) +end + +function _compare(v1::AbstractArray{FT}, v2::AbstractArray{FT}; name, ignore) where {FT <: AbstractFloat} + error = maximum(_error(v1, v2); init = zero(eltype(v1))) + return print_maybe(error <= 100eps(eltype(v1)), "$name error: $error") +end + +function _compare(pass, v1::T1, v2::T2; name, ignore) where {T1, T2} + error("v1 and v2 have different types") +end + +function print_maybe(exp, what) + exp || println(what) + return exp +end diff --git a/experiments/ClimaEarth/test/restart.jl b/experiments/ClimaEarth/test/restart.jl new file mode 100644 index 0000000000..9779391348 --- /dev/null +++ b/experiments/ClimaEarth/test/restart.jl @@ -0,0 +1,117 @@ +# This test runs a small AMIP simulation four times. +# +# - The first time the simulation is run for four steps +# - The second time the simulation is run for two steps +# - The third time the simulation is run for two steps, but restarting from the +# second simulation +# +# After all these simulations are run, we compare the first and last runs. They +# should be bit-wise identical. +# +# The content of the simulation is not the most important, but it helps if it +# has all of the complexity possible. + +import ClimaComms +ClimaComms.@import_required_backends +import ClimaUtilities.OutputPathGenerator: maybe_wait_filesystem +import YAML +import Logging +using Test + +# Uncomment the following for cleaner output (but more difficult debugging) +# Logging.disable_logging(Logging.Warn) + +include("compare.jl") +include("../setup_run.jl") + +comms_ctx = ClimaComms.context() +@info "Context: $(comms_ctx)" +ClimaComms.init(comms_ctx) + +# Make sure that all MPI processes agree on the output_loc +tmpdir = ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : "" +tmpdir = ClimaComms.bcast(comms_ctx, tmpdir) +# Sometimes the shared filesystem doesn't work properly and the folder is not +# synced across MPI processes. Let's add an additional check here. +maybe_wait_filesystem(ClimaComms.context(), tmpdir) + +default_config = parse_commandline(argparse_settings()) +base_config_file = joinpath(@__DIR__, "amip_test.yml") +base_config = YAML.load_file(base_config_file) +merge!(default_config, base_config) + +# Four steps +four_steps = deepcopy(default_config) + +four_steps["dt"] = "180secs" +four_steps["dt_cpl"] = "180secs" +four_steps["t_end"] = "720secs" +four_steps["dt_rad"] = "180secs" +four_steps["job_id"] = "four_steps" + +cs_four_steps = setup_and_run(four_steps) + +println("Simulating two steps") + +# Now, two steps plus one +two_steps = deepcopy(default_config) + +two_steps["dt"] = "180secs" +two_steps["dt_cpl"] = "180secs" +two_steps["t_end"] = "360secs" +two_steps["dt_rad"] = "180secs" +two_steps["coupler_output_dir"] = tmpdir +two_steps["checkpoint_dt"] = "360secs" +two_steps["job_id"] = "two_steps" + +# Copying since setup_and_run changes its content +cs_two_steps1 = setup_and_run(two_steps) + +println("Reading and simulating last step") +# Two additional steps +two_steps["t_end"] = "720secs" +cs_two_steps2 = setup_and_run(two_steps) + +@testset "Restarts" begin + # We put cs_four_steps.fields in a NamedTuple so that we can start the recursion in compare + @test compare((; coupler_fields = cs_four_steps.fields), (; coupler_fields = cs_two_steps2.fields)) + + @test compare(cs_four_steps.model_sims.atmos_sim.integrator.u, cs_two_steps2.model_sims.atmos_sim.integrator.u) + + @test compare( + cs_four_steps.model_sims.atmos_sim.integrator.p, + cs_two_steps2.model_sims.atmos_sim.integrator.p, + ignore = [ + :walltime_estimate, # Stateful + :output_dir, # Changes across runs + :scratch, # Irrelevant + :ghost_buffer, # Irrelevant + :hyperdiffusion_ghost_buffer, # Irrelevant + :data_handler, # Stateful + :face_clear_sw_direct_flux_dn, # Not filled by RRTGMP + :face_sw_direct_flux_dn, # Not filled by RRTGMP + :rc, # CUDA internal object + ], + ) + + @test compare(cs_four_steps.model_sims.ice_sim.integrator.u, cs_two_steps2.model_sims.ice_sim.integrator.u) + + @test compare(cs_four_steps.model_sims.land_sim.integrator.u, cs_two_steps2.model_sims.land_sim.integrator.u) + @test compare( + cs_four_steps.model_sims.land_sim.integrator.p, + cs_two_steps2.model_sims.land_sim.integrator.p, + ignore = [:dss_buffer_3d, :dss_buffer_2d, :rc], + ) + + # Ignoring SST_timevaryinginput because it contains closures (which should be + # reinitialized correctly). We have to remove it from the type, otherwise + # compare will not work + function delete(nt::NamedTuple, fieldnames...) + return (; filter(p -> !(first(p) in fieldnames), collect(pairs(nt)))...) + end + + ocean_cache_four = delete(cs_four_steps.model_sims.ocean_sim.cache, :SST_timevaryinginput) + ocean_cache_two2 = delete(cs_two_steps2.model_sims.ocean_sim.cache, :SST_timevaryinginput) + + @test compare(ocean_cache_four, ocean_cache_two2) +end diff --git a/experiments/ClimaEarth/test/runtests.jl b/experiments/ClimaEarth/test/runtests.jl index 92e9f69ef0..ccd36cb1bc 100644 --- a/experiments/ClimaEarth/test/runtests.jl +++ b/experiments/ClimaEarth/test/runtests.jl @@ -25,3 +25,7 @@ end @safetestset "AMIP test" begin include("amip_test.jl") end + +@safetestset "Restart test" begin + include("restart.jl") +end diff --git a/experiments/ClimaEarth/user_io/arg_parsing.jl b/experiments/ClimaEarth/user_io/arg_parsing.jl index 23032ddb28..918b81626a 100644 --- a/experiments/ClimaEarth/user_io/arg_parsing.jl +++ b/experiments/ClimaEarth/user_io/arg_parsing.jl @@ -44,6 +44,9 @@ This function may modify the input dictionary to remove unnecessary keys. - All arguments needed for the coupled simulation """ function get_coupler_args(config_dict::Dict) + # Make a copy so that we don't modify the original input + config_dict = copy(config_dict) + # Simulation-identifying information; Print `config_dict` if requested config_dict["print_config_dict"] && @info(config_dict) job_id = config_dict["job_id"] @@ -84,7 +87,8 @@ function get_coupler_args(config_dict::Dict) # Restart information restart_dir = config_dict["restart_dir"] - restart_t = Int(config_dict["restart_t"]) + restart_t = + isnothing(config_dict["restart_t"]) ? nothing : Int64(Utilities.time_to_seconds(config_dict["restart_t"])) # Diagnostics information use_coupler_diagnostics = config_dict["use_coupler_diagnostics"] @@ -101,7 +105,7 @@ function get_coupler_args(config_dict::Dict) conservation_softfail = config_dict["conservation_softfail"] # Output information - output_dir_root = config_dict["coupler_output_dir"] + output_dir_root = joinpath(config_dict["coupler_output_dir"], job_id) # ClimaLand-specific information land_domain_type = config_dict["land_domain_type"] diff --git a/src/Checkpointer.jl b/src/Checkpointer.jl index eec2463198..c018598405 100644 --- a/src/Checkpointer.jl +++ b/src/Checkpointer.jl @@ -7,10 +7,13 @@ module Checkpointer import ClimaComms import ClimaCore as CC +import ClimaUtilities.Utils: sort_by_creation_time import ..Interfacer import Dates -export get_model_prog_state, checkpoint_model_state, restart_model_state!, checkpoint_sims +import JLD2 + +export get_model_prog_state, checkpoint_model_state, checkpoint_sims """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) @@ -20,10 +23,18 @@ This is a template function that should be implemented for each component model. """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) = nothing +""" + get_model_cache(sim::Interfacer.ComponentModelSimulation) + +Returns the model cache of a simulation. +This is a template function that should be implemented for each component model. +""" +get_model_cache(sim::Interfacer.ComponentModelSimulation) = nothing + """ checkpoint_model_state(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") -Checkpoints the model state of a simulation to a HDF5 file at a given time, t (in seconds). +Checkpoint the model state of a simulation to a HDF5 file at a given time, t (in seconds). """ function checkpoint_model_state( sim::Interfacer.ComponentModelSimulation, @@ -35,8 +46,7 @@ function checkpoint_model_state( day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) @info "Saving checkpoint " * Interfacer.name(sim) * " model state to HDF5 on day $day second $sec" - mkpath(joinpath(output_dir, "checkpoint")) - output_file = joinpath(output_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + output_file = joinpath(output_dir, "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") checkpoint_writer = CC.InputOutput.HDF5Writer(output_file, comms_ctx) CC.InputOutput.HDF5.write_attribute(checkpoint_writer.file, "time", t) CC.InputOutput.write!(checkpoint_writer, Y, "model_state") @@ -46,30 +56,41 @@ function checkpoint_model_state( end """ - restart_model_state!(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; input_dir = "input") + checkpoint_model_cache(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output") + +Checkpoint the model cache to N JLD2 files at a given time, t (in seconds), +where N is the number of MPI ranks. -Sets the model state of a simulation from a HDF5 file from a given time, t (in seconds). +Objects are saved to JLD2 files because caches are generally not ClimaCore +objects (and ClimaCore.InputOutput can only save `Field`s or `FieldVector`s). """ -function restart_model_state!( +function checkpoint_model_cache( sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; - input_dir = "input", + output_dir = "output", ) - Y = get_model_prog_state(sim) + # Move p to CPU (because we cannot save CUArrays) + p = CC.Adapt.adapt(Array, get_model_cache(sim)) day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) - input_file = joinpath(input_dir, "checkpoint", "checkpoint_" * Interfacer.name(sim) * "_$t.hdf5") + @info "Saving checkpoint " * Interfacer.name(sim) * " model cache to JLD2 on day $day second $sec" + pid = ClimaComms.mypid(comms_ctx) + output_file = joinpath(output_dir, "checkpoint_cache_$(pid)_" * Interfacer.name(sim) * "_$t.jld2") + JLD2.jldsave(output_file, cache = p) + return nothing +end - @info "Setting " Interfacer.name(sim) " state to checkpoint: $input_file, corresponding to day $day second $sec" - # open file and read - restart_reader = CC.InputOutput.HDF5Reader(input_file, comms_ctx) - Y_new = CC.InputOutput.read_field(restart_reader, "model_state") - Base.close(restart_reader) +""" + restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + +Replace the cache in `sim` with `new_cache`. - # set new state - Y .= Y_new +Component models can define new methods for this to change how cache is restored. +""" +function restore_cache!(sim::Interfacer.ComponentModelSimulation, new_cache) + return nothing end """ @@ -78,18 +99,115 @@ end This is a callback function that checkpoints all simulations defined in the current coupled simulation. """ function checkpoint_sims(cs::Interfacer.CoupledSimulation) + t = Dates.datetime2epochms(cs.dates.date[1]) + t0 = Dates.datetime2epochms(cs.dates.date0[1]) + time = Int((t - t0) / 1e3) + day = floor(Int, time / (60 * 60 * 24)) + sec = floor(Int, time % (60 * 60 * 24)) + output_dir = cs.dirs.checkpoints + comms_ctx = cs.comms_ctx for sim in cs.model_sims - if Checkpointer.get_model_prog_state(sim) !== nothing - t = Dates.datetime2epochms(cs.dates.date[1]) - t0 = Dates.datetime2epochms(cs.dates.date0[1]) - Checkpointer.checkpoint_model_state( - sim, - cs.comms_ctx, - Int((t - t0) / 1e3), - output_dir = cs.dirs.checkpoints, - ) + if !isnothing(Checkpointer.get_model_prog_state(sim)) + Checkpointer.checkpoint_model_state(sim, comms_ctx, time; output_dir) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + Checkpointer.checkpoint_model_cache(sim, comms_ctx, time; output_dir) + end + end + # Checkpoint the Coupler fields + pid = ClimaComms.mypid(comms_ctx) + @info "Saving coupler fields to JLD2 on day $day second $sec" + output_file = joinpath(output_dir, "checkpoint_coupler_fields_$(pid)_$time.jld2") + # Adapt to Array move fields to the CPU + JLD2.jldsave(output_file, coupler_fields = CC.Adapt.adapt(Array, cs.fields)) +end + +""" + restart!(cs::CoupledSimulation, checkpoint_dir, checkpoint_t) + +Overwrite the content of `cs` with checkpoints in `checkpoint_dir` at time `checkpoint_t`. + +Return a true if the simulation was restarted. +""" +function restart!(cs, checkpoint_dir, checkpoint_t) + @info "Restarting from time $(checkpoint_t) and directory $(checkpoint_dir)" + pid = ClimaComms.mypid(cs.comms_ctx) + for sim in cs.model_sims + if !isnothing(Checkpointer.get_model_prog_state(sim)) + input_file_state = + output_file = joinpath(checkpoint_dir, "checkpoint_$(Interfacer.name(sim))_$(checkpoint_t).hdf5") + restart_model_state!(sim, input_file_state, cs.comms_ctx) + end + if !isnothing(Checkpointer.get_model_cache(sim)) + input_file_cache = + joinpath(checkpoint_dir, "checkpoint_cache_$(pid)_$(Interfacer.name(sim))_$(checkpoint_t).jld2") + restart_model_cache!(sim, input_file_cache) end end + input_file_coupler_fields = joinpath(checkpoint_dir, "checkpoint_coupler_fields_$(pid)_$(checkpoint_t).jld2") + restart_coupler_fields!(cs, input_file_coupler_fields) + return true +end + +""" + restart_model_cache!(sim, input_file) + +Overwrite the content of `sim` with the cache from the `input_file`. + +It relies on `restore_cache!(sim, old_cache)`, which has to be implemented by +the component models that have a cache. +""" +function restart_model_cache!(sim, input_file) + ispath(input_file) || error("File $(input_file) not found") + # Component models are responsible for defining a method for this + restore_cache!(sim, JLD2.jldopen(input_file)["cache"]) +end + +""" + restart_model_state!(sim, input_file, comms_ctx) + +Overwrite the content of `sim` with the state from the `input_file`. +""" +function restart_model_state!(sim, input_file, comms_ctx) + ispath(input_file) || error("File $(input_file) not found") + Y = get_model_prog_state(sim) + # open file and read + CC.InputOutput.HDF5Reader(input_file, comms_ctx) do restart_reader + Y_new = CC.InputOutput.read_field(restart_reader, "model_state") + # set new state + Y .= Y_new + end + return nothing +end + +""" + restart_coupler_fields!(cs, input_file) + +Overwrite the content of the coupled simulation `cs` with the coupler fields +read from `input_file`. +""" +function restart_coupler_fields!(cs, input_file) + ispath(input_file) || error("File $(input_file) not found") + fields_read = JLD2.jldopen(input_file)["coupler_fields"] + for name in propertynames(cs.fields) + ArrayType = ClimaComms.array_type(ClimaComms.device(cs.comms_ctx)) + parent(getproperty(cs.fields, name)) .= ArrayType(parent(getproperty(fields_read, name))) + end +end + +""" + t_start_from_checkpoint(checkpoint_dir) + +Look for restart files in `checkpoint_dir`, if found, return the time of the latest. +If not found, return `nothing`. +""" +function t_start_from_checkpoint(checkpoint_dir) + isdir(checkpoint_dir) || return nothing + restart_file_rx = r"checkpoint_(\w+)_(\d+).hdf5" + restarts = filter(f -> !isnothing(match(restart_file_rx, f)), readdir(checkpoint_dir)) + isempty(restarts) && return nothing + latest_restart = last(sort_by_creation_time(restarts)) + return parse(Int, match(restart_file_rx, latest_restart)[2]) end end # module diff --git a/src/Utilities.jl b/src/Utilities.jl index 75b359a729..dbfabb27e9 100644 --- a/src/Utilities.jl +++ b/src/Utilities.jl @@ -115,7 +115,7 @@ CPU of this process since it began. function show_memory_usage() cpu_max_rss_GB = "" cpu_max_rss_GB = string(round(Sys.maxrss() / 1e9, digits = 3)) * " GiB" - @info cpu_max_rss_GB + @info "Memory in use: $(cpu_max_rss_GB)" return cpu_max_rss_GB end diff --git a/test/mpi_tests/checkpointer_mpi_tests.jl b/test/mpi_tests/checkpointer_mpi_tests.jl deleted file mode 100644 index be57f5d472..0000000000 --- a/test/mpi_tests/checkpointer_mpi_tests.jl +++ /dev/null @@ -1,47 +0,0 @@ -#= - Unit tests for ClimaCoupler Checkpointer module functions to exercise MPI - -These are in a separate testing file from the other Checkpointer unit tests so -that MPI can be enabled for testing of these functions. -=# -import Test: @test, @testset -import ClimaComms -@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends -import ClimaCore as CC -import ClimaCoupler -import ClimaCoupler: Checkpointer, Interfacer - -include(joinpath("..", "..", "experiments", "ClimaEarth", "test", "TestHelper.jl")) -import .TestHelper - -# set up MPI communications context -const comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded()) -const pid, nprocs = ClimaComms.init(comms_ctx) -@info pid -ClimaComms.barrier(comms_ctx) - -FT = Float64 -struct DummySimulation{S} <: Interfacer.AtmosModelSimulation - state::S -end -Checkpointer.get_model_prog_state(sim::DummySimulation) = sim.state -@testset "checkpoint_model_state, restart_model_state!" begin - boundary_space = TestHelper.create_space(FT, comms_ctx = comms_ctx) - t = 1 - - # old sim run - sim = DummySimulation(CC.Fields.FieldVector(T = ones(boundary_space))) - Checkpointer.checkpoint_model_state(sim, comms_ctx, t, output_dir = "test_checkpoint") - ClimaComms.barrier(comms_ctx) - - # new sim run - sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(boundary_space))) - Checkpointer.restart_model_state!(sim_new, comms_ctx, t, input_dir = "test_checkpoint") - @test sim_new.state.T == sim.state.T - - # remove checkpoint directory - ClimaComms.barrier(comms_ctx) - if ClimaComms.iamroot(comms_ctx) - rm("./test_checkpoint/", force = true, recursive = true) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index beff063424..a8bb21c21c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,3 @@ end @safetestset "FluxCalculator tests" begin include("flux_calculator_tests.jl") end -@safetestset "Checkpointer tests" begin - include("checkpointer_tests.jl") -end From 5903070a6d1716a317d2c6c90458c7fd538db4ac Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Thu, 6 Mar 2025 14:06:12 -0800 Subject: [PATCH 3/4] update ocean, sea ice every 1 day --- NEWS.md | 17 +++++++++++ config/ci_configs/amip_component_dts.yml | 2 +- experiments/ClimaEarth/cli_options.jl | 4 +-- experiments/ClimaEarth/user_io/arg_parsing.jl | 29 +++++++------------ 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/NEWS.md b/NEWS.md index ea5a664e44..a1d6dbe41e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,23 @@ ClimaCoupler.jl Release Notes ### ClimaCoupler features +#### Shared component `dt` can be overwritten for individual components +Previously, we required that the user either specify a shared `dt` to be +used by all component models, or specify values for all component models +(`dt_atmos`, `dt_ocean`, `dt_seaice`, `dt_land`). If fewer than 4 +model-specific timesteps were provided, they would be discarded and +`dt` would be used uniformly instead. After this PR, if a user provides +fewer than 4 model-specific timesteps, they will be used for those models, +and the generic `dt` will be used for any models that don't have a more +specific timestep. +This makes choosing the timesteps simpler and allows us to easily set +specific `dt`s only for the models we're interested in. + +This PR also changes the prescribed ocean and sea ice simulations +to update the stored SST/SIC based on a daily schedule. Now, the +input data will be interpolated from monthly to daily instead of +to every timestep. + #### Add default `get_field` methods for surface models PR[#1210](https://github.com/CliMA/ClimaCoupler.jl/pull/1210) Add default methods for `get_field` methods that are commonly not extended for surface models. These return reasonable default diff --git a/config/ci_configs/amip_component_dts.yml b/config/ci_configs/amip_component_dts.yml index d8770b9812..67c5168319 100644 --- a/config/ci_configs/amip_component_dts.yml +++ b/config/ci_configs/amip_component_dts.yml @@ -3,7 +3,7 @@ co2: "maunaloa" dt_atmos: "150secs" dt_cpl: "150secs" dt_land: "50secs" -dt_ocean: "30secs" +dt_ocean: "300secs" dt_rad: "1hours" dt_save_to_sol: "1days" dt_seaice: "37.5secs" diff --git a/experiments/ClimaEarth/cli_options.jl b/experiments/ClimaEarth/cli_options.jl index 4aebdad2e6..7cb987ebef 100644 --- a/experiments/ClimaEarth/cli_options.jl +++ b/experiments/ClimaEarth/cli_options.jl @@ -51,11 +51,11 @@ function argparse_settings() arg_type = String default = "20000101" "--dt_cpl" - help = "Coupling time step in seconds [400 (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" + help = "Coupling time step in seconds [400secs (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" arg_type = String default = "400secs" "--dt" - help = "Component model time step [allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" + help = "Component model time step [400secs (default); allowed formats: \"Nsecs\", \"Nmins\", \"Nhours\", \"Ndays\", \"Inf\"]" arg_type = String default = "400secs" "--dt_atmos" diff --git a/experiments/ClimaEarth/user_io/arg_parsing.jl b/experiments/ClimaEarth/user_io/arg_parsing.jl index 918b81626a..bbd2d40c33 100644 --- a/experiments/ClimaEarth/user_io/arg_parsing.jl +++ b/experiments/ClimaEarth/user_io/arg_parsing.jl @@ -232,27 +232,20 @@ function parse_component_dts!(config_dict) # Specify component model names component_dt_names = ["dt_atmos", "dt_land", "dt_ocean", "dt_seaice"] component_dt_dict = Dict{String, typeof(Δt_cpl)}() - # check if all component dt's are specified - if all(key -> !isnothing(config_dict[key]), component_dt_names) - # when all component dt's are specified, ignore the dt field - if haskey(config_dict, "dt") - @warn "Removing dt in favor of individual component dt's" - delete!(config_dict, "dt") - end - for key in component_dt_names + + @assert all(key -> !isnothing(config_dict[key]), component_dt_names) || haskey(config_dict, "dt") "all model-specific timesteps (dt_atmos, dt_land, dt_ocean, and dt_seaice) or a generic timestep (dt) must be specified" + + for key in component_dt_names + if haskey(config_dict, key) + # Check if the component timestep is specified component_dt = Float64(Utilities.time_to_seconds(config_dict[key])) @assert isapprox(Δt_cpl % component_dt, 0.0) "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n $key = $component_dt" component_dt_dict[key] = component_dt - end - else - # when not all component dt's are specified, use the dt field - @assert haskey(config_dict, "dt") "dt or (dt_atmos, dt_land, dt_ocean, and dt_seaice) must be specified" - for key in component_dt_names - if !isnothing(config_dict[key]) - @warn "Removing $key from config in favor of dt because not all component dt's are specified" - end - delete!(config_dict, key) - component_dt_dict[key] = Float64(Utilities.time_to_seconds(config_dict["dt"])) + else + # If the component timestep is not specified, use the generic timestep + dt = Float64(Utilities.time_to_seconds(config_dict["dt"])) + @assert isapprox(Δt_cpl % dt, 0.0) "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n dt = $dt" + component_dt_dict[key] = dt end end config_dict["component_dt_dict"] = component_dt_dict From ed59e27bf6acfa160ec57a33d5d86815ec48b1ea Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Thu, 6 Mar 2025 14:41:34 -0800 Subject: [PATCH 4/4] add callbacks to read in prescribed data daily --- .../components/ocean/prescr_ocean.jl | 9 ++++-- .../components/ocean/prescr_seaice.jl | 28 ++++++++++++++++--- experiments/ClimaEarth/setup_run.jl | 2 +- experiments/ClimaEarth/user_io/arg_parsing.jl | 2 +- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl index 3482471adb..742a052dce 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl @@ -3,6 +3,7 @@ import ClimaUtilities.ClimaArtifacts: @clima_artifact import Interpolations # triggers InterpolationsExt in ClimaUtilities import Thermodynamics as TD import ClimaCoupler: Checkpointer, FieldExchanger, Interfacer +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule """ PrescribedOceanSimulation{C} @@ -84,6 +85,8 @@ function PrescribedOceanSimulation( SST_init = zeros(space) evaluate!(SST_init, SST_timevaryinginput, t_start) + SST_schedule = CD.Schedules.EveryCalendarDtSchedule(TimeManager.time_to_period("1days"); start_date = date0) + # Create the cache cache = (; T_sfc = SST_init, @@ -97,6 +100,7 @@ function PrescribedOceanSimulation( phase = TD.Liquid(), thermo_params = thermo_params, SST_timevaryinginput = SST_timevaryinginput, + SST_schedule = SST_schedule, ) return PrescribedOceanSimulation(cache) end @@ -107,10 +111,11 @@ end Interfacer.step!(sim::PrescribedOceanSimulation, t) Update the cached surface temperature field using the prescribed data -at each timestep. +at each timestep. This doesn't happen at every timestep, +but only when the SST data is scheduled to be updated. """ function Interfacer.step!(sim::PrescribedOceanSimulation, t) - evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) + sim.cache.SST_schedule(t) && evaluate!(sim.cache.T_sfc, sim.cache.SST_timevaryinginput, t) end function Checkpointer.get_model_cache(sim::PrescribedOceanSimulation) diff --git a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl index 3fc8390faf..e90c9b4849 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl @@ -1,5 +1,6 @@ import SciMLBase import ClimaCore as CC +import ClimaDiagnostics as CD import ClimaTimeSteppers as CTS import ClimaUtilities.TimeVaryingInputs: TimeVaryingInput, evaluate! import ClimaUtilities.ClimaArtifacts: @clima_artifact @@ -143,8 +144,20 @@ function PrescribedIceSimulation( tspan = Float64.(tspan) saveat = Float64.(saveat) end + + # Set up a callback to read in SST data daily + SIC_schedule = CD.Schedules.EveryCalendarDtSchedule(TimeManager.time_to_period("1days"); start_date = date0) + SIC_update_cb = TimeManager.Callback(SIC_schedule, read_sic_data!) + problem = SciMLBase.ODEProblem(ode_function, Y, tspan, (; cache..., params = params)) - integrator = SciMLBase.init(problem, ode_algo, dt = dt, saveat = saveat, adaptive = false) + integrator = SciMLBase.init( + problem, + ode_algo, + dt = dt, + saveat = saveat, + adaptive = false, + callback = SciMLBase.CallbackSet(SIC_update_cb), + ) sim = PrescribedIceSimulation(params, space, integrator) @@ -153,6 +166,16 @@ function PrescribedIceSimulation( return sim end +""" + read_sic_data!(integrator) + +Read in the sea ice concentration data at the current time step. +This function is intended to be used within a callback +""" +function read_sic_data!(integrator) + evaluate!(integrator.p.area_fraction, integrator.p.SIC_timevaryinginput, integrator.t) +end + # extensions required by Interfacer Interfacer.get_field(sim::PrescribedIceSimulation, ::Val{:area_fraction}) = sim.integrator.p.area_fraction Interfacer.get_field(sim::PrescribedIceSimulation, ::Val{:roughness_buoyancy}) = sim.integrator.p.params.z0b @@ -237,9 +260,6 @@ function ice_rhs!(dY, Y, p, t) FT = eltype(Y) params = p.params - # Update the cached area fraction with the current SIC - evaluate!(p.area_fraction, p.SIC_timevaryinginput, t) - # Overwrite ice fraction with the static land area fraction anywhere we have nonzero land area # max needed to avoid Float32 errors (see issue #271; Heisenbug on HPC) @. p.area_fraction = max(min(p.area_fraction, FT(1) - p.land_fraction), FT(0)) diff --git a/experiments/ClimaEarth/setup_run.jl b/experiments/ClimaEarth/setup_run.jl index 84a56475af..4ff8ccd5c3 100644 --- a/experiments/ClimaEarth/setup_run.jl +++ b/experiments/ClimaEarth/setup_run.jl @@ -527,7 +527,7 @@ function setup_and_run(config_dict::AbstractDict) NB: Eventually, we will call all of radiation from the coupler, in addition to the albedo calculation. =# schedule_checkpoint = EveryCalendarDtSchedule(TimeManager.time_to_period(checkpoint_dt); start_date = date0) - checkpoint_cb = TimeManager.TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) + checkpoint_cb = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims) if sim_mode <: AMIPMode schedule_albedo = EveryCalendarDtSchedule(TimeManager.time_to_period(dt_rad); start_date = date0) diff --git a/experiments/ClimaEarth/user_io/arg_parsing.jl b/experiments/ClimaEarth/user_io/arg_parsing.jl index bbd2d40c33..95c63c0d7f 100644 --- a/experiments/ClimaEarth/user_io/arg_parsing.jl +++ b/experiments/ClimaEarth/user_io/arg_parsing.jl @@ -236,7 +236,7 @@ function parse_component_dts!(config_dict) @assert all(key -> !isnothing(config_dict[key]), component_dt_names) || haskey(config_dict, "dt") "all model-specific timesteps (dt_atmos, dt_land, dt_ocean, and dt_seaice) or a generic timestep (dt) must be specified" for key in component_dt_names - if haskey(config_dict, key) + if !isnothing(config_dict[key]) # Check if the component timestep is specified component_dt = Float64(Utilities.time_to_seconds(config_dict[key])) @assert isapprox(Δt_cpl % component_dt, 0.0) "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n $key = $component_dt"