Skip to content

Commit feb3548

Browse files
committed
Fix Base.copy for ScalarNonlinearFunction
1 parent 48ac449 commit feb3548

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/functions.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,30 @@ struct ScalarNonlinearFunction <: AbstractScalarFunction
360360
end
361361
end
362362

363+
# copy() doesn't recursively copy the children, and deepcopy seems to have a
364+
# performance problem for deeply nested structs.
363365
function Base.copy(f::ScalarNonlinearFunction)
364-
return ScalarNonlinearFunction(f.head, copy(f.args))
366+
stack, result_stack = Any[f], Any[]
367+
while !isempty(stack)
368+
arg = pop!(stack)
369+
if arg isa ScalarNonlinearFunction
370+
# We need some sort of hint so that the next time we see this on the
371+
# stack we evaluate it using the args in `result_stack`. One option
372+
# would be a custom type. Or we can just wrap in (,) and then check
373+
# for a Tuple, which isn't (curretly) a valid argument.
374+
push!(stack, (arg,))
375+
for child in arg.args
376+
push!(stack, child)
377+
end
378+
elseif arg isa Tuple{<:ScalarNonlinearFunction}
379+
result = only(arg)
380+
args = Any[pop!(result_stack) for i in 1:length(result.args)]
381+
push!(result_stack, ScalarNonlinearFunction(result.head, args))
382+
else
383+
push!(result_stack, copy(arg))
384+
end
385+
end
386+
return only(result_stack)
365387
end
366388

367389
constant(f::ScalarNonlinearFunction, ::Type{T} = Float64) where {T} = zero(T)

test/functions.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,48 @@ function test_convert_VectorAffineFunction_VectorQuadraticFunction()
469469
return
470470
end
471471

472+
function test_copy_ScalarNonlinearFunction()
473+
N = 10_000
474+
x = MOI.VariableIndex.(1:N)
475+
f1 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1])
476+
for i in 2:N
477+
g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1])
478+
f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g])
479+
end
480+
f2 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1])
481+
for i in 2:N
482+
g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1])
483+
f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g])
484+
end
485+
f_copy = copy(f1)
486+
@test (f_copy, f2)
487+
f1.args[2].args[2] = 2.0 # x[1]^1 --> x[1]^2
488+
@test !isapprox(f_copy, f1)
489+
@test isapprox(f_copy, f2)
490+
return
491+
end
492+
493+
function test_copy_ScalarNonlinearFunction_with_arg()
494+
N = 10_000
495+
x = MOI.VariableIndex.(1:N)
496+
f1 = 1.0 * x[1] + 1.0
497+
for i in 2:N
498+
g = f1 = Float64(i) * x[i] + Float64(i)
499+
f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g])
500+
end
501+
f2 = 1.0 * x[1] + 1.0
502+
for i in 2:N
503+
g = f2 = Float64(i) * x[i] + Float64(i)
504+
f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g])
505+
end
506+
f_copy = copy(f1)
507+
@test (f_copy, f2)
508+
f1.args[2].constant += 1
509+
@test !isapprox(f_copy, f1)
510+
@test isapprox(f_copy, f2)
511+
return
512+
end
513+
472514
function runtests()
473515
for name in names(@__MODULE__; all = true)
474516
if startswith("$name", "test_")

0 commit comments

Comments
 (0)