From 86648c9f6d51b2fcd2be519938ab072329b70e64 Mon Sep 17 00:00:00 2001 From: Mikhail Mints Date: Sun, 20 Aug 2023 09:32:10 -0700 Subject: [PATCH] ML aerosol activation emulators and data preprocessing --- Project.toml | 15 ++ docs/src/plots/ARGdata.jl | 18 +++ docs/src/plots/ARGplots.jl | 59 ++++++-- ext/EmulatorModelsExt.jl | 211 +++++++++++++++++++++++++++ src/PreprocessAerosolData.jl | 275 +++++++++++++++++++++++++++++++++++ 5 files changed, 567 insertions(+), 11 deletions(-) create mode 100644 ext/EmulatorModelsExt.jl create mode 100644 src/PreprocessAerosolData.jl diff --git a/Project.toml b/Project.toml index cac9a0d805..f0d178b11f 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,21 @@ RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" +[weakdeps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GaussianProcesses = "891a1506-143c-57d2-908e-e1f8e92e6de9" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extensions] +EmulatorModelsExt = "DataFrames" + [compat] CLIMAParameters = "0.9" DocStringExtensions = "0.8, 0.9" diff --git a/docs/src/plots/ARGdata.jl b/docs/src/plots/ARGdata.jl index 7cee348864..fbf843f7c6 100644 --- a/docs/src/plots/ARGdata.jl +++ b/docs/src/plots/ARGdata.jl @@ -15,6 +15,8 @@ Fig1_y_obs = [ 0.34946619217081853, 0.17580071174377226, ] +Fig1_x_PySDM = [100.0, 1325.0, 2550.0, 3775.0, 5000.0] +Fig1_y_PySDM = [0.76, 0.5, 0.38, 0.3, 0.26] Fig1_x_param = [ 0, 12.891344383057117, @@ -129,6 +131,8 @@ Fig2a_y_obs = [ 0.5986159169550174, 0.4975778546712803, ] +Fig2a_x_PySDM = [100.0, 1325.0, 2550.0, 3775.0, 5000.0] +Fig2a_y_PySDM = [0.8, 0.68, 0.62, 0.6, 0.56] Fig2a_x_param = [ 12.567324955116646, 44.88330341113101, @@ -191,6 +195,8 @@ Fig2b_y_obs = [ 0.1949367088607595, 0.1329113924050633, ] +Fig2b_x_PySDM = [100.0, 1325.0, 2550.0, 3775.0, 5000.0] +Fig2b_y_PySDM = [0.38, 0.26, 0.22, 0.2, 0.18] Fig2b_x_param = [ 78.80910683012257, 113.83537653239921, @@ -233,6 +239,8 @@ Fig3a_x_obs = [ 1.0034965034965033, ] Fig3a_y_obs = [0.7586666666666667, 0.7453333333333334, 0.732, 0.72, 0.712] +Fig3a_x_PySDM = [0.1, 0.325, 0.55, 0.775, 1.0] +Fig3a_y_PySDM = [0.8, 0.78, 0.76, 0.76, 0.76] Fig3a_x_param = [ 0.10069930069930072, 0.14405594405594407, @@ -269,6 +277,8 @@ Fig3b_y_obs = [ 0.6734982332155477, 0.7074204946996467, ] +Fig3b_x_PySDM = [0.1, 0.325, 0.55, 0.775, 1.0] +Fig3b_y_PySDM = [0.38, 0.58, 0.66, 0.72, 0.76] Fig3b_x_param = [ 0.10270270270270272, 0.12000000000000002, @@ -321,6 +331,8 @@ Fig4a_y_obs = [ 0.4013333333333333, 0.12266666666666667, ] +Fig4a_x_PySDM = [0.01, 0.02659148, 0.07071068, 0.18803015, 0.5] +Fig4a_y_PySDM = [0.82, 0.78, 0.74, 0.58, 0.18] Fig4a_x_param = [ 0.010398558176237407, 0.011922822170487846, @@ -379,6 +391,8 @@ Fig4b_y_obs = [ 0.9805194805194806, 0.9844155844155846, ] +Fig4b_x_PySDM = [0.01, 0.02659148, 0.07071068, 0.18803015, 0.5] +Fig4b_y_PySDM = [0.08, 0.46, 0.86, 0.98, 1.0] Fig4b_x_param = [ 0.010822678662544178, 0.012760436250146357, @@ -441,6 +455,8 @@ Fig5a_y_obs = [ 0.9223427331887204, 0.975704989154013, ] +Fig5a_x_PySDM = [0.01, 0.04728708, 0.2236068, 1.05737126, 5.0] +Fig5a_y_PySDM = [0.12, 0.42, 0.68, 0.88, 0.98] Fig5a_x_param = [ 0.012498045610787637, 0.0206913808111479, @@ -493,6 +509,8 @@ Fig5b_y_obs = [ 0.6230508474576271, 0.7993220338983051, ] +Fig5b_x_PySDM = [0.01, 0.04728708, 0.2236068, 1.05737126, 5.0] +Fig5b_y_PySDM = [0.02, 0.1, 0.26, 0.54, 0.84] Fig5b_x_param = [ 0.011331300304886677, 0.013602279391211175, diff --git a/docs/src/plots/ARGplots.jl b/docs/src/plots/ARGplots.jl index 380d1c5b48..8591afc34c 100644 --- a/docs/src/plots/ARGplots.jl +++ b/docs/src/plots/ARGplots.jl @@ -1,15 +1,14 @@ -import Plots +import Plots as PL import CloudMicrophysics -import CLIMAParameters -import Thermodynamics - -const PL = Plots -const AM = CloudMicrophysics.AerosolModel -const AA = CloudMicrophysics.AerosolActivation -const CP = CLIMAParameters -const CMP = CloudMicrophysics.Parameters -const TD = Thermodynamics +import CloudMicrophysics: + AerosolModel as AM, + AerosolActivation as AA, + Parameters as CMP, + CommonTypes as CMT +import CLIMAParameters as CP +import Thermodynamics as TD +import DataFrames as DF FT = Float64 @@ -42,6 +41,16 @@ M_insol = 0.044 # molar mass of insol ρ_insol = 1770.0 # density of insol κ_insol = 0.0 # hygroscopicity of insol +# Schemes +ARG_scheme = CMT.ARG2000Type() +#ML_scheme = AA.MLEmulatedAerosolActivation( +# joinpath( +# pkgdir(CloudMicrophysics), +# "aerosol_activation_emulators", +# "2modal_nn_machine_naive.jls", +# ), +#) + function mass2vol(mass_mixing_ratios) if length(mass_mixing_ratios) == 2 densities = (sulfate.ρ, ρ_insol) @@ -56,7 +65,7 @@ end # Abdul-Razzak and Ghan 2000 # https://doi.org/10.1029/1999JD901161 -function make_ARG_figX(X) +function make_ARG_figX(X, scheme = ARG_scheme) p1 = PL.plot() p2 = PL.plot() @@ -181,11 +190,15 @@ function make_ARG_figX(X) x1_obs = Fig1_x_obs y1_obs = Fig1_y_obs + x1_PySDM = Fig1_x_PySDM + y1_PySDM = Fig1_y_PySDM x1_param = Fig1_x_param y1_param = Fig1_y_param x2_obs = Fig1_x_obs y2_obs = Fig1_y_obs + x2_PySDM = Fig1_x_PySDM + y2_PySDM = Fig1_y_PySDM x2_param = Fig1_x_param y2_param = Fig1_y_param @@ -237,11 +250,15 @@ function make_ARG_figX(X) x1_obs = Fig2a_x_obs y1_obs = Fig2a_y_obs + x1_PySDM = Fig2a_x_PySDM + y1_PySDM = Fig2a_y_PySDM x1_param = Fig2a_x_param y1_param = Fig2a_y_param x2_obs = Fig2b_x_obs y2_obs = Fig2b_y_obs + x2_PySDM = Fig2b_x_PySDM + y2_PySDM = Fig2b_y_PySDM x2_param = Fig2b_x_param y2_param = Fig2b_y_param @@ -295,11 +312,15 @@ function make_ARG_figX(X) x1_obs = Fig3a_x_obs y1_obs = Fig3a_y_obs + x1_PySDM = Fig3a_x_PySDM + y1_PySDM = Fig3a_y_PySDM x1_param = Fig3a_x_param y1_param = Fig3a_y_param x2_obs = Fig3b_x_obs y2_obs = Fig3b_y_obs + x2_PySDM = Fig3b_x_PySDM + y2_PySDM = Fig3b_y_PySDM x2_param = Fig3b_x_param y2_param = Fig3b_y_param @@ -350,11 +371,15 @@ function make_ARG_figX(X) x1_obs = Fig4a_x_obs y1_obs = Fig4a_y_obs + x1_PySDM = Fig4a_x_PySDM + y1_PySDM = Fig4a_y_PySDM x1_param = Fig4a_x_param y1_param = Fig4a_y_param x2_obs = Fig4b_x_obs y2_obs = Fig4b_y_obs + x2_PySDM = Fig4b_x_PySDM + y2_PySDM = Fig4b_y_PySDM x2_param = Fig4b_x_param y2_param = Fig4b_y_param @@ -407,11 +432,15 @@ function make_ARG_figX(X) x1_obs = Fig5a_x_obs y1_obs = Fig5a_y_obs + x1_PySDM = Fig5a_x_PySDM + y1_PySDM = Fig5a_y_PySDM x1_param = Fig5a_x_param y1_param = Fig5a_y_param x2_obs = Fig5b_x_obs y2_obs = Fig5b_y_obs + x2_PySDM = Fig5b_x_PySDM + y2_PySDM = Fig5b_y_PySDM x2_param = Fig5b_x_param y2_param = Fig5b_y_param @@ -454,6 +483,14 @@ function make_ARG_figX(X) ) PL.scatter!(p2, x2_obs, y2_obs, markercolor = :black) PL.plot!(p2, x2_param, y2_param, linecolor = :black) + PL.scatter!( + p1, + x1_PySDM, + y1_PySDM, + markercolor = :green, + label = "PySDM observations", + ) + PL.scatter!(p2, x2_PySDM, y2_PySDM, markercolor = :green) end end diff --git a/ext/EmulatorModelsExt.jl b/ext/EmulatorModelsExt.jl new file mode 100644 index 0000000000..2baa6c5de1 --- /dev/null +++ b/ext/EmulatorModelsExt.jl @@ -0,0 +1,211 @@ +module EmulatorModelsExt + +import MLJ +import MLJFlux +import Flux +import DataFrames as DF +import GaussianProcesses +import StatsBase +import Distributions + +import CloudMicrophysics as CM +import CloudMicrophysics.AerosolActivation as AA +import CloudMicrophysics.AerosolModel as AM +import CloudMicrophysics.Parameters as CMP + +struct NNBuilder <: MLJFlux.Builder + layer_sizes::Vector{Integer} + dropout::Vector{Float64} +end + +function MLJFlux.build(builder::NNBuilder, rng, n_in, n_out) + @assert length(builder.layer_sizes) == length(builder.dropout) + num_hidden_layers = length(builder.layer_sizes) + init = Flux.glorot_uniform(rng) + layers::Vector{Any} = [] + if num_hidden_layers == 0 + push!(layers, Flux.Dense(n_in => n_out, init = init)) + else + push!( + layers, + Flux.Dense( + n_in => builder.layer_sizes[1], + Flux.sigmoid_fast, + init = init, + ), + ) + end + for i in 1:num_hidden_layers + push!(layers, Flux.Dropout(builder.dropout[i])) + if i == num_hidden_layers + push!( + layers, + Flux.Dense(builder.layer_sizes[i] => n_out, init = init), + ) + else + push!( + layers, + Flux.Dense( + builder.layer_sizes[i] => builder.layer_sizes[i + 1], + Flux.sigmoid_fast, + init = init, + ), + ) + end + end + return Flux.Chain(layers...) +end + +mutable struct GPRegressor <: MLJ.Deterministic + num_gps::Integer + sample_size::Integer + use_ARG_weights::Bool + use_DTC::Bool + sample_size_inducing::Integer +end + +function MLJ.fit(model::GPRegressor, verbosity, X, y) + gps = [] + for i in 1:(model.num_gps) + if model.use_ARG_weights + weights = StatsBase.Weights([ + Distributions.pdf(Distributions.Normal(0.0, 0.5), x) for + x in X.mode_1_ARG_act_frac + ]) + inds = StatsBase.sample( + 1:DF.nrow(X), + weights, + model.sample_size, + replace = false, + ) + inds_inducing = StatsBase.sample( + 1:DF.nrow(X), + weights, + model.sample_size_inducing, + replace = false, + ) + else + inds = StatsBase.sample( + 1:DF.nrow(X), + model.sample_size, + replace = false, + ) + inds_inducing = StatsBase.sample( + 1:DF.nrow(X), + model.sample_size_inducing, + replace = false, + ) + end + if model.use_DTC + gp = GaussianProcesses.DTC( + Matrix(X[inds, :])', + Matrix(X[inds_inducing, :])', + y[inds], + GaussianProcesses.MeanConst(StatsBase.mean(y)), + GaussianProcesses.SEArd(fill(4.0, DF.ncol(X)), 0.0), + 2.0, + ) + else + gp = GaussianProcesses.GPA( + Matrix(X[inds, :])', + y[inds], + GaussianProcesses.MeanConst(StatsBase.mean(y)), + GaussianProcesses.SEArd(fill(4.0, DF.ncol(X)), 0.0), + GaussianProcesses.GaussLik(2.0), + ) + end + GaussianProcesses.optimize!(gp) + push!(gps, gp) + end + return gps, nothing, nothing +end + +function MLJ.predict(::GPRegressor, fitresult, Xnew) + means = reduce( + hcat, + [GaussianProcesses.predict_f(gp, Matrix(Xnew)')[1] for gp in fitresult], + ) + variances = reduce( + hcat, + [GaussianProcesses.predict_f(gp, Matrix(Xnew)')[2] for gp in fitresult], + ) + return (sum( + means ./ variances, + dims = 2, + ) ./ sum(1.0 ./ variances, dims = 2))[ + :, + 1, + ] +end + + +""" + MLEmulatedAerosolActivation + +The type for aerosol activation schemes that are emulated with an ML model +""" +struct MLAerosolActivationParameters <: CMP.ParametersType + ap::CMP.AerosolActivationParameters + machine::MLJ.Machine +end + +function MLEmulatedAerosolActivation( + ::Type{FT}, + emulator_filepath::String, + toml_dict::CP.AbstractTOMLDict = CP.create_toml_dict(FT), +) where {FT} + (; data) = toml_dict + machine = MLJ.machine(emulator_filepath) + + M_w = FT(data["molar_mass_water"]["value"]) + R = FT(data["gas_constant"]["value"]) + ρ_w = FT(data["density_liquid_water"]["value"]) + σ = FT(data["surface_tension_water"]["value"]) + g = FT(data["gravitational_acceleration"]["value"]) + f1 = FT(data["ARG2000_f_coeff_1"]["value"]) + f2 = FT(data["ARG2000_f_coeff_2"]["value"]) + g1 = FT(data["ARG2000_g_coeff_1"]["value"]) + g2 = FT(data["ARG2000_g_coeff_2"]["value"]) + p1 = FT(data["ARG2000_pow_1"]["value"]) + p2 = FT(data["ARG2000_pow_2"]["value"]) + + activation_params = CMP.AerosolActivationParameters(M_w, R, ρ_w, σ, g, f1, f2, g1, g2, p1, p2) + return MLEmulatedAerosolActivation(activation_params, machine) +end + +function AA.N_activated_per_mode( + ml::MLAerosolActivationParameters, + ad::CMP.AerosolDistributionType, + aip::CMP.AirProperties, + tps::TDP.ThermodynamicsParameters, + T::FT, + p::FT, + w::FT, + q::TD.PhasePartition{FT}, +) where {FT <: Real} + hygro = mean_hygroscopicity_parameter(ml.ap, ad) + return ntuple(Val(AM.n_modes(ad))) do i + # Model predicts activation of the first mode. So, swap each mode + # with the first mode repeatedly to predict all activations. + modes_perm = collect(1:AM.n_modes(ad)) + modes_perm[[1, i]] = modes_perm[[i, 1]] + per_mode_data = [ + (; + Symbol("mode_$(j)_N") => ad.Modes[modes_perm[j]].N, + Symbol("mode_$(j)_mean") => ad.Modes[modes_perm[j]].r_dry, + Symbol("mode_$(j)_stdev") => ad.Modes[modes_perm[j]].stdev, + Symbol("mode_$(j)_kappa") => hygro[modes_perm[j]], + ) for j in 1:AM.n_modes(ad) + ] + additional_data = (; + :velocity => w, + :initial_temperature => T, + :initial_pressure => p, + ) + X = DF.DataFrame([merge(reduce(merge, per_mode_data), additional_data)]) + max(FT(0), min(FT(1), MLJ.predict(ml.machine, X)[1])) * + ad.Modes[i].N + end +end + +end # module diff --git a/src/PreprocessAerosolData.jl b/src/PreprocessAerosolData.jl new file mode 100644 index 0000000000..38f62aa0fc --- /dev/null +++ b/src/PreprocessAerosolData.jl @@ -0,0 +1,275 @@ +module PreprocessAerosolData + +import CloudMicrophysics as CM +import ..AerosolActivation as AA +import ..AerosolModel as AM +import ..Parameters as CMP +import ..CommonTypes as CT +import ..Parameters as CMP +import ..Common as CO +import CLIMAParameters as CP +import Thermodynamics as TD +import DataFrames as DF +using DataFramesMeta + +const FT = Float64 +const APS = CMP.AbstractCloudMicrophysicsParameters + +include(joinpath(pkgdir(CM), "test", "create_parameters.jl")) +toml_dict = CP.create_toml_dict(FT; dict_type = "alias") +default_param_set = cloud_microphysics_parameters(toml_dict) + +function get_num_modes(df::DataFrame) + i = 1 + while true + if !("mode_$(i)_N" in names(df)) + return i - 1 + end + i += 1 + end +end + +function get_num_modes(data_row::NamedTuple) + i = 1 + while true + if !(Symbol("mode_$(i)_N") in keys(data_row)) + return i - 1 + end + i += 1 + end +end + +function convert_to_ARG_params(data_row::NamedTuple, param_set::APS) + num_modes = get_num_modes(data_row) + @assert num_modes > 0 + mode_Ns = [] + mode_means = [] + mode_stdevs = [] + mode_kappas = [] + w = data_row.velocity + T = data_row.initial_temperature + p = data_row.initial_pressure + for i in 1:num_modes + push!(mode_Ns, data_row[Symbol("mode_$(i)_N")]) + push!(mode_means, data_row[Symbol("mode_$(i)_mean")]) + push!(mode_stdevs, data_row[Symbol("mode_$(i)_stdev")]) + push!(mode_kappas, data_row[Symbol("mode_$(i)_kappa")]) + end + ad = AM.AerosolDistribution( + Tuple( + AM.Mode_κ( + mode_means[i], + mode_stdevs[i], + mode_Ns[i], + FT(1), + FT(1), + FT(0), + mode_kappas[i], + 1, + ) for i in 1:num_modes + ), + ) + thermo_params = CMP.thermodynamics_params(param_set) + pv0 = TD.saturation_vapor_pressure(thermo_params, T, TD.Liquid()) + vapor_mix_ratio = + pv0 / TD.Parameters.molmass_ratio(thermo_params) / (p - pv0) + q_vap = vapor_mix_ratio / (vapor_mix_ratio + 1) + q = TD.PhasePartition(q_vap, FT(0), FT(0)) + return (; ad, T, p, w, q, mode_Ns) +end + +function convert_to_ARG_params(data_row::NamedTuple) + return convert_to_ARG_params(data_row, default_param_set) +end + +function convert_to_ARG_intermediates(data_row::NamedTuple, param_set::APS) + num_modes = get_num_modes(data_row) + @assert num_modes > 0 + + (; ad, T, p, w, q) = convert_to_ARG_params(data_row, param_set) + + thermo_params = CMP.thermodynamics_params(param_set) + _grav::FT = CMP.grav(param_set) + _ρ_cloud_liq::FT = CMP.ρ_cloud_liq(param_set) + + _ϵ::FT = 1 / CMP.molmass_ratio(param_set) + R_m::FT = TD.gas_constant_air(thermo_params, q) + cp_m::FT = TD.cp_m(thermo_params, q) + + L::FT = TD.latent_heat_vapor(thermo_params, T) + p_vs::FT = TD.saturation_vapor_pressure(thermo_params, T, TD.Liquid()) + G::FT = CO.G_func(param_set, T, TD.Liquid()) / _ρ_cloud_liq + + α::FT = L * _grav * _ϵ / R_m / cp_m / T^2 - _grav / R_m / T + γ::FT = R_m * T / _ϵ / p_vs + _ϵ * L^2 / cp_m / T / p + + A::FT = AA.coeff_of_curvature(param_set, T) + ζ::FT = 2 * A / 3 * sqrt(α * w / G) + + Sm = AA.critical_supersaturation(param_set, ad, T) + η = [ + (α * w / G)^FT(3 / 2) / (FT(2 * pi) * _ρ_cloud_liq * γ * ad.Modes[i].N) for i in 1:num_modes + ] + + per_mode_intermediates = [ + (; + Symbol("mode_$(i)_stdev") => ad.Modes[i].stdev, + Symbol("mode_$(i)_η") => η[i], + Symbol("mode_$(i)_Sm") => Sm[i], + ) for i in 1:num_modes + ] + + return merge(reduce(merge, per_mode_intermediates), (; ζ)) +end + +function convert_to_ARG_intermediates(data_row::NamedTuple) + return convert_to_ARG_intermediates(data_row, default_param_set) +end + +function get_ARG_S_max(data_row::NamedTuple, param_set::APS) + (; ad, T, p, w, q) = convert_to_ARG_params(data_row, param_set) + max_supersaturation = + AA.max_supersaturation(param_set, CT.ARG2000Type(), ad, T, p, w, q) + return max_supersaturation +end + +function get_ARG_S_max(data_row::NamedTuple) + return get_ARG_S_max(data_row, default_param_set) +end + +function get_ARG_S_max(X::DataFrame, param_set::APS) + return get_ARG_S_max.(NamedTuple.(eachrow(X)), param_set) +end + +function get_ARG_S_max(X::DataFrame) + return get_ARG_S_max(X, default_param_set) +end + +function get_ARG_S_crit(data_row::NamedTuple, param_set::APS) + (; ad, T) = convert_to_ARG_params(data_row, param_set) + return AA.critical_supersaturation(param_set, ad, T) +end + +function get_ARG_S_crit(data_row::NamedTuple) + return get_ARG_S_crit(data_row, default_param_set) +end + +function get_ARG_S_crit(X::DataFrame, param_set::APS) + return get_ARG_S_crit.(NamedTuple.(eachrow(X)), param_set) +end + +function get_ARG_S_crit(X::DataFrame) + return get_ARG_S_crit(X, default_param_set) +end + +function get_ARG_act_N(data_row::NamedTuple, param_set::APS, S_max = nothing) + (; ad, T, p, w, q) = convert_to_ARG_params(data_row, param_set) + if S_max === nothing + return collect( + AA.N_activated_per_mode( + param_set, + CT.ARG2000Type(), + ad, + T, + p, + w, + q, + ), + ) + else + critical_supersaturation = AA.critical_supersaturation(param_set, ad, T) + return collect( + AA.N_activated_per_mode( + param_set, + CT.ARG2000Type(), + ad, + T, + p, + w, + q, + S_max, + critical_supersaturation, + ), + ) + end +end + +function get_ARG_act_N(data_row::NamedTuple, S_max = nothing) + return get_ARG_act_N(data_row, default_param_set, S_max) +end + +function get_ARG_act_N(X::DataFrame, param_set::APS, S_max = nothing) + return transpose( + hcat(get_ARG_act_N.(NamedTuple.(eachrow(X)), param_set, S_max)...), + ) +end + +function get_ARG_act_N(X::DataFrame, S_max = nothing) + return get_ARG_act_N(X, default_param_set, S_max) +end + +function get_ARG_act_frac(data_row::NamedTuple, param_set::APS, S_max = nothing) + (; mode_Ns) = convert_to_ARG_params(data_row, param_set) + return get_ARG_act_N(data_row, param_set, S_max) ./ mode_Ns +end + +function get_ARG_act_frac(data_row::NamedTuple, S_max = nothing) + return get_ARG_act_frac(data_row, default_param_set, S_max) +end + +function get_ARG_act_frac(X::DataFrame, param_set::APS, S_max = nothing) + return transpose( + hcat(get_ARG_act_frac.(NamedTuple.(eachrow(X)), param_set, S_max)...), + ) +end + +function get_ARG_act_frac(X::DataFrame, S_max = nothing) + return get_ARG_act_frac(X, default_param_set, S_max) +end + +function preprocess_aerosol_data(X::DataFrame, add_ARG_act_frac::Bool) + num_modes = get_num_modes(X) + if add_ARG_act_frac + X = DF.transform( + X, + AsTable(All()) => + ByRow(x -> get_ARG_act_frac(x)) => + [Symbol("mode_$(i)_ARG_act_frac") for i in 1:num_modes], + ) + end + for i in 1:num_modes + X = DF.transform( + X, + Symbol("mode_$(i)_N") => ByRow(log) => Symbol("mode_$(i)_N"), + ) + X = DF.transform( + X, + Symbol("mode_$(i)_mean") => + ByRow(log) => Symbol("mode_$(i)_mean"), + ) + end + X = DF.transform(X, :velocity => ByRow(log) => :velocity) + return X +end + +function preprocess_aerosol_data_standard(X::DataFrame) + return preprocess_aerosol_data(X, false) +end + +function preprocess_aerosol_data_with_ARG_act_frac(X::DataFrame) + return preprocess_aerosol_data(X, true) +end + +function preprocess_aerosol_data_with_ARG_intermediates(X::DataFrame) + return DF.DataFrame(convert_to_ARG_intermediates.(NamedTuple.(eachrow(X)))) +end + +function target_transform(act_frac) + return @. atanh(2.0 * 0.99 * (act_frac - 0.5)) +end + +function inverse_target_transform(transformed_act_frac) + return @. (1.0 / (2.0 * 0.99)) * tanh(transformed_act_frac) + 0.5 +end + +end # module