Skip to content

Commit b80be2a

Browse files
authored
Merge pull request #127 from SymbolicML/n-arity-v4
Permit n-argument operators
2 parents 294f789 + 73ad287 commit b80be2a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+3386
-1878
lines changed

.github/codecov.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
coverage:
2+
status:
3+
patch:
4+
default:
5+
informational: true

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "1.10.3"
4+
version = "2.0.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
910
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
1011
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -30,6 +31,7 @@ DynamicExpressionsZygoteExt = "Zygote"
3031
[compat]
3132
Bumper = "0.6"
3233
ChainRulesCore = "1"
34+
Compat = "4.16"
3335
DispatchDoctor = "0.4"
3436
Interfaces = "0.3"
3537
LoopVectorization = "0.12"

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ A dynamic expression is a snippet of code that can change throughout runtime - c
2727
```julia
2828
using DynamicExpressions
2929

30-
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
30+
operators = OperatorEnum(1 => (cos,), 2 => (+, -, *))
3131
variable_names = ["x1", "x2"]
3232

3333
x1 = Expression(Node{Float64}(feature=1); operators, variable_names)
@@ -98,7 +98,7 @@ We can also compute gradients with the same speed:
9898
```julia
9999
using Zygote # trigger extension
100100

101-
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
101+
operators = OperatorEnum(1 => (cos,), 2 => (+, -, *))
102102
variable_names = ["x1", "x2"]
103103
x1, x2 = (Expression(Node{Float64}(feature=i); operators, variable_names) for i in 1:2)
104104

@@ -149,7 +149,7 @@ using DynamicExpressions: @declare_expression_operator
149149
my_string_func(x::String) = "ello $x"
150150
@declare_expression_operator(my_string_func, 1)
151151

152-
operators = GenericOperatorEnum(; binary_operators=[*], unary_operators=[my_string_func])
152+
operators = GenericOperatorEnum(1 => (my_string_func,), 2 => (*,))
153153

154154
x1 = Expression(_x1; operators, variable_names)
155155
```
@@ -192,7 +192,7 @@ vec_square(x) = x .* x
192192
@declare_expression_operator(vec_square, 1)
193193

194194
# Set up an operator enum:
195-
operators = GenericOperatorEnum(;binary_operators=[vec_add], unary_operators=[vec_square])
195+
operators = GenericOperatorEnum(1 => (vec_square,), 2 => (vec_add,))
196196

197197
# Construct the expression:
198198
variable_names = ["x1"]

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
44
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
55
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
66
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7+
8+
[sources]
9+
DynamicExpressions = { path = "../" }

docs/src/api.md

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,11 @@ This `enum` is defined as follows:
1010
OperatorEnum
1111
```
1212

13-
Construct this operator specification as follows:
14-
15-
```@docs
16-
OperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
17-
```
18-
1913
This is just for scalar operators. However, you can use
2014
the following for more general operators:
2115

2216
```@docs
23-
GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
17+
GenericOperatorEnum
2418
```
2519

2620
By default, these operators will define helper functions for constructing trees,
@@ -60,7 +54,7 @@ When using these node constructors, types will automatically be promoted.
6054
You can convert the type of a node using `convert`:
6155

6256
```@docs
63-
convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2}
57+
convert(::Type{N1}, tree::N2) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}}
6458
```
6559

6660
You can set a `tree` (in-place) with `set_node!`:
@@ -75,6 +69,41 @@ You can create a copy of a node with `copy_node`:
7569
copy_node
7670
```
7771

72+
## Generic Node Accessors
73+
74+
For working with nodes of arbitrary arity:
75+
76+
```@docs
77+
get_child
78+
set_child!
79+
get_children
80+
set_children!
81+
```
82+
83+
Examples:
84+
85+
```julia
86+
# Define operators including ternary
87+
my_ternary(x, y, z) = x + y * z
88+
operators = OperatorEnum(((sin,), (+, *), (my_ternary,))) # (unary, binary, ternary)
89+
90+
tree = Node{Float64,3}(; op=1, children=(Node{Float64,3}(; val=1.0), Node{Float64,3}(; val=2.0)))
91+
new_child = Node{Float64,3}(; val=3.0)
92+
93+
left_child = get_child(tree, 1)
94+
right_child = get_child(tree, 2)
95+
96+
set_child!(tree, new_child, 1)
97+
98+
left, right = get_children(tree, Val(2)) # type stable
99+
100+
# Transform to ternary operation
101+
child1, child2, child3 = Node{Float64,3}(; val=4.0), Node{Float64,3}(; val=5.0), Node{Float64,3}(; val=6.0)
102+
set_children!(tree, (child1, child2, child3))
103+
tree.op = 1 # my_ternary
104+
tree.degree = 3
105+
```
106+
78107
## Graph Nodes
79108

80109
You can describe an equation as a *graph* rather than a tree
@@ -88,9 +117,7 @@ This makes it so you can have multiple parents for a given node,
88117
and share parts of an expression. For example:
89118

90119
```julia
91-
julia> operators = OperatorEnum(;
92-
binary_operators=[+, -, *], unary_operators=[cos, sin, exp]
93-
);
120+
julia> operators = OperatorEnum(1 => (cos, sin, exp), 2 => (+, -, *));
94121

95122
julia> x1, x2 = GraphNode(feature=1), GraphNode(feature=2)
96123
(x1, x2)
@@ -109,7 +136,7 @@ This means that we only need to change it once
109136
to have changes propagate across the expression:
110137

