Skip to content

Commit eff34b0

Browse files
authored
Fix block diag (#321)
* Fix block diag * Fix rebase * Fixes * Clean up * Fix format * Fixes * Switch back to CSDP * Fixes * Fix format * Fixes * Fix format * Fix doc
1 parent 2321de2 commit eff34b0

File tree

4 files changed

+169
-77
lines changed

4 files changed

+169
-77
lines changed

docs/src/constraints.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,5 +364,6 @@ PolyJuMP.Bridges.Constraint.ToPolynomialBridge
364364

365365
```@docs
366366
SumOfSquares.Certificate.Symmetry.orthogonal_transformation_to
367-
SumOfSquares.Certificate.Symmetry._permutation_quasi_upper_triangular
367+
SumOfSquares.Certificate.Symmetry._reorder!
368+
SumOfSquares.Certificate.Symmetry._rotate_complex
368369
```

docs/src/tutorials/Symmetry/dihedral.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,17 @@ function solve(G)
149149

150150
g = gram_matrix(con_ref).blocks #src
151151
@test length(g) == 5 #src
152-
@test g[1].basis.polynomials == [y^3, x^2*y, -y] #src
153-
@test g[2].basis.polynomials == [-x^3, -x*y^2, x] #src
152+
@test g[1].basis.polynomials == [y^3, x^2*y, y] #src
153+
@test g[2].basis.polynomials == [-x^3, -x*y^2, -x] #src
154154
for i in 1:2 #src
155155
I = 3:-1:1 #src
156156
Q = g[i].Q[I, I] #src
157157
@test size(Q) == (3, 3) #src
158158
@test Q[2, 2] 1 rtol=1e-2 #src
159-
@test Q[1, 2] -5/8 rtol=1e-2 #src
159+
@test Q[1, 2] 5/8 rtol=1e-2 #src
160160
@test Q[2, 3] -1 rtol=1e-2 #src
161161
@test Q[1, 1] 25/64 rtol=1e-2 #src
162-
@test Q[1, 3] 5/8 rtol=1e-2 #src
162+
@test Q[1, 3] -5/8 rtol=1e-2 #src
163163
@test Q[3, 3] 1 rtol=1e-2 #src
164164
end #src
165165
@test length(g[3].basis.polynomials) == 2 #src

src/Certificate/Symmetry/block_diag.jl

Lines changed: 119 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,91 @@
11
import DataStructures
22

3+
_isapproxless(a::Real, b::Real) = a < b
4+
function _isapproxless(a::Complex, b::Complex)
5+
if real(a) real(b)
6+
return isless(imag(a), imag(b))
7+
else
8+
return isless(real(a), real(b))
9+
end
10+
end
11+
312
"""
4-
_permutation_quasi_upper_triangular(S)
13+
_reorder!(F::LinearAlgebra.Schur{T}) where {T}
514
6-
Given a (quasi) upper triangular matrix `S`
7-
returns the permutation `P` so that
8-
`P' * S * P` has its eigenvalues in increasing order.
15+
Given a Schur decomposition of a, reorder it so that its
16+
eigenvalues are in in increasing order.
917
10-
By (quasi), we mean that if `S` is a `Matrix{<:Real}`,
11-
then there may be nonzero entries in `S[i+1,i]` representing
18+
Note that if `T<:Real`, `F.Schur` is quasi upper triangular.
19+
By (quasi), we mean that there may be nonzero entries in `S[i+1,i]` representing
1220
complex conjugates.
1321
In that case, the complex conjugate are permuted together.
14-
If `S` is a `Matrix{<:Complex}`, then `S` is triangular.
22+
If `T<:Complex`, then `S` is triangular.
1523
"""
16-
function _permutation_quasi_upper_triangular(S::AbstractMatrix{T}) where {T}
17-
n = LinearAlgebra.checksquare(S)
24+
function _reorder!(F::LinearAlgebra.Schur{T}) where {T}
25+
n = length(F.values)
1826
# Bubble sort
1927
sorted = false
20-
P = SparseArrays.sparse(one(T) * LinearAlgebra.I, n, n)
21-
function permute!(i, j)
22-
I = collect(1:n)
23-
J = copy(I)
24-
J[i] = j
25-
J[j] = i
26-
swap = sparse(I, J, ones(T, n), n, n)
27-
S = swap' * S * swap
28-
P *= swap
29-
return
30-
end
3128
while !sorted
3229
prev_i = nothing
3330
sorted = true
3431
i = 1
3532
while i <= n
36-
permute =
37-
!isnothing(prev_i) &&
38-
(real(S[i, i]), imag(S[i, i])) <
39-
(real(S[prev_i, prev_i]), imag(S[prev_i, prev_i]))
33+
S = F.Schur
4034
if (T <: Real) && i < n && !iszero(S[i+1, i])
41-
#if S[i+1, i] < S[i, i+1]
42-
# permute!(i, i + 1)
43-
#end
44-
if permute
45-
if i - prev_i == 2
46-
permute!(prev_i, i)
47-
permute!(prev_i + 1, i + 1)
48-
else
49-
permute!(prev_i, i)
50-
permute!(i, i + 1)
51-
end
52-
sorted = false
53-
end
54-
# complex
55-
prev_i = i
56-
i += 2
35+
# complex pair
36+
next_i = i + 2
5737
else
58-
if permute
59-
permute!(prev_i, i)
60-
if i - prev_i == 2
61-
permute!(i - 1, i)
62-
end
63-
sorted = false
64-
end
65-
prev_i = i
66-
i += 1
38+
next_i = i + 1
39+
end
40+
if !isnothing(prev_i) && _isapproxless(S[i, i], S[prev_i, prev_i])
41+
select = trues(n)
42+
select[prev_i:(i-1)] .= false
43+
select[next_i:end] .= false
44+
LinearAlgebra.ordschur!(F, select)
45+
sorted = false
6746
end
47+
prev_i = i
48+
i = next_i
6849
end
6950
end
70-
return P
7151
end
7252

73-
# `A` and `B` may be not upper triangular because of the permutations
74-
function _sign_diag(A::AbstractMatrix{T}, B::AbstractMatrix{T}) where {T}
53+
# We can multiply by `Diagonal(d)` if `d[i] * conj(d[i]) = 1`.
54+
# So in the real case, `d = ±1` but in the complex case, we have more freedom.
55+
function _sign_diag(
56+
A::AbstractMatrix{T},
57+
B::AbstractMatrix{T};
58+
tol = Base.rtoldefault(real(T)),
59+
) where {T}
7560
n = LinearAlgebra.checksquare(A)
7661
d = ones(T, n)
77-
for i in 1:n
78-
minus = zero(real(T))
79-
not_minus = zero(real(T))
80-
for j in 1:(i-1)
81-
for (I, J) in [(i, j), (j, i)]
82-
a = A[I, J]
83-
b = B[I, J]
62+
for j in 2:n
63+
if T <: Real
64+
minus = zero(real(T))
65+
not_minus = zero(real(T))
66+
for i in 1:(j-1)
67+
a = A[i, j]
68+
b = B[i, j]
8469
minus = max(minus, abs(a + b))
8570
not_minus = max(not_minus, abs(a - b))
8671
end
87-
end
88-
if minus < not_minus
89-
d[i] = -one(T)
90-
B[:, i] = -B[:, i]
91-
B[i, :] = -B[i, :]
72+
if minus < not_minus
73+
d[j] = -one(T)
74+
B[:, j] = -B[:, j]
75+
B[j, :] = -B[j, :]
76+
end
77+
else
78+
i = argmax(abs.(B[1:(j-1), j]))
79+
if abs(B[i, j]) <= tol
80+
continue
81+
end
82+
rot = A[i, j] / B[i, j]
83+
# It should be unitary but there might be small numerical errors
84+
# so let's normalize
85+
rot /= abs(rot)
86+
d[j] = rot
87+
B[:, j] *= rot
88+
B[j, :] *= conj(rot)
9289
end
9390
end
9491
return d
@@ -104,6 +101,55 @@ function _try_integer!(A::Matrix)
104101
end
105102
end
106103

104+
"""
105+
_rotate_complex(A::AbstractMatrix{T}, B::AbstractMatrix{T}; tol = Base.rtoldefault(real(T))) where {T}
106+
107+
Given (quasi) upper triangular matrix `A` and `B` that have the eigenvalues in
108+
the same order except the complex pairs which may need to be (signed) permuted,
109+
returns an othogonal matrix `P` such that `P' * A * P` and `B` have matching
110+
low triangular part.
111+
The upper triangular part will be dealt with by `_sign_diag`.
112+
113+
By (quasi), we mean that if `S` is a `Matrix{<:Real}`,
114+
then there may be nonzero entries in `S[i+1,i]` representing
115+
complex conjugates.
116+
If `S` is a `Matrix{<:Complex}`, then `S` is upper triangular so there is
117+
nothing to do.
118+
"""
119+
function _rotate_complex(
120+
A::AbstractMatrix{T},
121+
B::AbstractMatrix{T};
122+
tol = Base.rtoldefault(real(T)),
123+
) where {T}
124+
n = LinearAlgebra.checksquare(A)
125+
I = collect(1:n)
126+
J = copy(I)
127+
V = ones(T, n)
128+
pair = false
129+
for i in 1:n
130+
if pair || i == n
131+
continue
132+
end
133+
pair = abs(A[i+1, i]) > tol
134+
if pair
135+
a = (A[i+1, i], A[i, i+1])
136+
b = (B[i+1, i], B[i, i+1])
137+
c = a[2:-1:1]
138+
if LinearAlgebra.norm(abs.(a) .- abs.(b)) >
139+
LinearAlgebra.norm(abs.(c) .- abs.(b))
140+
a = c
141+
J[i] = i + 1
142+
J[i+1] = i
143+
end
144+
c = (-).(a)
145+
if LinearAlgebra.norm(a .- b) > LinearAlgebra.norm(c .- b)
146+
V[i+1] = -V[i]
147+
end
148+
end
149+
end
150+
return SparseArrays.sparse(I, J, V, n, n)
151+
end
152+
107153
"""
108154
orthogonal_transformation_to(A, B)
109155
@@ -113,22 +159,23 @@ Return an orthogonal transformation `U` such that
113159
Given Schur decompositions
114160
`A = Z_A * S_A * Z_A'`
115161
`B = Z_B * S_B * Z_B'`
116-
We further decompose the triangular matrices `S_A`, `S_B`
117-
to order the eigenvalues:
118-
`S_A = P_A * T_A * P_A'`
119-
`S_B = P_B * T_B * P_B'`
162+
Since `P' * S_A * P = D' * S_B * D`, we have
163+
`A = Z_A * P * Z_B' * B * Z_B * P' * Z_A'`
120164
"""
121165
function orthogonal_transformation_to(A, B)
122166
As = LinearAlgebra.schur(A)
167+
_reorder!(As)
123168
T_A = As.Schur
124169
Z_A = As.vectors
125-
P_A = _permutation_quasi_upper_triangular(T_A)
126170
Bs = LinearAlgebra.schur(B)
171+
_reorder!(Bs)
127172
T_B = Bs.Schur
128173
Z_B = Bs.vectors
129-
P_B = _permutation_quasi_upper_triangular(T_B)
130-
d = _sign_diag(P_A' * T_A * P_A, P_B' * T_B * P_B)
131-
return _try_integer!(Z_B * P_B * LinearAlgebra.Diagonal(d) * P_A' * Z_A')
174+
P = _rotate_complex(T_A, T_B)
175+
T_A = P' * T_A * P
176+
d = _sign_diag(T_A, T_B)
177+
D = LinearAlgebra.Diagonal(d)
178+
return _try_integer!(Z_B * D * P' * Z_A')
132179
end
133180

134181
function ordered_block_diag(As, d)

test/symmetry.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,50 @@ function _test_orthogonal_transformation_to(T::Type)
108108
0 0 1
109109
]
110110
_test_orthogonal_transformation_to(A1, A2)
111+
A1 = T[
112+
-1 0 1
113+
1 1 0
114+
0 1 -1
115+
]
116+
A2 = T[
117+
-1 1 0
118+
0 1 1
119+
1 0 -1
120+
]
121+
_test_orthogonal_transformation_to(A1, A2)
122+
A1 = T[
123+
-1 1 1
124+
2 1 0
125+
0 1 0
126+
]
127+
A2 = T[
128+
0 1 0
129+
0 1 2
130+
1 1 -1
131+
]
132+
_test_orthogonal_transformation_to(A1, A2)
133+
A1 = T[
134+
0 1 1
135+
2 0 0
136+
0 1 -1
137+
]
138+
A2 = T[
139+
-1 1 0
140+
0 0 2
141+
1 1 0
142+
]
143+
_test_orthogonal_transformation_to(A1, A2)
144+
A1 = ComplexF64[
145+
0 1 1
146+
2 0 0
147+
0 1 -1
148+
]
149+
A2 = ComplexF64[
150+
-1 1 0
151+
0 0 2
152+
1 1 0
153+
]
154+
_test_orthogonal_transformation_to(A1, A2)
111155
return
112156
end
113157

0 commit comments

Comments
 (0)