Skip to content

Commit 108629e

Browse files
committed
Update
1 parent fad13e6 commit 108629e

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

src/Nonlinear/SymbolicAD/SymbolicAD.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,27 +241,64 @@ function simplify!(::Val{:*}, f::MOI.ScalarNonlinearFunction)
241241
return f
242242
end
243243

244+
# The MOI.Utilities.operate!(+, T, x, y) methods do not cope with mixed input
245+
# types. However, ScalarNonlinearFunction can hold various <:Real coefficient
246+
# types.
247+
_add_to!(x::Real, y::Real) = x + y
248+
249+
_add_to!(x::Real, y::MOI.ScalarAffineFunction{<:Real}) = _add_to!(y, x)
250+
251+
function _add_to!(x::MOI.ScalarAffineFunction{T}, y::T) where {T<:Real}
252+
return MOI.Utilities.operate!(+, T, x, y)
253+
end
254+
255+
function _add_to!(x::MOI.ScalarAffineFunction{S}, y::T) where {S<:Real,T<:Real}
256+
U = promote_type(S, T)
257+
F = MOI.ScalarAffineFunction{U}
258+
return MOI.Utilities.operate!(+, U, convert(F, x), convert(U, y))
259+
end
260+
261+
function _add_to!(
262+
x::MOI.ScalarAffineFunction{T},
263+
y::MOI.ScalarAffineFunction{T},
264+
) where {T<:Real}
265+
return MOI.Utilities.operate!(+, T, x, y)
266+
end
267+
268+
function _add_to!(
269+
x::MOI.ScalarAffineFunction{S},
270+
y::MOI.ScalarAffineFunction{T},
271+
) where {S<:Real,T<:Real}
272+
U = promote_type(S, T)
273+
F = MOI.ScalarAffineFunction{U}
274+
return MOI.Utilities.operate!(+, U, convert(F, x), convert(F, y))
275+
end
276+
244277
function simplify!(::Val{:+}, f::MOI.ScalarNonlinearFunction)
245278
new_args = Any[]
246-
first_constant = 0
279+
first_affine_term = 0
247280
for arg in f.args
248281
if _isexpr(arg, :+)
249282
# If a child is a :+, lift its arguments to the parent
250283
append!(new_args, arg.args)
251284
elseif _iszero(arg)
252285
# Skip any zero arguments
253-
elseif arg isa Real
254-
# Collect all constant arguments into a single value
255-
if first_constant == 0
286+
elseif arg isa Real || arg isa MOI.ScalarAffineFunction{<:Real}
287+
# Collect all affine arguments into a single value
288+
if first_affine_term == 0
256289
push!(new_args, arg)
257-
first_constant = length(new_args)
290+
first_affine_term = length(new_args)
258291
else
259-
new_args[first_constant] += arg
292+
new_args[first_affine_term] =
293+
_add_to!(new_args[first_affine_term], arg)
260294
end
261295
else
262296
push!(new_args, arg)
263297
end
264298
end
299+
if first_affine_term !== 0 && !(new_args[first_affine_term] isa Real)
300+
MOI.Utilities.canonicalize!(new_args[first_affine_term])
301+
end
265302
if length(new_args) == 0
266303
# +() -> false
267304
return false

test/Nonlinear/SymbolicAD.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,32 @@ function test_simplify_ScalarNonlinearFunction_addition()
348348
return
349349
end
350350

351+
function test_simplify_ScalarNonlinearFunction_addition_terms()
352+
x = MOI.VariableIndex(1)
353+
for (args, ret) in Any[
354+
# add_to!(::Real, ::Real)
355+
Any[2.0, 3.0]=>5.0,
356+
Any[2.5, 3]=>5.5,
357+
Any[3, 2.5]=>5.5,
358+
# add_to!(::ScalarAffineFunction{T}, ::ScalarAffineFunction{T})
359+
Any[2.0*x, 3.0*x+4.0]=>5.0*x+4.0,
360+
# add_to!(::ScalarAffineFunction{T}, ::T)
361+
Any[2.0*x, 1.0]=>2.0*x+1.0,
362+
# add_to!(::ScalarAffineFunction{S}, ::T)
363+
Any[2.0*x, 3]=>2.0*x+3.0,
364+
Any[2*x, 3.0]=>2.0*x+3.0,
365+
# add_to!(::ScalarAffineFunction{S}, ::ScalarAffineFunction{T})
366+
Any[2*x, 3.0*x+4.0]=>5.0*x+4.0,
367+
Any[3.0*x+4.0, 2*x]=>5.0*x+4.0,
368+
Any[3*x+4, 1.5, 2*x]=>5.0*x+5.5,
369+
Any[1.5, 3*x+4, 2*x]=>5.0*x+5.5,
370+
]
371+
f = MOI.ScalarNonlinearFunction(:+, args)
372+
@test SymbolicAD.simplify(f) ret
373+
end
374+
return
375+
end
376+
351377
# simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
352378
function test_simplify_ScalarNonlinearFunction_subtraction()
353379
x, y = MOI.VariableIndex(1), MOI.VariableIndex(2)

0 commit comments

Comments
 (0)