Skip to content

Commit

Permalink
Merge pull request #339 from CliMA/ap/solver
Browse files Browse the repository at this point in the history
P3 shape solver update with approximate mu
  • Loading branch information
anastasia-popova authored Mar 6, 2024
2 parents 7b0de4f + 13b011d commit a7d088d
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 123 deletions.
21 changes: 16 additions & 5 deletions docs/src/P3Scheme.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ N_{ice} = \int_{0}^{\infty} \! N'(D) \mathrm{d}D = \int_{0}^{\infty} \! N_{0} D^

``q_{ice}`` depends on the variable mass-size relation ``m(D)`` defined above.
We solve for ``q_{ice}`` in a piece-wise fashion defined by the same thresholds as ``m(D)``.
As a result ``q\_{ice}`` can be expressed as a sum of inclomplete gamma functions.
As a result ``q_{ice}`` can be expressed as a sum of inclomplete gamma functions.
and the shape parameters are found using iterative solver.

| condition(s) | ``q_{ice} = \int \! m(D) N'(D) \mathrm{d}D`` | gamma representation |
Expand All @@ -133,18 +133,29 @@ As a result ``q\_{ice}`` can be expressed as a sum of inclomplete gamma function
where ``\Gamma \,(a, z) = \int_{z}^{\infty} \! t^{a - 1} e^{-t} \mathrm{d}D``
and ``\Gamma \,(a) = \Gamma \,(a, 0)`` for simplicity.

An initial guess for the non-linear solver is found by approximating the gamma functions as a simple power function.
Within our solver, we approximate ``\mu`` from q/N and keep it constant throughout the solving step. We approximate ``\mu`` by an exponential function given by the q/N points corresponding to ``\mu = 6`` and ``\mu = 0``. This is shown below as well as how this affects the solvers ``\lambda`` solutions.

```@example
include("plots/P3LambdaErrorPlots.jl")
```
![](MuApprox.svg)

An initial guess for the non-linear solver is found by approximating the gamma functions as a simple linear function from log(q\N) to log(``\lambda``).

```@example
include("plots/P3ShapeSolverPlots.jl")
```
![](SolverInitialGuess.svg)

This equation is given by ``(log(q_{approx}) - log(q1)) = slope (log(\lambda) - log(p1))``. Solving for ``q_{approx}`` we get `` q_{approx} = q_1 \frac{\lambda}{p_1} ^{slope}`` where `` slope = \frac{log(q1) - log(q2)}{log(p1) - log(p2)}``, ``p1`` and ``p2`` are defining ``\lambda`` values of the estimated line (we use p1 = 1e2, p2 = 1e6), ``q1 = q(p1)`` and ``q2 = q(p2)`` are the corresponding calculated q values for the given ``F_r`` and ``\rho_r`` values.
Let ``x = log(q/N)`` and ``y = log(\lambda)``. This equation is given by ``(x - x_1) = slope (y - y_1)`` where `` slope = \frac{x_1 - x_2}{y_1 - y_2}``, ``y_1`` and ``y_2`` are defining ``log(\lambda)`` values of the estimated line decided off of the ``q/N`` value (described below).

We use this approximation to calculate a ``\lambda_{guess}`` value which will set our initial guess. Solving for ``\lambda`` in the power function we get ``\lambda_{guess} = p1 (\frac{q}{q1})^{(\frac{log(q1)-log(q2)}{log(p1)-log(p2)})}``. Thus, given any q we can calculate a ``\lambda`` around which to expect the true solved ``\lambda`` value.
| q/N | ``y_1`` | ``y_2`` |
|:-------------------------------|:-----------------------|:---------------------|
| ``q/N >= 10^-8`` | ``1`` | ``6 * 10^3`` |
| ``2 * 10^9 <= q/N < 10^-8`` | ``6 * 10^3`` | ``3 * 10^4`` |
| ``q/N < 2 * 10^9`` | ``4 * 10^4`` | ``10^6`` |

For small values of ``\lambda_{guess}`` it was found to be more efficient to use constant initial guesses.
We use this approximation to calculate a ``\lambda_{guess}`` value which will set our initial guess. Solving for ``\lambda`` in the power function we get ``\lambda_{guess} = \lambda _1 (\frac{q}{q_1})^{(\frac{y_1 - y_2}{x_1 - x_2})}``. Thus, given any q we can calculate a ``\lambda`` around which to expect the true solved ``\lambda`` value.

Using this approach we get the following relative errors in our solved ``\lambda`` vs the expected ``lambda`` within the solver.

Expand Down
131 changes: 113 additions & 18 deletions docs/src/plots/P3LambdaErrorPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,30 @@ function λ_diff(F_r::FT, ρ_r::FT, N::FT, λ_ex::FT, p3::PSP3) where {FT}

