Skip to content

Commit 88c9776

Browse files
authored
[Nonlinear] fix performance of evaluating univariate operators (#2620)
1 parent 6468097 commit 88c9776

File tree

5 files changed

+227
-33
lines changed

5 files changed

+227
-33
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function _forward_eval_ϵ(
376376
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
377377
f′′ = Nonlinear.eval_univariate_hessian(
378378
user_operators,
379-
user_operators.univariate_operators[node.index],
379+
node.index,
380380
ex.forward_storage[child_idx],
381381
)
382382
partials_storage_ϵ[child_idx] = f′′ * storage_ϵ[child_idx]

src/Nonlinear/ReverseAD/reverse_mode.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,13 @@ function _forward_eval(
248248
end
249249
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
250250
child_idx = children_arr[f.adj.colptr[k]]
251-
f.forward_storage[k] = Nonlinear.eval_univariate_function(
251+
ret_f, ret_f′ = Nonlinear.eval_univariate_function_and_gradient(
252252
operators,
253-
operators.univariate_operators[node.index],
254-
f.forward_storage[child_idx],
255-
)
256-
f.partials_storage[child_idx] = Nonlinear.eval_univariate_gradient(
257-
operators,
258-
operators.univariate_operators[node.index],
253+
node.index,
259254
f.forward_storage[child_idx],
260255
)
256+
f.forward_storage[k] = ret_f
257+
f.partials_storage[child_idx] = ret_f′
261258
elseif node.type == Nonlinear.NODE_COMPARISON
262259
children_idx = SparseArrays.nzrange(f.adj, k)
263260
result = true

src/Nonlinear/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ function evaluate(
377377
child_idx = children_arr[adj.colptr[k]]
378378
storage[k] = eval_univariate_function(
379379
model.operators,
380-
model.operators.univariate_operators[node.index],
380+
node.index,
381381
storage[child_idx],
382382
)
383383
elseif node.type == NODE_COMPARISON

src/Nonlinear/operators.jl

Lines changed: 172 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ struct _UnivariateOperator{F,F′,F′′}
6363
end
6464
end
6565

66+
function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
67+
ret = operator.f(x)
68+
check_return_type(T, ret)
69+
return ret::T
70+
end
71+
72+
function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T}
73+
ret = operator.f′(x)
74+
check_return_type(T, ret)
75+
return ret::T
76+
end
77+
78+
function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T}
79+
ret = operator.f′′(x)
80+
check_return_type(T, ret)
81+
return ret::T
82+
end
83+
84+
function eval_univariate_function_and_gradient(
85+
operator::_UnivariateOperator,
86+
x::T,
87+
) where {T}
88+
ret_f = eval_univariate_function(operator, x)
89+
ret_f′ = eval_univariate_gradient(operator, x)
90+
return ret_f, ret_f′
91+
end
92+
6693
struct _MultivariateOperator{F,F′,F′′}
6794
N::Int
6895
f::F
@@ -517,81 +544,214 @@ end
517544
"""
518545
eval_univariate_function(
519546
registry::OperatorRegistry,
520-
op::Symbol,
547+
op::Union{Symbol,Integer},
521548
x::T,
522549
) where {T}
523550
524551
Evaluate the operator `op(x)::T`, where `op` is a univariate function in
525552
`registry`.
553+
554+
If `op isa Integer`, then `op` is the index in
555+
`registry.univariate_operators[op]`.
556+
557+
## Example
558+
559+
```jldoctest
560+
julia> import MathOptInterface as MOI
561+
562+
julia> r = MOI.Nonlinear.OperatorRegistry();
563+
564+
julia> MOI.Nonlinear.eval_univariate_function(r, :abs, -1.2)
565+
1.2
566+
567+
julia> r.univariate_operators[3]
568+
:abs
569+
570+
julia> MOI.Nonlinear.eval_univariate_function(r, 3, -1.2)
571+
1.2
572+
```
526573
"""
527574
function eval_univariate_function(
528575
registry::OperatorRegistry,
529576
op::Symbol,
530577
x::T,
531578
) where {T}
532579
id = registry.univariate_operator_to_id[op]
580+
return eval_univariate_function(registry, id, x)
581+
end
582+
583+
function eval_univariate_function(
584+
registry::OperatorRegistry,
585+
id::Integer,
586+
x::T,
587+
) where {T}
533588
if id <= registry.univariate_user_operator_start
534589
f, _ = _eval_univariate(id, x)
535590
return f::T
536591
end
537592
offset = id - registry.univariate_user_operator_start
538593
operator = registry.registered_univariate_operators[offset]
539-
ret = operator.f(x)
540-
check_return_type(T, ret)
541-
return ret::T
594+
return eval_univariate_function(operator, x)
542595
end
543596

544597
"""
545598
eval_univariate_gradient(
546599
registry::OperatorRegistry,
547-
op::Symbol,
600+
op::Union{Symbol,Integer},
548601
x::T,
549602
) where {T}
550603
551604
Evaluate the first-derivative of the operator `op(x)::T`, where `op` is a
552605
univariate function in `registry`.
606+
607+
If `op isa Integer`, then `op` is the index in
608+
`registry.univariate_operators[op]`.
609+
610+
## Example
611+
612+
```jldoctest
613+
julia> import MathOptInterface as MOI
614+
615+
julia> r = MOI.Nonlinear.OperatorRegistry();
616+
617+
julia> MOI.Nonlinear.eval_univariate_gradient(r, :abs, -1.2)
618+
-1.0
619+
620+
julia> r.univariate_operators[3]
621+
:abs
622+
623+
julia> MOI.Nonlinear.eval_univariate_gradient(r, 3, -1.2)
624+
-1.0
625+
```
553626
"""
554627
function eval_univariate_gradient(
555628
registry::OperatorRegistry,
556629
op::Symbol,
557630
x::T,
558631
) where {T}
559632
id = registry.univariate_operator_to_id[op]
633+
return eval_univariate_gradient(registry, id, x)
634+
end
635+
636+
function eval_univariate_gradient(
637+
registry::OperatorRegistry,
638+
id::Integer,
639+
x::T,
640+
) where {T}
560641
if id <= registry.univariate_user_operator_start
561642
_, f′ = _eval_univariate(id, x)
562643
return f′::T
563644
end
564645
offset = id - registry.univariate_user_operator_start
565646
operator = registry.registered_univariate_operators[offset]
566-
ret = operator.f′(x)
567-
check_return_type(T, ret)
568-
return ret::T
647+
return eval_univariate_gradient(operator, x)
648+
end
649+
650+
"""
651+
eval_univariate_function_and_gradient(
652+
registry::OperatorRegistry,
653+
op::Union{Symbol,Integer},
654+
x::T,
655+
)::Tuple{T,T} where {T}
656+
657+
Evaluate the function and first-derivative of the operator `op(x)::T`, where
658+
`op` is a univariate function in `registry`.
659+
660+
If `op isa Integer`, then `op` is the index in
661+
`registry.univariate_operators[op]`.
662+
663+
## Example
664+
665+
```jldoctest
666+
julia> import MathOptInterface as MOI
667+
668+
julia> r = MOI.Nonlinear.OperatorRegistry();
669+
670+
julia> MOI.Nonlinear.eval_univariate_function_and_gradient(r, :abs, -1.2)
671+
(1.2, -1.0)
672+
673+
julia> r.univariate_operators[3]
674+
:abs
675+
676+
julia> MOI.Nonlinear.eval_univariate_function_and_gradient(r, 3, -1.2)
677+
(1.2, -1.0)
678+
```
679+
"""
680+
function eval_univariate_function_and_gradient(
681+
registry::OperatorRegistry,
682+
op::Symbol,
683+
x::T,
684+
) where {T}
685+
id = registry.univariate_operator_to_id[op]
686+
return eval_univariate_function_and_gradient(registry, id, x)
687+
end
688+
689+
function eval_univariate_function_and_gradient(
690+
registry::OperatorRegistry,
691+
id::Integer,
692+
x::T,
693+
) where {T}
694+
if id <= registry.univariate_user_operator_start
695+
return _eval_univariate(id, x)::Tuple{T,T}
696+
end
697+
offset = id - registry.univariate_user_operator_start
698+
operator = registry.registered_univariate_operators[offset]
699+
return eval_univariate_function_and_gradient(operator, x)
569700
end
570701

571702
"""
572703
eval_univariate_hessian(
573704
registry::OperatorRegistry,
574-
op::Symbol,
705+
op::Union{Symbol,Integer},
575706
x::T,
576707
) where {T}
577708
578709
Evaluate the second-derivative of the operator `op(x)::T`, where `op` is a
579710
univariate function in `registry`.
711+
712+
If `op isa Integer`, then `op` is the index in
713+
`registry.univariate_operators[op]`.
714+
715+
## Example
716+
717+
```jldoctest
718+
julia> import MathOptInterface as MOI
719+
720+
julia> r = MOI.Nonlinear.OperatorRegistry();
721+
722+
julia> MOI.Nonlinear.eval_univariate_hessian(r, :sin, 1.0)
723+
-0.8414709848078965
724+
725+
julia> r.univariate_operators[16]
726+
:sin
727+
728+
julia> MOI.Nonlinear.eval_univariate_hessian(r, 16, 1.0)
729+
-0.8414709848078965
730+
731+
julia> -sin(1.0)
732+
-0.8414709848078965
733+
```
580734
"""
581735
function eval_univariate_hessian(
582736
registry::OperatorRegistry,
583737
op::Symbol,
584738
x::T,
585739
) where {T}
586740
id = registry.univariate_operator_to_id[op]
741+
return eval_univariate_hessian(registry, id, x)
742+
end
743+
744+
function eval_univariate_hessian(
745+
registry::OperatorRegistry,
746+
id::Integer,
747+
x::T,
748+
) where {T}
587749
if id <= registry.univariate_user_operator_start
588750
return _eval_univariate_2nd_deriv(id, x)::T
589751
end
590752
offset = id - registry.univariate_user_operator_start
591753
operator = registry.registered_univariate_operators[offset]
592-
ret = operator.f′′(x)
593-
check_return_type(T, ret)
594-
return ret::T
754+
return eval_univariate_hessian(operator, x)
595755
end
596756

597757
"""

test/Nonlinear/Nonlinear.jl

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,28 +360,65 @@ end
360360

361361
function test_eval_univariate_function()
362362
r = Nonlinear.OperatorRegistry()
363-
@test Nonlinear.eval_univariate_function(r, :+, 1.0) == 1.0
364-
@test Nonlinear.eval_univariate_function(r, :-, 1.0) == -1.0
365-
@test Nonlinear.eval_univariate_function(r, :abs, -1.1) == 1.1
366-
@test Nonlinear.eval_univariate_function(r, :abs, 1.1) == 1.1
363+
for (op, x, y) in [
364+
(:+, 1.0, 1.0),
365+
(:-, 1.0, -1.0),
366+
(:abs, -1.1, 1.1),
367+
(:abs, 1.1, 1.1),
368+
(:sin, 1.1, sin(1.1)),
369+
]
370+
id = r.univariate_operator_to_id[op]
371+
@test Nonlinear.eval_univariate_function(r, op, x) == y
372+
@test Nonlinear.eval_univariate_function(r, id, x) == y
373+
end
367374
return
368375
end
369376

370377
function test_eval_univariate_gradient()
371378
r = Nonlinear.OperatorRegistry()
372-
@test Nonlinear.eval_univariate_gradient(r, :+, 1.2) == 1.0
373-
@test Nonlinear.eval_univariate_gradient(r, :-, 1.2) == -1.0
374-
@test Nonlinear.eval_univariate_gradient(r, :abs, -1.1) == -1.0
375-
@test Nonlinear.eval_univariate_gradient(r, :abs, 1.1) == 1.0
379+
for (op, x, y) in [
380+
(:+, 1.2, 1.0),
381+
(:-, 1.2, -1.0),
382+
(:abs, -1.1, -1.0),
383+
(:abs, 1.1, 1.0),
384+
(:sin, 1.1, cos(1.1)),
385+
]
386+
id = r.univariate_operator_to_id[op]
387+
@test Nonlinear.eval_univariate_gradient(r, op, x) == y
388+
@test Nonlinear.eval_univariate_gradient(r, id, x) == y
389+
end
390+
return
391+
end
392+
393+
function test_eval_univariate_function_and_gradient()
394+
r = Nonlinear.OperatorRegistry()
395+
for (op, x, y) in [
396+
(:+, 1.2, (1.2, 1.0)),
397+
(:-, 1.2, (-1.2, -1.0)),
398+
(:abs, -1.1, (1.1, -1.0)),
399+
(:abs, 1.1, (1.1, 1.0)),
400+
(:sin, 1.1, (sin(1.1), cos(1.1))),
401+
]
402+
id = r.univariate_operator_to_id[op]
403+
@test Nonlinear.eval_univariate_function_and_gradient(r, op, x) == y
404+
@test Nonlinear.eval_univariate_function_and_gradient(r, id, x) == y
405+
end
376406
return
377407
end
378408

379409
function test_eval_univariate_hessian()
380410
r = Nonlinear.OperatorRegistry()
381-
@test Nonlinear.eval_univariate_hessian(r, :+, 1.2) == 0.0
382-
@test Nonlinear.eval_univariate_hessian(r, :-, 1.2) == 0.0
383-
@test Nonlinear.eval_univariate_hessian(r, :abs, -1.1) == 0.0
384-
@test Nonlinear.eval_univariate_hessian(r, :abs, 1.1) == 0.0
411+
for (op, x, y) in [
412+
(:+, 1.2, 0.0),
413+
(:-, 1.2, 0.0),
414+
(:abs, -1.1, 0.0),
415+
(:abs, 1.1, 0.0),
416+
(:sin, 1.0, -sin(1.0)),
417+
]
418+
id = r.univariate_operator_to_id[op]
419+
@test Nonlinear.eval_univariate_hessian(r, op, x) == y
420+
@test Nonlinear.eval_univariate_hessian(r, id, x) == y
421+
end
385422
return
386423
end
387424

0 commit comments

Comments
 (0)