Skip to content

Commit fa49e3a

Browse files
haampiemohamed82008
authored andcommitted
Simplify the tests of lsmr and remove the 5-argument mul! stuff since that is not in LinearAlgebra for dense matrices.
1 parent 76d5ed7 commit fa49e3a

File tree

2 files changed

+38
-92
lines changed

2 files changed

+38
-92
lines changed

src/lsmr.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,14 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
107107
normArs = Tr[]
108108
conlim > 0 ? ctol = convert(Tr, inv(conlim)) : ctol = zero(Tr)
109109
# form the first vectors u and v (satisfy β*u = b, α*v = A'u)
110-
u = mul!(b, A, x, -1, 1)
110+
tmp_u = similar(b)
111+
tmp_v = similar(v)
112+
mul!(tmp_u, A, x)
113+
b .-= tmp_u
114+
u = b
111115
β = norm(u)
112116
u .*= inv(β)
113-
mul!(v, adjoint(A), u, 1, 0)
117+
mul!(v, adjoint(A), u)
114118
α = norm(v)
115119
v .*= inv(α)
116120

@@ -158,12 +162,14 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
158162
while iter < maxiter
159163
nextiter!(log,mvps=1)
160164
iter += 1
161-
mul!(u, A, v, 1, -α)
165+
mul!(tmp_u, A, v)
166+
u .= tmp_u .+ u .* -α
162167
β = norm(u)
163168
if β > 0
164169
log.mtvps+=1
165170
u .*= inv(β)
166-
mul!(v, adjoint(A), u, 1, -β)
171+
mul!(tmp_v, adjoint(A), u)
172+
v .= tmp_v .+ v .* -β
167173
α = norm(v)
168174
v .*= inv(α)
169175
end
@@ -278,11 +284,3 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
278284
setconv(log, istop (3, 6, 7))
279285
x
280286
end
281-
282-
function LinearAlgebra.mul!(y::AbstractVector, A::StridedVecOrMat, x::AbstractVector, α::Number, β::Number)
283-
BLAS.gemm!('N', 'N', convert(eltype(y), α), A, x, convert(eltype(y), β), y)
284-
end
285-
286-
function LinearAlgebra.mul!(y::AbstractVector, A::Adjoint{F, T}, x::AbstractVector, α::Number, β::Number) where {F, T <: StridedVecOrMat{F}}
287-
BLAS.gemm!('T', 'N', convert(eltype(y), α), adjoint(A), x, convert(eltype(y), β), y)
288-
end

test/lsmr.jl

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,17 @@ using LinearAlgebra
44
using Random
55
using SparseArrays
66

7-
import Base: size, eltype, similar, copyto!, fill!, length
7+
import Base: size, eltype, similar, copyto!, fill!, length, axes, getindex, setindex!
88
import LinearAlgebra: norm, mul!, rmul!, lmul!
99

1010
# Type used in Dampenedtest
11-
# solve (A'A + diag(v).^2 ) x = b
12-
# using LSMR in the augmented space A' = [A ; diag(v)] b' = [b; zeros(size(A, 2)]
13-
mutable struct DampenedVector{Ty, Tx}
14-
y::Ty
15-
x::Tx
16-
end
17-
18-
eltype(a::DampenedVector) = promote_type(eltype(a.y), eltype(a.x))
19-
norm(a::DampenedVector) = sqrt(norm(a.y)^2 + norm(a.x)^2)
20-
21-
function Base.Broadcast.broadcast!(f::Tf, to::DampenedVector, from::DampenedVector, args...) where {Tf}
22-
to.x .= f.(from.x, args...)
23-
to.y .= f.(from.y, args...)
24-
to
25-
end
26-
27-
function copyto!(a::DampenedVector{Ty, Tx}, b::DampenedVector{Ty, Tx}) where {Ty, Tx}
28-
copyto!(a.y, b.y)
29-
copyto!(a.x, b.x)
30-
a
31-
end
32-
33-
function fill!(a::DampenedVector, α::Number)
34-
fill!(a.y, α)
35-
fill!(a.x, α)
36-
a
37-
end
38-
39-
lmul!::Number, a::DampenedVector) = rmul!(a, α)
40-
41-
function rmul!(a::DampenedVector, α::Number)
42-
rmul!(a.y, α)
43-
rmul!(a.x, α)
44-
a
45-
end
46-
47-
similar(a::DampenedVector, T) = DampenedVector(similar(a.y, T), similar(a.x, T))
48-
length(a::DampenedVector) = length(a.y) + length(a.x)
49-
50-
mutable struct DampenedMatrix{TA, Tx}
11+
# solve (A'A + diag(v).^2 ) x = A'b
12+
# using LSMR in the augmented space à = [A ; diag(v)] b̃ = [b; zeros(size(A, 2)]
13+
struct DampenedMatrix{Tv,TA<:AbstractMatrix{Tv},TD<:AbstractVector{Tv}} <: AbstractMatrix{Tv}
5114
A::TA
52-
diagonal::Tx
15+
diagonal::TD
5316
end
5417

