Skip to content

Commit f12640a

Browse files
committed
[Nonlinear.SymbolicAD] simplify quadratic functions if possible
1 parent f31be21 commit f12640a

File tree

2 files changed

+199
-68
lines changed

2 files changed

+199
-68
lines changed

src/Nonlinear/SymbolicAD/SymbolicAD.jl

Lines changed: 159 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,19 @@ function simplify!(f::MOI.ScalarAffineFunction{T}) where {T}
7777
if isempty(f.terms)
7878
return f.constant
7979
end
80+
if iszero(f.constant) && length(f.terms) == 1
81+
term = only(f.terms)
82+
if isone(term.coefficient)
83+
return term.variable
84+
end
85+
end
8086
return f
8187
end
8288

8389
function simplify!(f::MOI.ScalarQuadraticFunction{T}) where {T}
8490
f = MOI.Utilities.canonicalize!(f)
8591
if isempty(f.quadratic_terms)
86-
if isempty(f.affine_terms)
87-
return f.constant
88-
end
89-
return MOI.ScalarAffineFunction(f.affine_terms, f.constant)
92+
return simplify!(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
9093
end
9194
return f
9295
end
@@ -117,7 +120,7 @@ function simplify!(f::MOI.ScalarNonlinearFunction)
117120
push!(result_stack, arg)
118121
end
119122
end
120-
return _simplify_if_affine!(only(result_stack))
123+
return _simplify_if_quadratic!(only(result_stack))
121124
end
122125

123126
function simplify!(f::MOI.VectorAffineFunction{T}) where {T}
@@ -140,10 +143,12 @@ function simplify!(f::MOI.VectorQuadraticFunction{T}) where {T}
140143
end
141144

142145
function simplify!(f::MOI.VectorNonlinearFunction)
143-
for (i, row) in enumerate(f.rows)
144-
f.rows[i] = simplify!(row)
146+
rows = simplify!.(f.rows)
147+
Y = reduce(promote_type, typeof.(rows))
148+
if isconcretetype(Y)
149+
return MOI.Utilities.vectorize(convert(Vector{Y}, rows))
145150
end
146-
return f
151+
return MOI.VectorNonlinearFunction(rows)
147152
end
148153

149154
# If a ScalarNonlinearFunction has only constant arguments, we should return
@@ -1507,100 +1512,207 @@ function MOI.eval_hessian_lagrangian(model::Evaluator, H, x, σ, μ)
15071512
end
15081513

15091514
# A default fallback for all types
1510-
_add_to_affine!(::Any, ::Any, ::T) where {T} = nothing
1515+
_add_to_quadratic!(::Any, ::Real, ::Any) = nothing
1516+
_add_to_quadratic!(::Any, ::Real, ::Any, ::Any) = nothing
15111517

1512-
# The creation of `ret::MOI.ScalarAffineFunction` has been delayed until now.
1513-
function _add_to_affine!(
1514-
::Nothing,
1515-
f::Union{Real,MOI.VariableIndex,MOI.ScalarAffineFunction},
1518+
# The creation of `ret::MOI.ScalarQuadraticFunction` has been delayed until now.
1519+
function _add_to_quadratic!(
1520+
::Missing,
15161521
scale::T,
1517-
) where {T}
1518-
return _add_to_affine!(zero(MOI.ScalarAffineFunction{T}), f, scale)
1522+
f::Union{
1523+
Real,
1524+
MOI.VariableIndex,
1525+
MOI.ScalarAffineFunction,
1526+
MOI.ScalarQuadraticFunction,
1527+
}...,
1528+
) where {T<:Real}
1529+
return _add_to_quadratic!(zero(MOI.ScalarQuadraticFunction{T}), scale, f...)
1530+
end
1531+
1532+
function _add_to_quadratic!(
1533+
ret::MOI.ScalarQuadraticFunction{T},
1534+
scale::T,
1535+
x::S,
1536+
) where {T<:Real,S<:Real}
1537+
if promote_type(T, S) != T
1538+
return # We can't store `S` in `T`.
1539+
end
1540+
ret.constant += scale * convert(T, x)
1541+
return ret
15191542
end
15201543

1521-
function _add_to_affine!(
1522-
ret::MOI.ScalarAffineFunction{T},
1523-
x::S,
1544+
function _add_to_quadratic!(
1545+
ret::MOI.ScalarQuadraticFunction{T},
15241546
scale::T,
1525-
) where {T,S<:Real}
1547+
f::MOI.ScalarAffineTerm{S},
1548+
) where {T<:Real,S}
15261549
if promote_type(T, S) != T
15271550
return # We can't store `S` in `T`.
15281551
end
1529-
ret.constant += scale * convert(T, x)
1552+
push!(
1553+
ret.affine_terms,
1554+
MOI.ScalarAffineTerm{T}(scale * f.coefficient, f.variable),
1555+
)
15301556
return ret
15311557
end
15321558

1533-
function _add_to_affine!(
1534-
ret::MOI.ScalarAffineFunction{T},
1535-
x::MOI.VariableIndex,
1559+
function _add_to_quadratic!(
1560+
ret::MOI.ScalarQuadraticFunction{T},
15361561
scale::T,
1537-
) where {T}
1538-
push!(ret.terms, MOI.ScalarAffineTerm(scale, x))
1562+
f::MOI.ScalarQuadraticTerm{S},
1563+
) where {T<:Real,S}
1564+
if promote_type(T, S) != T
1565+
return # We can't store `S` in `T`.
1566+
end
1567+
push!(
1568+
ret.quadratic_terms,
1569+
MOI.ScalarQuadraticTerm{T}(
1570+
scale * f.coefficient,
1571+
f.variable_1,
1572+
f.variable_2,
1573+
),
1574+
)
15391575
return ret
15401576
end
15411577

1542-
function _add_to_affine!(
1543-
ret::MOI.ScalarAffineFunction{T},
1544-
f::MOI.ScalarAffineFunction{S},
1578+
function _add_to_quadratic!(
1579+
ret::MOI.ScalarQuadraticFunction{T},
15451580
scale::T,
1546-
) where {T,S}
1581+
x::MOI.VariableIndex,
1582+
) where {T<:Real}
1583+
return _add_to_quadratic!(ret, scale, MOI.ScalarAffineTerm(one(T), x))
1584+
end
1585+
1586+
function _add_to_quadratic!(
1587+
ret::MOI.ScalarQuadraticFunction{T},
1588+
scale::T,
1589+
f::MOI.ScalarAffineFunction{S},
1590+
) where {T<:Real,S}
15471591
if promote_type(T, S) != T
15481592
return # We can't store `S` in `T`.
15491593
end
1550-
ret = _add_to_affine!(ret, f.constant, scale)
1594+
ret = _add_to_quadratic!(ret, scale, f.constant)
15511595
for term in f.terms
1552-
ret = _add_to_affine!(ret, term.variable, scale * term.coefficient)
1596+
ret = _add_to_quadratic!(ret, scale, term)
15531597
end
15541598
return ret
15551599
end
15561600

1557-
function _add_to_affine!(
1558-
ret::Union{Nothing,MOI.ScalarAffineFunction{T}},
1559-
f::MOI.ScalarNonlinearFunction,
1601+
function _add_to_quadratic!(
1602+
ret::MOI.ScalarQuadraticFunction{T},
15601603
scale::T,
1561-
) where {T}
1604+
f::MOI.ScalarQuadraticFunction{S},
1605+
) where {T<:Real,S}
1606+
if promote_type(T, S) != T
1607+
return # We can't store `S` in `T`.
1608+
end
1609+
ret = _add_to_quadratic!(ret, scale, f.constant)
1610+
for term in f.affine_terms
1611+
ret = _add_to_quadratic!(ret, scale, term)
1612+
end
1613+
for q_term in f.quadratic_terms
1614+
ret = _add_to_quadratic!(ret, scale, q_term)
1615+
end
1616+
return ret
1617+
end
1618+
1619+
function _add_to_quadratic!(
1620+
ret::MOI.ScalarQuadraticFunction{T},
1621+
scale::T,
1622+
f::MOI.VariableIndex,
1623+
g::MOI.VariableIndex,
1624+
) where {T<:Real}
1625+
return _add_to_quadratic!(ret, scale, one(T) * f * g)
1626+
end
1627+
1628+
function _add_to_quadratic!(
1629+
ret::MOI.ScalarQuadraticFunction{T},
1630+
scale::T,
1631+
f::MOI.ScalarAffineFunction{F},
1632+
g::MOI.ScalarAffineFunction{G},
1633+
) where {T<:Real,F,G}
1634+
H = MOI.ScalarAffineFunction{promote_type(F, G)}
1635+
return _add_to_quadratic!(ret, scale, convert(H, f) * convert(H, g))
1636+
end
1637+
1638+
function _add_to_quadratic!(
1639+
ret::MOI.ScalarQuadraticFunction{T},
1640+
scale::T,
1641+
f::MOI.VariableIndex,
1642+
g::MOI.ScalarAffineFunction,
1643+
) where {T<:Real}
1644+
return _add_to_quadratic!(ret, scale, f * g)
1645+
end
1646+
1647+
function _add_to_quadratic!(
1648+
ret::MOI.ScalarQuadraticFunction{T},
1649+
scale::T,
1650+
f::MOI.ScalarAffineFunction,
1651+
g::MOI.VariableIndex,
1652+
) where {T<:Real}
1653+
return _add_to_quadratic!(ret, scale, g, f)
1654+
end
1655+
1656+
function _add_to_quadratic!(
1657+
ret::Union{Missing,MOI.ScalarQuadraticFunction{T}},
1658+
scale::T,
1659+
f::MOI.ScalarNonlinearFunction,
1660+
) where {T<:Real}
15621661
if f.head == :+
15631662
for arg in f.args
1564-
ret = _add_to_affine!(ret, arg, scale)
1663+
ret = _add_to_quadratic!(ret, scale, arg)
15651664
if ret === nothing
15661665
return
15671666
end
15681667
end
15691668
return ret
15701669
elseif f.head == :-
15711670
if length(f.args) == 1
1572-
return _add_to_affine!(ret, only(f.args), -scale)
1671+
return _add_to_quadratic!(ret, -scale, only(f.args))
15731672
end
15741673
@assert length(f.args) == 2
1575-
ret = _add_to_affine!(ret, f.args[1], scale)
1674+
ret = _add_to_quadratic!(ret, scale, f.args[1])
15761675
if ret === nothing
15771676
return
15781677
end
1579-
return _add_to_affine!(ret, f.args[2], -scale)
1678+
return _add_to_quadratic!(ret, -scale, f.args[2])
15801679
elseif f.head == :*
1581-
y = nothing
1680+
y1, y2 = nothing, nothing
15821681
for arg in f.args
15831682
if arg isa Real
15841683
scale *= arg
1585-
elseif y === nothing
1586-
y = arg
1684+
elseif y1 === nothing
1685+
y1 = arg
1686+
elseif y2 === nothing
1687+
y2 = arg
15871688
else
15881689
return # We already have a `y`. Can't multiple factors.
15891690
end
15901691
end
1591-
return _add_to_affine!(ret, something(y, one(T)), convert(T, scale))
1692+
if y1 === nothing
1693+
@assert y2 === nothing
1694+
return _add_to_quadratic!(ret, one(T), scale)
1695+
elseif y2 === nothing
1696+
return _add_to_quadratic!(ret, scale, y1)
1697+
else
1698+
return _add_to_quadratic!(ret, scale, y1, y2)
1699+
end
1700+
elseif f.head == :^ && f.args[2] isa Real && f.args[2] == 2
1701+
return _add_to_quadratic!(ret, scale, f.args[1], f.args[1])
1702+
elseif f.head == :/ && f.args[2] isa Real
1703+
return _add_to_quadratic!(ret, convert(T, scale / f.args[2]), f.args[1])
15921704
end
15931705
return # An unsupported f.head
15941706
end
15951707

1596-
function _simplify_if_affine!(f::MOI.ScalarNonlinearFunction)
1597-
ret = _add_to_affine!(nothing, f, 1.0)
1708+
function _simplify_if_quadratic!(f::MOI.ScalarNonlinearFunction)
1709+
ret = _add_to_quadratic!(missing, 1.0, f)
15981710
if ret === nothing
15991711
return f
16001712
end
1601-
return simplify!(ret::MOI.ScalarAffineFunction{Float64})
1713+
return simplify!(ret::MOI.ScalarQuadraticFunction{Float64})
16021714
end
16031715

1604-
_simplify_if_affine!(f::Any) = f
1716+
_simplify_if_quadratic!(f::Any) = f
16051717

16061718
end # module

0 commit comments

Comments
 (0)