Skip to content

Commit 875387a

Browse files
authored
[Nonlinear] improve test coverage of operators.jl (#2650)
1 parent 899fd5a commit 875387a

File tree

2 files changed

+102
-29
lines changed

2 files changed

+102
-29
lines changed

src/Nonlinear/operators.jl

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,31 @@ function _create_binary_switch(ids, exprs)
1111
push!(out.args, _create_binary_switch(ids[2:end], exprs[2:end]))
1212
end
1313
return out
14-
else
15-
mid = length(exprs) >>> 1
16-
return Expr(
17-
:if,
18-
Expr(:call, :(<=), :id, ids[mid]),
19-
_create_binary_switch(ids[1:mid], exprs[1:mid]),
20-
_create_binary_switch(ids[mid+1:end], exprs[mid+1:end]),
21-
)
2214
end
15+
mid = length(exprs) >>> 1
16+
return Expr(
17+
:if,
18+
Expr(:call, :(<=), :id, ids[mid]),
19+
_create_binary_switch(ids[1:mid], exprs[1:mid]),
20+
_create_binary_switch(ids[mid+1:end], exprs[mid+1:end]),
21+
)
2322
end
2423

2524
# We use a let block here for `expr` to create a local variable that does not
2625
# persist in the scope of the module. All we care about is the _eval_univariate
2726
# function that is eval'd as a result.
28-
let exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
27+
let
28+
exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
2929
return :(return $(arg[1])(x), $(arg[2]))
3030
end
3131
@eval @inline function _eval_univariate(id, x::T) where {T}
3232
$(_create_binary_switch(1:length(exprs), exprs))
33-
return error("Invalid operator_id")
34-
end
35-
end
36-
37-
# We use a let block here for `expr` to create a local variable that does not
38-
# persist in the scope of the module. All we care about is the function that is
39-
# eval'd as a result.
40-
let exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
41-
if arg === :(nothing) # f''(x) isn't defined
42-
:(error("Invalid operator_id"))
43-
else
44-
:(return $(arg[3]))
45-
end
33+
return error("Invalid id for univariate operator: $id")
4634
end
35+
∇²f_exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)
4736
@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}
48-
$(_create_binary_switch(1:length(exprs), exprs))
49-
return error("Invalid operator_id")
37+
$(_create_binary_switch(1:length(∇²f_exprs), ∇²f_exprs))
38+
return error("Invalid id for univariate operator: $id")
5039
end
5140
end
5241

@@ -339,7 +328,7 @@ function _validate_register_assumptions(
339328
y = f(zeros(dimension)...)
340329
end
341330
catch
342-
# We hit some other error, perhaps we called a function like log(0).
331+
# We hit some other error, perhaps we called a function like log(-1).
343332
# Ignore for now, and hope that a useful error is shown to the user
344333
# during the solve.
345334
end
@@ -363,7 +352,7 @@ function _validate_register_assumptions(
363352
_FORWARD_DIFF_METHOD_ERROR_HELPER,
364353
)
365354
end
366-
# We hit some other error, perhaps we called a function like log(0).
355+
# We hit some other error, perhaps we called a function like log(-1).
367356
# Ignore for now, and hope that a useful error is shown to the user
368357
# during the solve.
369358
end
@@ -747,7 +736,12 @@ function eval_univariate_hessian(
747736
x::T,
748737
) where {T}
749738
if id <= registry.univariate_user_operator_start
750-
return _eval_univariate_2nd_deriv(id, x)::T
739+
ret = _eval_univariate_2nd_deriv(id, x)
740+
if ret === nothing
741+
op = registry.univariate_operators[id]
742+
error("Hessian is not defined for operator $op")
743+
end
744+
return ret::T
751745
end
752746
offset = id - registry.univariate_user_operator_start
753747
operator = registry.registered_univariate_operators[offset]
@@ -910,13 +904,15 @@ _nan_to_zero(x) = isnan(x) ? 0.0 : x
910904
op::Symbol,
911905
H::AbstractMatrix,
912906
x::AbstractVector{T},
913-
) where {T}
907+
)::Bool where {T}
914908
915909
Evaluate the Hessian of operator `∇²op(x)`, where `op` is a multivariate
916910
function in `registry`.
917911
918912
The Hessian is stored in the lower-triangular part of the matrix `H`.
919913
914+
Returns a `Bool` indicating whether non-zeros were stored in the matrix.
915+
920916
!!! note
921917
Implementations of the Hessian operators will not fill structural zeros.
922918
Therefore, before calling this function you should pre-populate the matrix

