From a3d4ab7d2b09a9a7a09ab2f53b16c21aaeab640a Mon Sep 17 00:00:00 2001 From: Anna Jaruga Date: Tue, 5 Mar 2024 10:42:38 -0800 Subject: [PATCH] review --- {test_ml => ext}/Common.jl | 0 test/emulator_NN.jl | 19 +++++++++++-------- test_ml/Project.toml | 12 ------------ 3 files changed, 11 insertions(+), 20 deletions(-) rename {test_ml => ext}/Common.jl (100%) delete mode 100644 test_ml/Project.toml diff --git a/test_ml/Common.jl b/ext/Common.jl similarity index 100% rename from test_ml/Common.jl rename to ext/Common.jl diff --git a/test/emulator_NN.jl b/test/emulator_NN.jl index 1c5003eb67..200a98bf43 100644 --- a/test/emulator_NN.jl +++ b/test/emulator_NN.jl @@ -62,16 +62,17 @@ function MLJFlux.build(builder::NNBuilder, rng, n_in, n_out) return Flux.Chain(layers...) end -# Path to ML testing folder (stores some common functions, downloaded data -# and the created emulator) -fpath = joinpath(pkgdir(CM), "test_ml") # Load aerosol data reading and preprocessing functions -include(joinpath(fpath, "Common.jl")) +include(joinpath(pkgdir(CM), "ext", "Common.jl")) # Get the ML model -function get_ML_model(machine_name, FT) +function get_2modal_NN_model_FT32() + FT = Float32 + machine_name = "2modal_nn_machine_naive.jls" + # If the ML model already exists load it in. # If it does not exist, train a NN + fpath = joinpath(pkgdir(CM), "test") if isfile(joinpath(fpath, machine_name)) # Read-in the saved ML model emulator_filepath = joinpath(fpath, machine_name) @@ -154,12 +155,14 @@ function test_emulator_NN(FT) ad = AM.AerosolDistribution((crs, acc)) # Get the ML model - mach = get_ML_model("2modal_nn_machine_naive.jls", FT) + mach = get_2modal_NN_model_FT32() TT.@test AA.N_activated_per_mode(mach, ap, ad, aip, tps, T, p, w, q)[1] ≈ - FT(999964.4226498074) + AA.N_activated_per_mode(ap, ad, aip, tps, T, p, w, q)[1] rtol = + 1e-5 TT.@test AA.N_activated_per_mode(mach, ap, ad, aip, tps, T, p, w, q)[2] ≈ - FT(9.998556592247702e7) + AA.N_activated_per_mode(ap, ad, aip, tps, T, p, w, q)[2] rtol = + 1e-3 end @info "Aerosol activation NN test" diff --git a/test_ml/Project.toml b/test_ml/Project.toml deleted file mode 100644 index 48af46c9e1..0000000000 --- a/test_ml/Project.toml +++ /dev/null @@ -1,12 +0,0 @@ -[deps] -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c" -CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" -MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"