55-
eltype(A::DampenedMatrix) = promote_type(eltype(A.A), eltype(A.diagonal))
56-
5718
function size(A::DampenedMatrix)
5819
m, n = size(A.A)
5920
l = length(A.diagonal)
@@ -63,36 +24,23 @@ end
6324
function size(A::DampenedMatrix, dim::Integer)
6425
m, n = size(A.A)
6526
l = length(A.diagonal)
66-
dim == 1 ? (m + l) :
67-
dim == 2 ? n : 1
27+
dim == 1 ? (m + l) : (dim == 2 ? n : 1)
6828
end
6929

70-
function mul!(b::DampenedVector{Ty, Tx}, mw::DampenedMatrix{TA, Tx}, a::Tx,
71-
α::Number, β::Number) where {TA, Tx, Ty}
72-
if β != 1.
73-
if β == 0.
74-
fill!(b, 0.)
75-
else
76-
rmul!(b, β)
77-
end
78-
end
79-
mul!(b.y, mw.A, a, α, 1.0)
80-
map!((z, x, y)-> z + α * x * y, b.x, b.x, a, mw.diagonal)
81-
return b
30+
function mul!(y::AbstractVector{Tv}, mw::DampenedMatrix, x::AbstractVector{Tv}) where {Tv}
31+
m₁ = size(mw.A, 1)
32+
m₂ = size(mw, 1)
33+
mul!(view(y, 1:m₁), mw.A, x)
34+
y[m₁+1:m₂] .= mw.diagonal .* x
35+
return y
8236
end
8337

84-
function mul!(b::Tx, mw::Adjoint{DampenedMatrix{TA, Tx}}, a::DampenedVector{Ty, Tx},
85-
α::Number, β::Number) where {TA, Tx, Ty}
86-
if β != 1.
87-
if β == 0.
88-
fill!(b, 0.)
89-
else
90-
rmul!(b, β)
91-
end
92-
end
93-
mul!(b, adjoint(mw.A), a.y, α, 1.0)
94-
map!((z, x, y)-> z + α * x * y, b, b, a.x, mw.diagonal)
95-
return b
38+
function mul!(y::AbstractVector, mw::Adjoint{Tv,<:DampenedMatrix}, x::AbstractVector) where {Tv}
39+
m₁ = size(mw.parent.A, 1)
40+
m₂ = size(mw.parent, 1)
41+
mul!(y, adjoint(mw.parent.A), view(x, 1:m₁))
42+
y .+= mw.parent.diagonal .* view(x, m₁+1:m₂)
43+
return y
9644
end
9745

9846
"""
@@ -138,14 +86,14 @@ end
13886
@test norm(b - A * x) 1e-4
13987
end
14088

141-
# @testset "Dampened test" for (m, n) = ((10, 10), (20, 10))
142-
# # Test used to make sure A, b can be generic matrix / vector
143-
# b = rand(m)
144-
# A = rand(m, n)
145-
# v = rand(n)
146-
# Adampened = DampenedMatrix(A, v)
147-
# bdampened = DampenedVector(b, zeros(n))
148-
# x, ch = lsmr(Adampened, bdampened, log=true)
149-
# @test norm((A'A + Matrix(Diagonal(v)) .^ 2)x - A'b) ≤ 1e-3
150-
# end
89+
@testset "Dampened test" for (m, n) = ((10, 10), (20, 10))
90+
# Test used to make sure A, b can be generic matrix / vector
91+
b = rand(m)
92+
A = rand(m, n)
93+
v = rand(n)
94+
A′ = DampenedMatrix(A, v)
95+
b′ = [b; zeros(n)]
96+
x, ch = lsmr(A′, b′, log=true)
97+
@test norm((A'A + Matrix(Diagonal(v)) .^ 2)x - A'b) 1e-3
98+
end
15199
end

0 commit comments

Comments
 (0)