test/Nonlinear/Nonlinear.jl

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,28 @@ function test_eval_univariate_function()
374374
return
375375
end
376376

377+
function test_eval_univariate_missing_hessian()
378+
r = Nonlinear.OperatorRegistry()
379+
x = 2.0
380+
@test Nonlinear.eval_univariate_function(r, :asec, x) asec(x)
381+
@test Nonlinear.eval_univariate_gradient(r, :asec, x)
382+
1 / (abs(x) * sqrt(x^2 - 1))
383+
@test_throws(
384+
ErrorException("Hessian is not defined for operator asec"),
385+
Nonlinear.eval_univariate_hessian(r, :asec, x),
386+
)
387+
return
388+
end
389+
390+
function test_eval_univariate_hessian_bad_id()
391+
r = Nonlinear.OperatorRegistry()
392+
err = ErrorException("Invalid id for univariate operator: -1")
393+
@test_throws err Nonlinear.eval_univariate_function(r, -1, 1.0)
394+
@test_throws err Nonlinear.eval_univariate_gradient(r, -1, 1.0)
395+
@test_throws err Nonlinear.eval_univariate_hessian(r, -1, 1.0)
396+
return
397+
end
398+
377399
function test_eval_univariate_gradient()
378400
r = Nonlinear.OperatorRegistry()
379401
for (op, x, y) in [
@@ -594,7 +616,29 @@ function test_eval_multivariate_gradient_mult()
594616
x = [1.1, 0.0, 2.2]
595617
g = zeros(3)
596618
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
597-
@test g == [0.0, 1.1 * 2.2, 0.0]
619+
@test g [0.0, 1.1 * 2.2, 0.0]
620+
x = [1.1, 3.3, 2.2]
621+
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
622+
@test g [3.3 * 2.2, 1.1 * 2.2, 1.1 * 3.3]
623+
return
624+
end
625+
626+
function test_eval_multivariate_gradient_univariate_mult()
627+
r = Nonlinear.OperatorRegistry()
628+
x = [1.1]
629+
g = zeros(1)
630+
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
631+
@test g == [1.0]
632+
return
633+
end
634+
635+
function test_eval_multivariate_hessian_shortcut()
636+
r = Nonlinear.OperatorRegistry()
637+
x = [1.1]
638+
H = LinearAlgebra.LowerTriangular(zeros(1, 1))
639+
for op in (:+, :-, :ifelse)
640+
@test !MOI.Nonlinear.eval_multivariate_hessian(r, op, H, x)
641+
end
598642
return
599643
end
600644

@@ -670,6 +714,18 @@ function test_eval_multivariate_function_registered()
670714
return
671715
end
672716

717+
function test_eval_multivariate_function_registered_log()
718+
r = Nonlinear.OperatorRegistry()
719+
f(x...) = log(x[1] - 1)
720+
Nonlinear.register_operator(r, :f, 2, f)
721+
x = [1.1, 2.2]
722+
@test Nonlinear.eval_multivariate_function(r, :f, x) f(x...)
723+
x = [0.0, 0.0]
724+
g = zeros(2)
725+
@test_throws DomainError Nonlinear.eval_multivariate_gradient(r, :f, g, x)
726+
return
727+
end
728+
673729
function test_eval_multivariate_function_method_error()
674730
r = Nonlinear.OperatorRegistry()
675731
function f(x...)
@@ -1327,6 +1383,27 @@ function test_convert_to_expr()
13271383
return
13281384
end
13291385

1386+
function test_create_binary_switch()
1387+
target = Expr(
1388+
:if,
1389+
Expr(:call, :(<=), :id, 2),
1390+
Expr(
1391+
:if,
1392+
Expr(:call, :(==), :id, 1),
1393+
:a,
1394+
Expr(:if, Expr(:call, :(==), :id, 2), :b),
1395+
),
1396+
Expr(
1397+
:if,
1398+
Expr(:call, :(==), :id, 3),
1399+
:c,
1400+
Expr(:if, Expr(:call, :(==), :id, 4), :d),
1401+
),
1402+
)
1403+
@test MOI.Nonlinear._create_binary_switch(1:4, [:a, :b, :c, :d]) == target
1404+
return
1405+
end
1406+
13301407
end # TestNonlinear
13311408

13321409
TestNonlinear.runtests()

0 commit comments

Comments
 (0)