Skip to content

Commit 95e454b

Browse files
committed
Fixes
1 parent 588fa61 commit 95e454b

File tree

5 files changed

+97
-76
lines changed

5 files changed

+97
-76
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
run: |
3838
using Pkg
3939
Pkg.add([
40-
PackageSpec(name="StarAlgebras", rev="mk/quadratic_form"),
40+
PackageSpec(name="StarAlgebras", rev="bl/quad_form"),
4141
PackageSpec(name="SymbolicWedderburn", rev="master"),
4242
PackageSpec(name="MultivariateBases", rev="master"),
4343
PackageSpec(name="MultivariateMoments", rev="master"),

src/Bridges/Variable/kernel.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ function MOI.Bridges.Variable.bridge_constrained_variable(
2020
gram, vars, con = SOS.add_gram_matrix(model, M, gram_basis, T)
2121
push!(variables, vars)
2222
push!(constraints, con)
23-
MA.operate!(SA.UnsafeAddMul(*), acc, gram, weight)
23+
if isone(weight)
24+
MA.operate!(SA.UnsafeAdd(), acc, SA.QuadraticForm(gram))
25+
else
26+
MA.operate!(SA.UnsafeAddMul(*), acc, gram, weight)
27+
end
2428
end
2529
MA.operate!(SA.canonical, SA.coeffs(acc))
2630
return KernelBridge{T,M}(

src/Certificate/ideal.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,20 @@ function _combine_with_gram(
2929
)
3030
end
3131
for (gram, weight) in zip(gram_bases, weights)
32-
MA.operate!(
33-
SA.UnsafeAddMul(*),
34-
p,
35-
GramMatrix{_NonZero}((_, _) -> _NonZero(), gram),
36-
weight,
37-
)
32+
if isone(weight)
33+
MA.operate!(
34+
SA.UnsafeAdd(),
35+
p,
36+
SA.QuadraticForm(GramMatrix{_NonZero}((_, _) -> _NonZero(), gram)),
37+
)
38+
else
39+
MA.operate!(
40+
SA.UnsafeAddMul(*),
41+
p,
42+
GramMatrix{_NonZero}((_, _) -> _NonZero(), gram),
43+
weight,
44+
)
45+
end
3846
end
3947
MA.operate!(SA.canonical, SA.coeffs(p))
4048
return MB.SubBasis{B}(keys(SA.coeffs(p)))

src/Certificate/newton_polytope.jl

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,8 @@ Base.iszero(::SignChange) = false
573573
MA.scaling_convert(::Type, s::SignChange) = s
574574
Base.:*(s::SignChange, α::Real) = SignChange(s.sign * α, s.Δ)
575575
Base.:*::Real, s::SignChange) = SignChange* s.sign, s.Δ)
576+
#Base.convert(::Type{SignChange{T}}, s::SignChange) where {T} = SignChange{T}(s.sign, s.Δ)
577+
#Base.:+(a::SignChange, b::SignChange) = convert(SignCount, a) + b
576578

577579
struct SignCount
578580
unknown::Int
@@ -593,6 +595,18 @@ function _sign(c::SignCount)
593595
end
594596
end
595597

598+
function Base.:*(α, a::SignCount)
599+
if α > 0
600+
return a
601+
elseif α < 0
602+
return SignCount(a.unknown, a.negative, a.positive)
603+
else
604+
error("Cannot multiply `SignCount`` with ``")
605+
end
606+
end
607+
608+
Base.:*(a::SignCount, α) = α * a
609+
596610
function Base.:+(a::SignCount, b::SignCount)
597611
return SignCount(
598612
a.unknown + b.unknown,
@@ -602,16 +616,16 @@ function Base.:+(a::SignCount, b::SignCount)
602616
end
603617

604618
function Base.:+(c::SignCount, a::SignChange{Missing})
605-
@assert c.unknown >= -a.Δ
619+
#@assert c.unknown >= -a.Δ
606620
return SignCount(c.unknown + a.Δ, c.positive, c.negative)
607621
end
608622

609623
function Base.:+(c::SignCount, a::SignChange{<:Number})
610624
if a.sign > 0
611-
@assert c.positive >= -a.Δ
625+
#@assert c.positive >= -a.Δ
612626
return SignCount(c.unknown, c.positive + a.Δ, c.negative)
613627
elseif a.sign < 0
614-
@assert c.negative >= -a.Δ
628+
#@assert c.negative >= -a.Δ
615629
return SignCount(c.unknown, c.positive, c.negative + a.Δ)
616630
elseif iszero(a.sign)
617631
error(
@@ -624,26 +638,16 @@ end
624638

625639
Base.convert(::Type{SignCount}, Δ::SignChange) = SignCount() + Δ
626640

627-
function increase(cache, counter, generator_sign, monos, mult)
628-
for a in monos
629-
for b in monos
630-
MA.operate_to!(
631-
cache,
632-
*,
633-
MB.algebra_element(mult),
634-
MB.algebra_element(a),
635-
MB.algebra_element(b),
636-
)
637-
MA.operate!(
638-
SA.UnsafeAddMul(*),
639-
counter,
640-
_term_constant_monomial(
641-
SignChange((a != b) ? missing : generator_sign, 1),
642-
mult,
643-
),
644-
cache,
645-
)
646-
end
641+
struct SignGram{T,B}
642+
sign::T
643+
basis::B
644+
end
645+
SA.basis(g::SignGram) = g.basis
646+
function Base.getindex(g::SignGram, i, j)
647+
if i == j
648+
return SignChange(g.sign, 1)
649+
else
650+
return SignChange(missing, 2)
647651
end
648652
end
649653

@@ -708,7 +712,8 @@ function post_filter(
708712
_DictCoefficients(Dict{MP.monomial_type(typeof(poly)),SignCount}()),
709713
MB.implicit_basis(SA.basis(poly)),
710714
)
711-
cache = zero(Float64, MB.algebra(MB.implicit_basis(SA.basis(poly))))
715+
cache = zero(SignCount, MB.algebra(MB.implicit_basis(SA.basis(poly))))
716+
cache2 = zero(SignCount, MB.algebra(MB.implicit_basis(SA.basis(poly))))
712717
for (mono, v) in SA.nonzero_pairs(SA.coeffs(poly))
713718
MA.operate!(
714719
SA.UnsafeAdd(),
@@ -717,29 +722,21 @@ function post_filter(
717722
)
718723
end
719724
for (mult, gram_monos) in zip(generators, multipliers_gram_monos)
720-
for (mono, v) in SA.nonzero_pairs(SA.coeffs(mult))
721-
increase(
722-
cache,
723-
counter,
724-
-_sign(v),
725-
gram_monos,
726-
SA.basis(mult)[mono],
727-
)
728-
end
725+
MA.operate_to!(cache, copy, SA.QuadraticForm(SignGram(-1, gram_monos)))
726+
MA.operate!(SA.UnsafeAddMul(*), counter, mult, cache)
729727
end
730-
function decrease(sign, a, b, c)
728+
function decrease(sign, a, b, generator)
731729
MA.operate_to!(
732730
cache,
733731
*,
734-
MB.algebra_element(a),
732+
_term(SignChange(1, -1), a),
735733
MB.algebra_element(b),
736-
MB.algebra_element(c),
737734
)
738735
MA.operate!(
739736
SA.UnsafeAddMul(*),
740737
counter,
741-
_term_constant_monomial(SignChange(sign, -1), a),
742738
cache,
739+
generator,
743740
)
744741
for mono in SA.supp(cache)
745742
count = SA.coeffs(counter)[SA.basis(counter)[mono]]
@@ -765,36 +762,38 @@ function post_filter(
765762
end
766763
keep[i][j] = false
767764
a = multipliers_gram_monos[i][j]
768-
for (k, v) in SA.nonzero_pairs(SA.coeffs(generators[i]))
769-
mono = SA.basis(generators[i])[k]
770-
sign = -_sign(v)
771-
decrease(sign, mono, a, a)
772-
for (j, b) in enumerate(multipliers_gram_monos[i])
773-
if keep[i][j]
774-
decrease(missing, mono, a, b)
775-
decrease(missing, mono, b, a)
776-
end
765+
decrease(-1, a, a, generators[i])
766+
for (k, b) in enumerate(multipliers_gram_monos[i])
767+
if keep[i][k]
768+
decrease(missing, a, b, generators[i])
769+
decrease(missing, b, a, generators[i])
777770
end
778771
end
779772
end
780773
for i in eachindex(generators)
781-
for k in SA.supp(generators[i])
782-
for (j, mono) in enumerate(multipliers_gram_monos[i])
783-
MA.operate_to!(
784-
cache,
785-
*,
786-
MB.algebra_element(k),
787-
MB.algebra_element(mono),
788-
MB.algebra_element(mono),
774+
for (j, mono) in enumerate(multipliers_gram_monos[i])
775+
MA.operate_to!(
776+
cache,
777+
*,
778+
# Dummy coef to help convert to `SignCount` which is the `eltype` of `cache`
779+
_term(SignChange(1, 1), mono),
780+
MB.algebra_element(mono),
781+
)
782+
# The `eltype` of `cache` is `SignCount`
783+
# so there is no risk of term cancellation
784+
MA.operate_to!(
785+
cache2,
786+
*,
787+
cache,
788+
generators[i],
789+
)
790+
for w in SA.supp(cache)
791+
if ismissing(
792+
_sign(SA.coeffs(counter)[SA.basis(counter)[w]]),
789793
)
790-
for w in SA.supp(cache)
791-
if ismissing(
792-
_sign(SA.coeffs(counter)[SA.basis(counter)[w]]),
793-
)
794-
push!(get!(back, w, Tuple{Int,Int}[]), (i, j))
795-
else
796-
delete(i, j)
797-
end
794+
push!(get!(back, w, Tuple{Int,Int}[]), (i, j))
795+
else
796+
delete(i, j)
798797
end
799798
end
800799
end

src/gram_matrix.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,21 @@ end
308308
# convert(PT, MP.polynomial(p))
309309
#end
310310

311+
function MB.algebra_element(
312+
p::Union{GramMatrix{T,B,U},BlockDiagonalGramMatrix{T,B,U}},
313+
) where {T,B,U}
314+
return MB.algebra_element(p, U)
315+
end
316+
317+
function MB.algebra_element(
318+
g::Union{GramMatrix,BlockDiagonalGramMatrix},
319+
::Type{T},
320+
) where {T}
321+
a = zero(T, MB.algebra(MB.implicit_basis(g)))
322+
MA.operate_to!(a, copy, SA.QuadraticForm(g))
323+
return a
324+
end
325+
311326
function MP.polynomial(
312327
p::Union{GramMatrix{T,B,U},BlockDiagonalGramMatrix{T,B,U}},
313328
) where {T,B,U}
@@ -318,10 +333,5 @@ function MP.polynomial(
318333
g::Union{GramMatrix,BlockDiagonalGramMatrix},
319334
::Type{T},
320335
) where {T}
321-
p = zero(T, MB.algebra(MB.implicit_basis(g)))
322-
MA.operate!(SA.UnsafeAddMul(*), p, g)
323-
MA.operate!(SA.canonical, SA.coeffs(p))
324-
return MP.polynomial(
325-
SA.coeffs(p, MB.FullBasis{MB.Monomial,MP.monomial_type(g)}()),
326-
)
336+
return MP.polynomial(MB.algebra_element(g, T))
327337
end

0 commit comments

Comments
 (0)