Skip to content

Commit e428fa8

Browse files
committed
Add _simplify_if_affine
1 parent 148398f commit e428fa8

File tree

3 files changed

+162
-29
lines changed

3 files changed

+162
-29
lines changed

docs/src/submodules/Nonlinear/SymbolicAD.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ julia> f = MOI.ScalarNonlinearFunction(
103103
+(+(1.0, MOI.VariableIndex(1)), 3.0 + 2.0 MOI.VariableIndex(1))
104104
105105
julia> MOI.Nonlinear.SymbolicAD.simplify(f)
106-
+(1.0, MOI.VariableIndex(1), 3.0 + 2.0 MOI.VariableIndex(1))
106+
4.0 + 3.0 MOI.VariableIndex(1)
107107
```
108108

109109
and trivial identities such as ``x^1 = x``:

src/Nonlinear/SymbolicAD/SymbolicAD.jl

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function simplify!(f::MOI.ScalarNonlinearFunction)
117117
push!(result_stack, arg)
118118
end
119119
end
120-
return only(result_stack)
120+
return _simplify_if_affine!(only(result_stack))
121121
end
122122

123123
function simplify!(f::MOI.VectorAffineFunction{T}) where {T}
@@ -1509,4 +1509,101 @@ function MOI.eval_hessian_lagrangian(model::Evaluator, H, x, σ, μ)
15091509
return
15101510
end
15111511

1512+
# A default fallback for all types
1513+
_add_to_affine!(::Any, ::Any, ::T) where {T} = nothing
1514+
1515+
# The creation of `ret::MOI.ScalarAffineFunction` has been delayed until now!
1516+
function _add_to_affine!(
1517+
::Nothing,
1518+
f::Union{Real,MOI.VariableIndex,MOI.ScalarAffineFunction},
1519+
scale::T,
1520+
) where {T}
1521+
return _add_to_affine!(zero(MOI.ScalarAffineFunction{T}), f, scale)
1522+
end
1523+
1524+
function _add_to_affine!(
1525+
ret::MOI.ScalarAffineFunction{T},
1526+
x::S,
1527+
scale::T,
1528+
) where {T,S<:Real}
1529+
if promote_type(T, S) != T
1530+
return # We can't store `S` in `T`.
1531+
end
1532+
ret.constant += scale * convert(T, x)
1533+
return ret
1534+
end
1535+
1536+
function _add_to_affine!(
1537+
ret::MOI.ScalarAffineFunction{T},
1538+
x::MOI.VariableIndex,
1539+
scale::T,
1540+
) where {T}
1541+
push!(ret.terms, MOI.ScalarAffineTerm(scale, x))
1542+
return ret
1543+
end
1544+
1545+
function _add_to_affine!(
1546+
ret::MOI.ScalarAffineFunction{T},
1547+
f::MOI.ScalarAffineFunction{S},
1548+
scale::T,
1549+
) where {T,S}
1550+
if promote_type(T, S) != T
1551+
return # We can't store `S` in `T`.
1552+
end
1553+
ret = _add_to_affine!(ret, f.constant, scale)
1554+
for term in f.terms
1555+
ret = _add_to_affine!(ret, term.variable, scale * term.coefficient)
1556+
end
1557+
return ret
1558+
end
1559+
1560+
function _add_to_affine!(
1561+
ret::Union{Nothing,MOI.ScalarAffineFunction{T}},
1562+
f::MOI.ScalarNonlinearFunction,
1563+
scale::T,
1564+
) where {T}
1565+
if f.head == :+
1566+
for arg in f.args
1567+
ret = _add_to_affine!(ret, arg, scale)
1568+
if ret === nothing
1569+
return
1570+
end
1571+
end
1572+
return ret
1573+
elseif f.head == :-
1574+
if length(f.args) == 1
1575+
return _add_to_affine!(ret, only(f.args), -scale)
1576+
end
1577+
@assert length(f.args) == 2
1578+
ret = _add_to_affine!(ret, f.args[1], scale)
1579+
if ret === nothing
1580+
return
1581+
end
1582+
return _add_to_affine!(ret, f.args[2], -scale)
1583+
elseif f.head == :*
1584+
y = nothing
1585+
for arg in f.args
1586+
if arg isa Real
1587+
scale *= arg
1588+
elseif y === nothing
1589+
y = arg
1590+
else
1591+
return # We already have a `y`. Can't multiple factors.
1592+
end
1593+
end
1594+
return _add_to_affine!(ret, something(y, one(T)), convert(T, scale))
1595+
end
1596+
return # An unsupported f.head
1597+
end
1598+
1599+
function _simplify_if_affine!(f::MOI.ScalarNonlinearFunction)
1600+
ret = _add_to_affine!(nothing, f, 1.0)
1601+
if ret === nothing
1602+
return f
1603+
end
1604+
return simplify!(ret::MOI.ScalarAffineFunction{Float64})
1605+
end
1606+
1607+
_simplify_if_affine!(f::Any) = f
1608+
15121609
end # module

test/Nonlinear/SymbolicAD.jl

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ end
239239
# simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
240240
function test_simplify_ScalarNonlinearFunction_multiplication()
241241
x, y, z = MOI.VariableIndex.(1:3)
242+
sinx = op(:sin, x)
242243
# *(x, *(y, z)) -> *(x, y, z)
243244
@test (SymbolicAD.simplify(op(:*, x, op(:*, y, z))), op(:*, x, y, z))
244245
# *(x, *(y, z, *(x, 2))) -> *(x, y, z, x, 2)
@@ -248,11 +249,11 @@ function test_simplify_ScalarNonlinearFunction_multiplication()
248249
op(:*, x, y, z, x, 2),
249250
)
250251
# *(x, 3, 2) -> *(x, 6)
251-
ret = op(:*, x, 3, 2)
252-
@test (SymbolicAD.simplify(ret), op(:*, x, 6))
252+
@test (SymbolicAD.simplify(op(:*, x, 3, 2)), 6.0 * x)
253+
@test (SymbolicAD.simplify(op(:*, sinx, 3, 2)), op(:*, sinx, 6))
253254
# *(3, x, 2) -> *(6, x)
254-
ret = op(:*, 3, x, 2)
255-
@test (SymbolicAD.simplify(ret), op(:*, 6, x))
255+
@test (SymbolicAD.simplify(op(:*, 3, x, 2)), 6.0 * x)
256+
@test (SymbolicAD.simplify(op(:*, 3, sinx, 2)), op(:*, 6, sinx))
256257
# *(x, 1) -> x
257258
ret = op(:*, x, 1)
258259
@test (SymbolicAD.simplify(ret), x)
@@ -272,55 +273,56 @@ end
272273
# simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
273274
function test_simplify_ScalarNonlinearFunction_addition()
274275
x, y, z = MOI.VariableIndex.(1:3)
275-
# (+(x, +(y, z)))=>(+(x, y, z)),
276-
@test (SymbolicAD.simplify(op(:+, x, op(:+, y, z))), op(:+, x, y, z))
277-
# +(sin(x), -cos(x))=>sin(x)-cos(x),
278276
sinx = op(:sin, x)
279277
cosx = op(:cos, x)
278+
# (+(x, +(y, z)))=>(+(x, y, z)),
279+
@test (SymbolicAD.simplify(op(:+, sinx, op(:+, y, z))), op(:+, sinx, y, z))
280+
@test (
281+
SymbolicAD.simplify(op(:+, x, op(:+, y, z))),
282+
1.0 * x + 1.0 * y + 1.0 * z,
283+
)
284+
# +(sin(x), -cos(x))=>sin(x)-cos(x),
280285
@test (SymbolicAD.simplify(op(:+, sinx, op(:-, cosx))), op(:-, sinx, cosx))
281286
# (+(x, 1, 2))=>(+(x, 3)),
282-
ret = op(:+, x, 1, 2)
283-
@test (SymbolicAD.simplify(ret), op(:+, x, 3))
284-
# (+(1, x, 2))=>(+(3, x)),
285-
ret = op(:+, 1, x, 2)
286-
@test (SymbolicAD.simplify(ret), op(:+, 3, x))
287+
@test (SymbolicAD.simplify(op(:+, x, 1, 2)), x + 3.0)
288+
@test (SymbolicAD.simplify(op(:+, sinx, 1, 2)), op(:+, sinx, 3))
289+
# (+(1, x, 2))=>(+(3, x)),ret =
290+
@test (SymbolicAD.simplify(op(:+, 1, x, 2)), x + 3.0)
291+
@test (SymbolicAD.simplify(op(:+, 1, sinx, 2)), op(:+, 3, sinx))
287292
# +(x, 0) -> x
288-
ret = op(:+, x, 0)
289-
@test SymbolicAD.simplify(ret) x
293+
@test SymbolicAD.simplify(op(:+, x, 0)) x
290294
# +(0, x) -> x
291-
ret = op(:+, 0, x)
292-
@test SymbolicAD.simplify(ret) x
295+
@test SymbolicAD.simplify(op(:+, 0, x)) x
293296
# +(-(x, x), 0) -> 0
294-
f = op(:+, op(:-, x, x), 0)
295-
@test SymbolicAD.simplify(f) === false
297+
@test SymbolicAD.simplify(op(:+, op(:-, x, x), 0)) === false
296298
return
297299
end
298300

299301
# simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
300302
function test_simplify_ScalarNonlinearFunction_subtraction()
301303
x, y = MOI.VariableIndex(1), MOI.VariableIndex(2)
304+
sinx = op(:sin, x)
302305
f = op(:-, x)
303306
# -x -> -x
304-
@test SymbolicAD.simplify(f) f
307+
@test SymbolicAD.simplify(op(:-, x)) -1.0 * x
308+
@test SymbolicAD.simplify(op(:-, sinx)) op(:-, sinx)
305309
# -(-(x)) -> x
306310
ret = op(:-, f)
307311
@test SymbolicAD.simplify(ret) x
308312
# -(x, 0) -> x
309313
ret = op(:-, x, 0)
310314
@test SymbolicAD.simplify(ret) x
311315
# -(0, x) -> -x
312-
ret = op(:-, 0, x)
313-
@test SymbolicAD.simplify(ret) f
316+
@test SymbolicAD.simplify(op(:-, 0, sinx)) op(:-, sinx)
314317
# -(x, x) -> 0
315318
ret = op(:-, x, x)
316319
@test SymbolicAD.simplify(ret) 0
317320
# -(x, -y) -> +(x, y)
318-
f = op(:-, x, op(:-, y))
319-
target = op(:+, x, y)
320-
@test SymbolicAD.simplify(f) target
321+
@test SymbolicAD.simplify(op(:-, x, op(:-, y))) 1.0 * x + 1.0 * y
322+
@test SymbolicAD.simplify(op(:-, sinx, op(:-, y))) op(:+, sinx, y)
321323
# -(x, y) -> -(x, y)
322-
f = op(:-, x, y)
323-
@test SymbolicAD.simplify(f) f
324+
@test SymbolicAD.simplify(op(:-, sinx, y)) op(:-, sinx, y)
325+
@test SymbolicAD.simplify(op(:-, x, y)) 1.0 * x - 1.0 * y
324326
return
325327
end
326328

@@ -412,7 +414,8 @@ function test_simplify_deep()
412414
g = op(:^, x[i], 1)
413415
f = op(:+, f, g)
414416
end
415-
@test (SymbolicAD.simplify(f), op(:+, convert(Vector{Any}, x)))
417+
ret = MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(1.0, x), 0.0)
418+
@test (SymbolicAD.simplify(f), ret)
416419
return
417420
end
418421

@@ -700,6 +703,39 @@ function test_SymbolicAD_univariate_registered()
700703
return
701704
end
702705

706+
function test_simplify_if_affine()
707+
x = MOI.VariableIndex(1)
708+
for (f, ret) in Any[
709+
op(:*, 2)=>2,
710+
op(:*, 2 // 3)=>2/3,
711+
op(:*, 2, 3)=>6,
712+
op(:*, 2, x, 3)=>6*x,
713+
op(:+, 2, 3)=>5,
714+
op(:-, 2)=>-2,
715+
op(:-, 2, 3)=>-1,
716+
op(:-, x)=>-1*x,
717+
op(:-, x, 2)=>1.0*x-2.0,
718+
op(:-, 2, x)=>2.0+-1.0*x,
719+
op(:+, 2, x)=>2+x,
720+
op(:+, x, x)=>2.0*x,
721+
op(:+, x, 2, x)=>2.0*x+2.0,
722+
op(:+, x, 2, op(:+, x))=>2.0*x+2.0,
723+
# Early termination because not affine
724+
op(:+, op(:sin, x))=>nothing,
725+
op(:-, op(:sin, x))=>nothing,
726+
op(:-, op(:sin, x), 1)=>nothing,
727+
op(:-, x, op(:sin, x))=>nothing,
728+
op(:*, 2, x, 3, x)=>nothing,
729+
op(:*, 2, 3, op(:sin, x))=>nothing,
730+
op(:log, x)=>nothing,
731+
op(:+, big(1) * x, big(2))=>nothing,
732+
op(:+, x, big(2))=>nothing,
733+
]
734+
@test SymbolicAD._simplify_if_affine!(f) something(ret, f)
735+
end
736+
return
737+
end
738+
703739
end # module
704740

705741
TestMathOptSymbolicAD.runtests()

0 commit comments

Comments
 (0)