Skip to content

Commit

Permalink
Add AMIP calibration pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Feb 19, 2025
1 parent 5d6b473 commit e753038
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 11 deletions.
13 changes: 8 additions & 5 deletions experiments/ClimaEarth/components/atmosphere/climaatmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ include("climaatmos_extra_diags.jl")
###
### Functions required by ClimaCoupler.jl for an AtmosModelSimulation
###
struct ClimaAtmosSimulation{P, D, I} <: Interfacer.AtmosModelSimulation
struct ClimaAtmosSimulation{P, D, I, OW} <: Interfacer.AtmosModelSimulation
params::P
domain::D
integrator::I
output_writers::OW
end
Interfacer.name(::ClimaAtmosSimulation) = "ClimaAtmosSimulation"

Expand All @@ -35,7 +36,7 @@ function ClimaAtmosSimulation(atmos_config)
# By passing `parsed_args` to `AtmosConfig`, `parsed_args` overwrites the default atmos config
FT = atmos_config.parsed_args["FLOAT_TYPE"] == "Float64" ? Float64 : Float32
simulation = CA.get_simulation(atmos_config)
(; integrator) = simulation
(; integrator, output_writers) = simulation
Y = integrator.u
center_space = axes(Y.c.ρe_tot)
face_space = axes(Y.f.u₃)
Expand Down Expand Up @@ -68,7 +69,7 @@ function ClimaAtmosSimulation(atmos_config)
@. ᶠradiation_flux = CC.Geometry.WVector(FT(0))
end

sim = ClimaAtmosSimulation(integrator.p.params, spaces, integrator)
sim = ClimaAtmosSimulation(integrator.p.params, spaces, integrator, output_writers)

