Skip to content

Commit 1e1d5c8

Browse files
committed
Add function barrier to operators
1 parent fa119e1 commit 1e1d5c8

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

src/Nonlinear/operators.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ struct _UnivariateOperator{F,F′,F′′}
6363
end
6464
end
6565

66+
function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
67+
ret = operator.f(x)
68+
check_return_type(T, ret)
69+
return ret::T
70+
end
71+
72+
function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T}
73+
ret = operator.f′(x)
74+
check_return_type(T, ret)
75+
return ret::T
76+
end
77+
78+
function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T}
79+
ret = operator.f′′(x)
80+
check_return_type(T, ret)
81+
return ret::T
82+
end
83+
84+
function eval_univariate_function_and_gradient(
85+
operator::_UnivariateOperator,
86+
x::T,
87+
) where {T}
88+
ret_f = eval_univariate_function(operator, x)
89+
ret_f′ = eval_univariate_gradient(operator, x)
90+
return ret_f, ret_f′
91+
end
92+
6693
struct _MultivariateOperator{F,F′,F′′}
6794
N::Int
6895
f::F
@@ -564,9 +591,7 @@ function eval_univariate_function(
564591
end
565592
offset = id - registry.univariate_user_operator_start
566593
operator = registry.registered_univariate_operators[offset]
567-
ret = operator.f(x)
568-
check_return_type(T, ret)
569-
return ret::T
594+
return eval_univariate_function(operator, x)
570595
end
571596

572597
"""
@@ -619,9 +644,7 @@ function eval_univariate_gradient(
619644
end
620645
offset = id - registry.univariate_user_operator_start
621646
operator = registry.registered_univariate_operators[offset]
622-
ret = operator.f′(x)
623-
check_return_type(T, ret)
624-
return ret::T
647+
return eval_univariate_gradient(operator, x)
625648
end
626649

627650
"""
@@ -673,11 +696,7 @@ function eval_univariate_function_and_gradient(
673696
end
674697
offset = id - registry.univariate_user_operator_start
675698
operator = registry.registered_univariate_operators[offset]
676-
ret_f = operator.f(x)
677-
check_return_type(T, ret_f)
678-
ret_f′ = operator.f′(x)
679-
check_return_type(T, ret_f′)
680-
return ret_f::T, ret_f′::T
699+
return eval_univariate_function_and_gradient(operator, x)
681700
end
682701

683702
"""
@@ -732,9 +751,7 @@ function eval_univariate_hessian(
732751
end
733752
offset = id - registry.univariate_user_operator_start
734753
operator = registry.registered_univariate_operators[offset]
735-
ret = operator.f′′(x)
736-
check_return_type(T, ret)
737-
return ret::T
754+
return eval_univariate_hessian(operator, x)
738755
end
739756

740757
"""

0 commit comments

Comments
 (0)