-
Notifications
You must be signed in to change notification settings - Fork 104
Add COCG method for complex symmetric linear systems #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 11 commits
c5b440e
91208cd
3337884
f7df543
d5b31f4
97315be
593e88c
835a894
9c47e7f
3cd7969
d16fecb
04c3c16
4800129
9a7fb26
f3710c3
b457247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,21 +1,22 @@ | ||||||
import Base: iterate | ||||||
using Printf | ||||||
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||||||
export cg, cg!, cocg, cocg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||||||
|
||||||
mutable struct CGIterable{matT, solT, vecT, numT <: Real} | ||||||
mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: AbstractDot} | ||||||
A::matT | ||||||
x::solT | ||||||
r::vecT | ||||||
c::vecT | ||||||
u::vecT | ||||||
tol::numT | ||||||
residual::numT | ||||||
prev_residual::numT | ||||||
ρ_prev::paramT | ||||||
maxiter::Int | ||||||
mv_products::Int | ||||||
dotproduct::dotT | ||||||
end | ||||||
|
||||||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number} | ||||||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: AbstractDot} | ||||||
Pl::precT | ||||||
A::matT | ||||||
x::solT | ||||||
|
@@ -24,9 +25,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb | |||||
u::vecT | ||||||
tol::numT | ||||||
residual::numT | ||||||
ρ::paramT | ||||||
ρ_prev::paramT | ||||||
maxiter::Int | ||||||
mv_products::Int | ||||||
dotproduct::dotT | ||||||
end | ||||||
|
||||||
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol | ||||||
|
@@ -47,18 +49,20 @@ function iterate(it::CGIterable, iteration::Int=start(it)) | |||||
end | ||||||
|
||||||
# u := r + βu (almost an axpy) | ||||||
β = it.residual^2 / it.prev_residual^2 | ||||||
ρ = isa(it.dotproduct, ConjugatedDot) ? it.residual^2 : _norm(it.r, it.dotproduct)^2 | ||||||
β = ρ / it.ρ_prev | ||||||
|
||||||
it.u .= it.r .+ β .* it.u | ||||||
|
||||||
# c = A * u | ||||||
mul!(it.c, it.A, it.u) | ||||||
α = it.residual^2 / dot(it.u, it.c) | ||||||
α = ρ / _dot(it.u, it.c, it.dotproduct) | ||||||
|
||||||
# Improve solution and residual | ||||||
it.ρ_prev = ρ | ||||||
it.x .+= α .* it.u | ||||||
it.r .-= α .* it.c | ||||||
|
||||||
it.prev_residual = it.residual | ||||||
it.residual = norm(it.r) | ||||||
|
||||||
# Return the residual at item and iteration number as state | ||||||
|
@@ -78,18 +82,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it)) | |||||
# Apply left preconditioner | ||||||
ldiv!(it.c, it.Pl, it.r) | ||||||
|
||||||
ρ_prev = it.ρ | ||||||
it.ρ = dot(it.c, it.r) | ||||||
|
||||||
# u := c + βu (almost an axpy) | ||||||
β = it.ρ / ρ_prev | ||||||
ρ = _dot(it.r, it.c, it.dotproduct) | ||||||
β = ρ / it.ρ_prev | ||||||
it.u .= it.c .+ β .* it.u | ||||||
|
||||||
# c = A * u | ||||||
mul!(it.c, it.A, it.u) | ||||||
α = it.ρ / dot(it.u, it.c) | ||||||
α = ρ / _dot(it.u, it.c, it.dotproduct) | ||||||
|
||||||
# Improve solution and residual | ||||||
it.ρ_prev = ρ | ||||||
it.x .+= α .* it.u | ||||||
it.r .-= α .* it.c | ||||||
|
||||||
|
@@ -122,7 +125,8 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |||||
reltol::Real = sqrt(eps(real(eltype(b)))), | ||||||
maxiter::Int = size(A, 2), | ||||||
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||||||
initially_zero::Bool = false) | ||||||
initially_zero::Bool = false, | ||||||
dotproduct::AbstractDot = ConjugatedDot()) | ||||||
u = statevars.u | ||||||
r = statevars.r | ||||||
c = statevars.c | ||||||
|
@@ -143,14 +147,12 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |||||
# Return the iterable | ||||||
if isa(Pl, Identity) | ||||||
return CGIterable(A, x, r, c, u, | ||||||
tolerance, residual, one(residual), | ||||||
maxiter, mv_products | ||||||
) | ||||||
tolerance, residual, one(eltype(r)), | ||||||
maxiter, mv_products, dotproduct) | ||||||
else | ||||||
return PCGIterable(Pl, A, x, r, c, u, | ||||||
tolerance, residual, one(eltype(x)), | ||||||
maxiter, mv_products | ||||||
) | ||||||
tolerance, residual, one(eltype(r)), | ||||||
maxiter, mv_products, dotproduct) | ||||||
end | ||||||
end | ||||||
|
||||||
|
@@ -211,6 +213,7 @@ function cg!(x, A, b; | |||||
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||||||
verbose::Bool = false, | ||||||
Pl = Identity(), | ||||||
dotproduct::AbstractDot = ConjugatedDot(), | ||||||
kwargs...) | ||||||
history = ConvergenceHistory(partial = !log) | ||||||
history[:abstol] = abstol | ||||||
|
@@ -219,7 +222,7 @@ function cg!(x, A, b; | |||||
|
||||||
# Actually perform CG | ||||||
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter, | ||||||
statevars = statevars, kwargs...) | ||||||
statevars = statevars, dotproduct = dotproduct, kwargs...) | ||||||
if log | ||||||
history.mvps = iterable.mv_products | ||||||
end | ||||||
|
@@ -237,3 +240,19 @@ function cg!(x, A, b; | |||||
|
||||||
log ? (iterable.x, history) : iterable.x | ||||||
end | ||||||
|
||||||
""" | ||||||
cocg(A, b; kwargs...) -> x, [history] | ||||||
|
||||||
Same as [`cocg!`](@ref), but allocates a solution vector `x` initialized with zeros. | ||||||
""" | ||||||
cocg(A, b; kwargs...) = cocg!(zerox(A, b), A, b; initially_zero = true, kwargs...) | ||||||
|
||||||
""" | ||||||
cocg!(x, A, b; kwargs...) -> x, [history] | ||||||
|
||||||
Same as [`cg!`](@ref), but uses the unconjugated dot product (`xᵀy`) instead of the usual, | ||||||
conjugated dot product (`x'y`) in the algorithm. It is for solving linear systems with | ||||||
matrices `A` that are complex-symmetric (`Aᵀ == A`) rahter than Hermitian (`A' == A`). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for finding out the typo! Will correct this in the next commit. |
||||||
""" | ||||||
cocg!(x, A, b; kwargs...) = cg!(x, A, b; dotproduct = UnconjugatedDot(), kwargs...) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,6 @@ | ||||||
import LinearAlgebra: ldiv!, \ | ||||||
|
||||||
export Identity | ||||||
export Identity, ConjugatedDot, UnconjugatedDot | ||||||
|
||||||
#### Type-handling | ||||||
""" | ||||||
|
@@ -30,3 +30,16 @@ struct Identity end | |||||
\(::Identity, x) = copy(x) | ||||||
ldiv!(::Identity, x) = x | ||||||
ldiv!(y, ::Identity, x) = copyto!(y, x) | ||||||
|
||||||
""" | ||||||
Conjugated and unconjugated dot products | ||||||
""" | ||||||
abstract type AbstractDot end | ||||||
struct ConjugatedDot <: AbstractDot end | ||||||
struct UnconjugatedDot <: AbstractDot end | ||||||
|
||||||
_norm(x, ::ConjugatedDot) = norm(x) | ||||||
_dot(x, y, ::ConjugatedDot) = dot(x, y) | ||||||
|
||||||
_norm(x, ::UnconjugatedDot) = sqrt(sum(xₖ->xₖ^2, x)) | ||||||
_dot(x, y, ::UnconjugatedDot) = transpose(@view(x[:])) * @view(y[:]) # allocating, but faster than sum(prod, zip(x,y)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
this shouldn't allocate, at least not on julia 1.5 and above; so let's drop that comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is allocating when If it is allocating, how about changing the comment in the code from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, but I'm not following why you're using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and the only thing that might allocate if julia is not smart enough is the array wrapper, but it's O(1) not O(n), so not worth documenting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What if It is true that multi-dimensional arrays If the unconjugated dot product had been one of those operators required for making objects behave as vectors, I wouldn't have had to go through these complications...
Fair enough. I will remove the comment in the next commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CG method only solves Orthogonal to that is how you internally represent your data. If you want X and B to be a tensor because that maps well to your PDE finite differences scheme, then the user should still wrap it in a vector type
and define the interface for it. Or you don't and you call But generally we assume the interface of vectors in iterative methods, and later we can think about supporting block methods to solve AX=B for multiple right-hand sides, where we assume a matrix interface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if we should support multiple right-hand sides with the same That said, my current proposal _norm(x, ::UnconjugatedDot) = sqrt(sum(xₖ->xₖ^2, x))
_dot(x, y, ::UnconjugatedDot) = transpose(@view(x[:])) * @view(y[:]) is not the perfect solution that supports the most general type of (I think the best way to solve the present issue might be to define a lazy wrapper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
Uh oh!
There was an error while loading. Please reload this page.