111138
```julia
112-
julia> y.r.val *= 0.9
139+
julia> get_child(y, 2).val *= 0.9
113140
1.35
114141

115142
julia> z

docs/src/eval.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ For example,
4040
```@example
4141
using DynamicExpressions
4242
43-
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
43+
operators = OperatorEnum(1 => (cos, sin), 2 => (+, -, *, /))
4444
tree = Node(; feature=1) * cos(Node(; feature=2) - 3.2)
4545
4646
tree([1 2 3; 4 5 6.], operators)
@@ -155,7 +155,7 @@ Let's look at an example. First, let's create a tree:
155155
```julia
156156
using DynamicExpressions
157157

158-
operators = OperatorEnum(binary_operators=(+, -, *, /), unary_operators=(cos, sin))
158+
operators = OperatorEnum(1 => (cos, sin), 2 => (+, -, *, /))
159159

160160
x1 = Node{Float64}(feature=1)
161161
x2 = Node{Float64}(feature=2)

docs/src/utils.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ mapreduce(f::F, op::G, tree::AbstractNode; return_type, f_on_shared, break_shari
1515
any(f::F, tree::AbstractNode) where {F<:Function}
1616
all(f::F, tree::AbstractNode) where {F<:Function}
1717
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
18-
convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2}
1918
hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T}
2019
```
2120

ext/DynamicExpressionsBumperExt.jl

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ using DynamicExpressions:
55
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
66
using DynamicExpressions.UtilsModule: ResultOk, counttuple
77

8-
import DynamicExpressions.ExtensionInterfaceModule:
9-
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
8+
import DynamicExpressions.ExtensionInterfaceModule: bumper_eval_tree_array, bumper_kern!
109

1110
function bumper_eval_tree_array(
1211
tree::AbstractExpressionNode{T},
@@ -37,8 +36,7 @@ function bumper_eval_tree_array(
3736
branch_node -> branch_node,
3837
# In the evaluation kernel, we combine the branch nodes
3938
# with the arrays created by the leaf nodes:
40-
((args::Vararg{Any,M}) where {M}) ->
41-
dispatch_kerns!(operators, args..., eval_options),
39+
KernelDispatcher(operators, eval_options),
4240
tree;
4341
break_sharing=Val(true),
4442
)
@@ -49,63 +47,44 @@ function bumper_eval_tree_array(
4947
return (result, all_ok[])
5048
end
5149

52-
function dispatch_kerns!(
53-
operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit}
54-
) where {early_exit}
55-
cumulator.ok || return cumulator
56-
57-
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options)
58-
return ResultOk(out, early_exit ? is_valid_array(out) : true)
59-
end
60-
function dispatch_kerns!(
61-
operators,
62-
branch_node,
63-
cumulator1,
64-
cumulator2,
65-
eval_options::EvalOptions{<:Any,true,early_exit},
66-
) where {early_exit}
67-
cumulator1.ok || return cumulator1
68-
cumulator2.ok || return cumulator2
69-
70-
out = dispatch_kern2!(
71-
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options
72-
)
73-
return ResultOk(out, early_exit ? is_valid_array(out) : true)
50+
struct KernelDispatcher{O<:OperatorEnum,E<:EvalOptions{<:Any,true,<:Any}} <: Function
51+
operators::O
52+
eval_options::E
7453
end
7554

76-
@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions)
77-
nuna = counttuple(unaops)
55+
@generated function (kd::KernelDispatcher{<:Any,<:EvalOptions{<:Any,true,early_exit}})(
56+
branch_node, inputs::Vararg{Any,degree}
57+
) where {degree,early_exit}
7858
quote
79-
Base.@nif(
80-
$nuna,
81-
i -> i == op_idx,
82-
i -> let op = unaops[i]
83-
return bumper_kern1!(op, cumulator, eval_options)
84-
end,
85-
)
59+
Base.Cartesian.@nexprs($degree, i -> inputs[i].ok || return inputs[i])
60+
cumulators = Base.Cartesian.@ntuple($degree, i -> inputs[i].x)
61+
out = dispatch_kerns!(kd.operators, branch_node, cumulators, kd.eval_options)
62+
return ResultOk(out, early_exit ? is_valid_array(out) : true)
8663
end
8764
end
88-
@generated function dispatch_kern2!(
89-
binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions
90-
)
91-
nbin = counttuple(binops)
65+
@generated function dispatch_kerns!(
66+
operators::OperatorEnum{OPS},
67+
branch_node,
68+
cumulators::Tuple{Vararg{Any,degree}},
69+
eval_options::EvalOptions,
70+
) where {OPS,degree}
71+
nops = length(OPS.types[degree].types)
9272
quote
93-
Base.@nif(
94-
$nbin,
73+
op_idx = branch_node.op
74+
Base.Cartesian.@nif(
75+
$nops,
9576
i -> i == op_idx,
96-
i -> let op = binops[i]
97-
return bumper_kern2!(op, cumulator1, cumulator2, eval_options)
98-
end,
77+
i -> bumper_kern!(operators[$degree][i], cumulators, eval_options)
9978
)
10079
end
10180
end
102-
function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F}
103-
@. cumulator = op(cumulator)
104-
return cumulator
105-
end
106-
function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F}
107-
@. cumulator1 = op(cumulator1, cumulator2)
108-
return cumulator1
81+
82+
function bumper_kern!(
83+
op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{false,true,early_exit}
84+
) where {F,degree,early_exit}
85+
cumulator_1 = first(cumulators)
86+
@. cumulator_1 = op(cumulators...)
87+
return cumulator_1
10988
end
11089

11190
end

0 commit comments

Comments
 (0)