# DSS state to ensure we have continuous fields
dss_state!(sim)
Expand Down Expand Up @@ -342,9 +343,11 @@ function get_atmos_config_dict(coupler_dict::Dict, job_id::String, atmos_output_
atmos_toml = joinpath.(pkgdir(CA), atmos_config["toml"])
coupler_toml = joinpath.(pkgdir(ClimaCoupler), coupler_dict["coupler_toml"])
toml = isempty(coupler_toml) ? atmos_toml : coupler_toml

if haskey(atmos_config, "calibration_toml")
push!(toml, atmos_config["calibration_toml"])
end
if !isempty(toml)
@info "Overwriting Atmos parameters from input TOML file(s)"
@info "Overwriting Atmos parameters from input TOML file(s): $toml"
atmos_config = merge(atmos_config, Dict("toml" => toml))
end

Expand Down
1 change: 0 additions & 1 deletion experiments/ClimaEarth/components/land/climaland_bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import SciMLBase
import Statistics
import ClimaCore as CC
import ClimaTimeSteppers as CTS
import ClimaParams
import Thermodynamics as TD
import ClimaLand as CL
import ClimaLand.Parameters as LP
Expand Down
11 changes: 7 additions & 4 deletions experiments/ClimaEarth/setup_run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,18 @@ function solve_coupler!(cs)
return nothing
end

function setup_and_run(config_file = joinpath(pkgdir(ClimaCoupler), "config/ci_configs/amip_default.yml"))
config_dict = get_coupler_config_dict(config_file)
return setup_and_run(config_dict)
end
"""
setup_and_run(config_file = joinpath(pkgdir(ClimaCoupler), "config/ci_configs/amip_default.yml"))
This function sets up and runs the coupled model simulation specified by the
input config file. It initializes the component models, all coupler objects,
diagnostics, and conservation checks, and then runs the simulation.
"""
function setup_and_run(config_file = joinpath(pkgdir(ClimaCoupler), "config/ci_configs/amip_default.yml"))
# Parse the configuration file
config_dict = get_coupler_config_dict(config_file)
function setup_and_run(config_dict::AbstractDict)
# Initialize communication context (do this first so all printing is only on root)
comms_ctx = Utilities.get_comms_context(config_dict)
# Select the correct timestep for each component model based on which are available
Expand Down Expand Up @@ -733,5 +735,6 @@ function setup_and_run(config_file = joinpath(pkgdir(ClimaCoupler), "config/ci_c
end

# Close all diagnostics file writers
!isnothing(cs.diags_handler) && map(diag -> close(diag.output_writer), cs.diags_handler.scheduled_diagnostics)
isnothing(cs.diags_handler) || foreach(diag -> close(diag.output_writer), cs.diags_handler.scheduled_diagnostics)
isnothing(atmos_sim.output_writers) || foreach(close, atmos_sim.output_writers)
end
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ for FT in (Float32, Float64)
u = (; state_field1 = CC.Fields.ones(boundary_space), state_field2 = CC.Fields.zeros(boundary_space)),
p = (; cache_field = CC.Fields.zeros(boundary_space)),
)
sim = ClimaAtmosSimulation(nothing, nothing, integrator)
sim = ClimaAtmosSimulation(nothing, nothing, integrator, nothing)

# make field non-constant to check the impact of the dss step
coords_lat = CC.Fields.coordinate_field(sim.integrator.u.state_field2).lat
Expand Down
34 changes: 34 additions & 0 deletions experiments/calibration/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ClimaAnalysis = "29b5916a-a76c-4e73-9657-3c8fd22e65e6"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaCoreMakie = "908f55d8-4145-4867-9c14-5dad1a479e4d"
ClimaCoupler = "4ade58fe-a8da-486c-bd89-46df092ec0c7"
ClimaDiagnostics = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f"
ClimaLand = "08f4d4ce-cf43-44bb-ad95-9d2d5f413532"
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Poppler_jll = "9c32591e-4766-534b-9725-b71a8799265b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SurfaceFluxes = "49b00bb7-8bd4-4f2b-b78c-51cd0450215f"
Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ClimaCalibrate = "0.0.10"
ClimaTimeSteppers = "=0.8.1"
18 changes: 18 additions & 0 deletions experiments/calibration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ClimaCoupler Calibration Experiments

This folder contains a trivial perfect-model calibration of the atmosphere coupled with the bucket model.
The calibration uses 30-day and lat/lon averages of top-of-atmosphere shortwave
radiation to calibrate the `total_solar_irradiance` parameter in a perfect model setting.
The current run script uses the `ClimaCalibrate.SlurmManager` to add Slurm workers
which run each ensemble member in parallel.

To run this calibration on a Slurm cluster, ensure that `run_calibration.sh` is
configured for your cluster and run `sbatch run_calibration.sh`. The output will
be generated in `experiments/calibration/output`.

Components:
- run_calibration.sh: SBATCH script used to instantiate the project and run the calibration on a Slurm cluster.
- run_calibration.jl: Julia script for the overall calibration and postprocessing. Contains the expriment configuration, such as ensemble size and number of iterations.
- model_interface.jl: Contains `forward_model`, the function that gets run during calibration. This basically just uses the `setup_run` function.
- model_config.yml: Contains the configuration for the coupler
-
38 changes: 38 additions & 0 deletions experiments/calibration/model_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
FLOAT_TYPE: "Float32"
albedo_model: "CouplerAlbedo"
atmos_config_file: "config/longrun_configs/amip_target_diagedmf.yml"
checkpoint_dt: "720hours"
coupler_toml: ["toml/amip.toml"]
deep_atmosphere: false
dt: "240secs"
dt_cpl: "240secs"
dz_bottom: 100.0
energy_check: false
h_elem: 8
land_albedo_type: "map_temporal"
mode_name: "amip"
mono_surface: false
netcdf_interpolation_num_points: [90, 45, 31]
netcdf_output_at_levels: true
output_default_diagnostics: false
use_coupler_diagnostics: false
override_precip_timescale: false
rayleigh_sponge: true
start_date: "20100101"
surface_setup: "PrescribedSurface"
coupler_output_dir: "experiments/calibration/output"
t_end: "30days"
topo_smoothing: true
topography: "Earth"
turb_flux_partition: "CombinedStateFluxesMOST"
viscous_sponge: true
z_elem: 39
z_max: 60000.0
insolation: timevarying
dt_rad: 6hours
rad: clearsky
extra_atmos_diagnostics:
- reduction_time: average
short_name: rsut
period: 30days
writer: nc
17 changes: 17 additions & 0 deletions experiments/calibration/model_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import ClimaCoupler
import ClimaCalibrate
include(joinpath(pkgdir(ClimaCoupler), "experiments", "ClimaEarth", "setup_run.jl"))

function ClimaCalibrate.forward_model(iter, member)
config_file = joinpath(ClimaCalibrate.project_dir(), "model_config.yml")
config_dict = get_coupler_config_dict(config_file)

output_dir_root = config_dict["coupler_output_dir"]
# Set member parameter file
sampled_parameter_file = ClimaCalibrate.parameter_path(output_dir_root, iter, member)
config_dict["calibration_toml"] = sampled_parameter_file
# Set member output directory
member_output_dir = ClimaCalibrate.path_to_ensemble_member(output_dir_root, iter, member)
config_dict["coupler_output_dir"] = member_output_dir
return setup_and_run(config_dict)
end
124 changes: 124 additions & 0 deletions experiments/calibration/run_calibration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using Distributed
import ClimaCalibrate as CAL
using ClimaCalibrate
using ClimaAnalysis
import ClimaAnalysis: SimDir, get, slice, average_xy
import ClimaComms
import EnsembleKalmanProcesses: I, ParameterDistributions.constrained_gaussian

# Ensure ClimaComms doesn't use MPI
ENV["CLIMACOMMS_CONTEXT"] = "SINGLETON"
ClimaComms.@import_required_backends

single_member_dims = (1,)
function CAL.observation_map(iteration)
G_ensemble = Array{Float64}(undef, single_member_dims..., ensemble_size)

for m in 1:ensemble_size
member_path = CAL.path_to_ensemble_member(output_dir, iteration, m)
simdir_path = joinpath(member_path, "model_config/output_active/clima_atmos")
if isdir(simdir_path)
simdir = SimDir(simdir_path)
G_ensemble[:, m] .= process_member_data(simdir)
else
@info "No data found for member $m."
G_ensemble[:, m] .= NaN
end
end
return G_ensemble
end

function process_member_data(simdir::SimDir)
output = zeros(single_member_dims...)
days = 86_400
isempty(simdir) && return NaN

rsut = get(simdir; short_name = "rsut", reduction = "average", period = "30d")
rsut_slice = slice(average_lon(average_lat(rsut)); time = 30days).data
return rsut_slice
end

addprocs(CAL.SlurmManager())
# Make variables and the forward model available on the worker sessions
@everywhere begin
import ClimaComms, CUDA
ENV["CLIMACOMMS_DEVICE"] = "CUDA"
ENV["CLIMACOMMS_CONTEXT"] = "SINGLETON"
import ClimaCalibrate as CAL
import JLD2

experiment_dir = CAL.project_dir()
include(joinpath(experiment_dir, "model_interface.jl"))
output_dir = joinpath(experiment_dir, "output")
obs_path = joinpath(experiment_dir, "observations.jld2")
end

# Experiment Configuration
ensemble_size = 10
n_iterations = 5
noise = 0.1 * I
prior = constrained_gaussian("total_solar_irradiance", 1000, 500, 250, 2000)

# Generate observations if needed
if !isfile(obs_path)
import JLD2
@info "Generating observations"
obs_output_dir = CAL.path_to_ensemble_member(output_dir, 0, 0)
mkpath(obs_output_dir)
touch(joinpath(obs_output_dir, "parameters.toml"))
CAL.forward_model(0, 0)
observations = Vector{Float64}(undef, 1)
observations .= process_member_data(SimDir(joinpath(obs_output_dir, "amip_config/output_active/clima_atmos")))
JLD2.save_object(obs_path, observations)
end

# Initialize experiment data
@everywhere observations = JLD2.load_object(obs_path)

eki = CAL.calibrate(CAL.WorkerBackend, ensemble_size, n_iterations, observations, noise, prior, output_dir)

# Postprocessing
import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test
import CairoMakie

function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(f[1, 1], ylabel = "Parameter Value", xlabel = "Top of atmosphere radiative SW flux")

g = vec.(EKP.get_g(eki; return_array = true))
params = vec.((EKP.get_ϕ(prior, eki)))

for (gg, uu) in zip(g, params)
CairoMakie.scatter!(ax, gg, uu)
end

CairoMakie.vlines!(ax, observations, linestyle = :dash)

output = joinpath(output_dir, "scatter.png")
CairoMakie.save(output, f)
return output
end

function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(f[1, 1], ylabel = "Parameter Value", xlabel = "Iteration")
params = EKP.get_ϕ(prior, eki)
for (i, param) in enumerate(params)
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
end

output = joinpath(output_dir, "param_vs_iter.png")
CairoMakie.save(output, f)
return output
end

scatter_plot(eki)
param_versus_iter_plot(eki)

params = EKP.get_ϕ(prior, eki)
spread = map(var, params)

# Spread should be heavily decreased as particles have converged
@test last(spread) / first(spread) < 0.1
13 changes: 13 additions & 0 deletions experiments/calibration/run_calibration.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

#SBATCH --partition=a3
#SBATCH --output="run_calibration.txt"
#SBATCH --time=05:00:00
#SBATCH --ntasks=10
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=4

julia --project=experiments/calibration -e 'using Pkg; Pkg.develop(;path="."); Pkg.instantiate(;verbose=true)'

julia --project=experiments/calibration experiments/calibration/run_calibration.jl

0 comments on commit e753038

Please sign in to comment.