# Find the P3 scheme thresholds
th = P3.thresholds(p3, ρ_r, F_r)
# Get μ corresponding to λ
μ = P3.DSD_μ(p3, λ_ex)
# Convert λ to ensure it remains positive
x = log(λ_ex)
# Compute mass density based on input shape parameters
q_calc = P3.q_gamma(p3, F_r, N, x, th)
q_calc = N * P3.q_over_N_gamma(p3, F_r, x, μ, th)

(λ_calculated,) = P3.distribution_parameter_solver(p3, q_calc, N, ρ_r, F_r)
return abs(λ_ex - λ_calculated)
end

function get_errors(
p3::PSP3,
λ_min::FT,
λ_max::FT,
log10_λ_min::FT,
log10_λ_max::FT,
F_r_min::FT,
F_r_max::FT,
ρ_r::FT,
N::FT,
λSteps::Int,
F_rSteps::Int,
) where {FT}
λs = range(FT(λ_min), stop = λ_max, length = λSteps)
logλs = range(FT(log10_λ_min), stop = log10_λ_max, length = λSteps)
λs = [10^logλ for logλ in logλs]
F_rs = range(F_r_min, stop = F_r_max, length = F_rSteps)
E = zeros(λSteps, F_rSteps)
min = Inf
Expand All @@ -46,8 +49,7 @@ function get_errors(
λ = λs[i]
F_r = F_rs[j]

diff = λ_diff(F_r, ρ_r, N, λ, p3)
er = log(diff / λ)
er = log(λ_diff(F_r, ρ_r, N, λ, p3) / λ)

E[i, j] = er

Expand All @@ -60,13 +62,13 @@ function get_errors(

end
end
return (λs = λs, F_rs = F_rs, E = E, min = min, max = max)
return (; λs, F_rs, E, min, max)
end

function plot_relerrors(
N::FT,
λ_min::FT,
λ_max::FT,
log10_λ_min::FT,
log10_λ_max::FT,
F_r_min::FT,
F_r_max::FT,
ρ_r_min::FT,
Expand Down Expand Up @@ -97,12 +99,13 @@ function plot_relerrors(
),
width = 400,
height = 300,
xscale = log10,
)

(λs, F_rs, E, min, max) = get_errors(
p3,
λ_min,
λ_max,
log10_λ_min,
log10_λ_max,
F_r_min,
F_r_max,
ρ,
Expand All @@ -111,13 +114,23 @@ function plot_relerrors(
F_rSteps,
)

Plt.heatmap!(λs, F_rs, E)
Plt.Colorbar(
f[x, y + 1],
limits = (min, max),
limits = (-10, 0),
colormap = :viridis,
flipaxis = false,
highclip = :red,
lowclip = :indigo,
)
Plt.heatmap!(
λs,
F_rs,
E,
colorrange = (-10, 0),
highclip = :red,
lowclip = :indigo,
)


y = y + 2
if (y > 6)
Expand All @@ -130,12 +143,89 @@ function plot_relerrors(
Plt.save("P3LambdaHeatmap.svg", f)
end

function μ_approximation_effects(F_r::FT, ρ_r::FT) where {FT}

f = Plt.Figure()

ax1 = Plt.Axis(
f[1, 1],
xlabel = "q/N",
ylabel = "μ",
title = string("μ vs q/N for F_r = ", F_r, " ρ_r = ", ρ_r),
width = 400,
height = 300,
xscale = log10,
)

ax2 = Plt.Axis(
f[1, 2],
xlabel = "λ",
ylabel = "μ",
title = string("μ vs λ for F_r = ", F_r, " ρ_r = ", ρ_r),
width = 400,
height = 300,
xscale = log10,
)

ax3 = Plt.Axis(
f[1, 3],
xlabel = "λ",
ylabel = "q/N",
title = string("q/N vs λ for F_r = ", F_r, " ρ_r = ", ρ_r),
width = 400,
height = 300,
xscale = log10,
yscale = log10,
)

Plt.linkxaxes!(ax2, ax3)
Plt.linkyaxes!(ax1, ax2)

numpts = 100

# Set up vectors
th = P3.thresholds(p3, ρ_r, F_r)
log_λs = range(FT(3.6), stop = FT(4.6), length = numpts)
λs = [10^log_λ for log_λ in log_λs]
μs = [P3.DSD_μ(p3, λ) for λ in λs]

μs_approx = [FT(0) for λ in λs]
qs = [FT(0) for λ in λs]
λ_solved = [FT(0) for λ in λs]

for i in 1:numpts
q = P3.q_over_N_gamma(p3, F_r, log(λs[i]), μs[i], th)
qs[i] = q
N = FT(1e6)
(L, N) = P3.distribution_parameter_solver(p3, q * N, N, ρ_r, F_r)
λ_solved[i] = L
μs_approx[i] = P3.DSD_μ_approx(p3, N * q, N, ρ_r, F_r)
end

# Plot
Plt.lines!(ax3, λs, qs, label = "true distribution")
Plt.lines!(ax3, λ_solved, qs, label = "approximated")

Plt.lines!(ax1, qs, μs, label = "true distribution")
Plt.lines!(ax1, qs, μs_approx, label = "approximated")
Plt.lines!(ax2, λs, μs, label = "true distribution")
Plt.lines!(ax2, λ_solved, μs_approx, label = "approximated")

Plt.axislegend(ax1, position = :lb)
Plt.axislegend(ax2, position = :lt)
Plt.axislegend(ax3, position = :lb)

Plt.resize_to_layout!(f)
Plt.save("MuApprox.svg", f)

end

# Define variables for heatmap relative error plots:

λ_min = FT(1e2)
λ_max = FT(1e6)
log10_λ_min = FT(1)
log10_λ_max = FT(6)
F_r_min = FT(0)
F_r_max = FT(1 - eps(FT))
F_r_max = FT(0.9)
ρ_r_min = FT(100)
ρ_r_max = FT(900)
N = FT(1e8)
Expand All @@ -146,8 +236,8 @@ NumPlots = 9

plot_relerrors(
N,
λ_min,
λ_max,
log10_λ_min,
log10_λ_max,
F_r_min,
F_r_max,
ρ_r_min,
Expand All @@ -157,3 +247,8 @@ plot_relerrors(
NumPlots,
p3,
)

F_r = FT(0.5)
ρ_r = FT(500)

μ_approximation_effects(F_r, ρ_r)
94 changes: 62 additions & 32 deletions docs/src/plots/P3ShapeSolverPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,74 @@ function guess_value(λ::FT, p1::FT, p2::FT, q1::FT, q2::FT)
return q1 */ p1)^((log(q1) - log(q2)) / (log(p1) - log(p2)))
end

function lambda_guess_plot(F_r::FT, ρ_r::FT) where {FT}
N = FT(1e8)
function lambda_guess_plot()
N = FT(1e6)

λs = FT(1e2):FT(1e2):FT(1e6 + 1)
th = P3.thresholds(p3, ρ_r, F_r)
qs = [P3.q_gamma(p3, F_r, N, log(λ), th) for λ in λs]

guesses = [guess_value(λ, λs[1], last(λs), qs[1], last(qs)) for λ in λs]
F_r_s = [FT(0.0), FT(0.5), FT(0.8)]
ρ_r_s = [FT(200), FT(400), FT(800)]

f = Plt.Figure()
Plt.Axis(
f[1, 1],
xscale = log,
yscale = log,
xticks = [10^2, 10^3, 10^4, 10^5, 10^6],
yticks = [10^3, 1, 10^-3, 10^-6],
xlabel = "λ",
ylabel = "q",
title = "q vs λ",
height = 300,
width = 400,
)

l1 = Plt.lines!(λs, qs, linewidth = 3, color = "Black", label = "q")
l2 = Plt.lines!(
λs,
guesses,
linewidth = 2,
linestyle = :dash,
color = "Red",
label = "q_approximated",
)

Plt.axislegend("Legend", position = :lb)

for i in 1:length(F_r_s)
for j in 1:length(ρ_r_s)
F_r = F_r_s[i]
ρ_r = ρ_r_s[j]

Plt.Axis(
f[i, j],
xlabel = "log(q/N)",
ylabel = "log(λ)",
title = string("λ vs q/N for F_r = ", F_r, " and ρ_r = ", ρ_r),
height = 300,
width = 400,
)


logλs = FT(1):FT(0.01):FT(6)
λs = [10^logλ for logλ in logλs]
th = P3.thresholds(p3, ρ_r, F_r)
qs = [
P3.q_over_N_gamma(p3, F_r, log(λ), P3.DSD_μ(p3, λ), th) for
λ in λs
]
guesses = [FT(0) for λ in λs]

for i in 1:length(λs)
(min,) = P3.get_bounds(
N,
qs[i] * N,
P3.DSD_μ_approx(p3, qs[i] * N, N, ρ_r, F_r),
F_r,
p3,
th,
)
guesses[i] = exp(min)
end


Plt.lines!(
log10.(qs),
log10.(λs),
linewidth = 3,
color = "Black",
label = "true",
)
Plt.lines!(
log10.(qs),
log10.(guesses),
linewidth = 2,
linestyle = :dash,
color = "Red",
label = "approximated",
)

Plt.axislegend("Legend", position = :lb)
end
end

Plt.resize_to_layout!(f)
Plt.save("SolverInitialGuess.svg", f)

end

lambda_guess_plot(FT(0.5), FT(200))
lambda_guess_plot()
Loading

0 comments on commit a7d088d

Please sign in to comment.