Skip to content

Commit f1c39de

Browse files
authored
Extend the parametric keywords to include anything (#66)
* Extend the parametric keywords to include anything * Fix spelling an format
1 parent 9c623ff commit f1c39de

File tree

7 files changed

+160
-161
lines changed

7 files changed

+160
-161
lines changed

src/composition.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function generate(c::Composition, name, ::Val{:Julia})
7575
co = reduce_symbols(symbs[4], ", ", false; prefix = CN * "co_")
7676

7777
documentation = """\"\"\"
78-
$name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
78+
$name(x; X = zeros(length(x), $tr_length), params...)
7979
8080
Composition `$name` generated by CompositionalNetworks.jl.
8181
```
@@ -85,10 +85,10 @@ function generate(c::Composition, name, ::Val{:Julia})
8585
"""
8686

8787
output = """
88-
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
89-
$(CN)tr_in(Tuple($tr), X, x, param)
88+
function $name(x; X = zeros(length(x), $tr_length), dom_size, params...)
89+
$(CN)tr_in(Tuple($tr), X, x; params)
9090
X[1:length(x), 1] .= 1:length(x) .|> (i -> $ar(@view X[i, 1:$tr_length]))
91-
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
91+
return $ag(@view X[:, 1]) |> (y -> $co(y; dom_size, nvars=length(x), params...))
9292
end
9393
"""
9494
return documentation * format_text(output, BlueStyle(); pipe_to_function_call = false)

src/icn.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,11 @@ function _compose(icn::ICN)
151151
end
152152
end
153153

154-
function composition(
155-
x;
156-
X = zeros(length(x), length(funcs[1])),
157-
param = nothing,
158-
dom_size,
159-
)
160-
tr_in(Tuple(funcs[1]), X, x, param)
154+
function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...)
155+
tr_in(Tuple(funcs[1]), X, x; params...)
161156
X[1:length(x), 1] .=
162157
1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])]))
163-
return (y -> funcs[4][1](y; param, dom_size, nvars = length(x)))(
158+
return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))(
164159
funcs[3][1](@view X[:, 1]),
165160
)
166161
end

src/layer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474

7575
"""
7676
generate_exclusive_operation(max_op_number)
77-
Generates the operations (weigths) of a layer with exclusive operations.
77+
Generates the operations (weights) of a layer with exclusive operations.
7878
"""
7979
function generate_exclusive_operation(max_op_number)
8080
op = rand(1:max_op_number)

src/layers/comparison.jl

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,59 @@
22
co_identity(x)
33
Identity function. Already defined in Julia as `identity`, specialized for scalars in the `comparison` layer.
44
"""
5-
co_identity(x; param = nothing, dom_size = 0, nvars = 0) = identity(x)
5+
co_identity(x; params...) = identity(x)
66

77
"""
8-
co_abs_diff_val_param(x; param)
9-
Return the absolute difference between `x` and `param`.
8+
co_abs_diff_var_val(x; val)
9+
Return the absolute difference between `x` and `val`.
1010
"""
11-
co_abs_diff_val_param(x; param, dom_size = 0, nvars = 0) = abs(x - param)
11+
co_abs_diff_var_val(x; val, params...) = abs(x - val)
1212

1313
"""
14-
co_val_minus_param(x; param)
15-
Return the difference `x - param` if positive, `0.0` otherwise.
14+
co_var_minus_val(x; val)
15+
Return the difference `x - val` if positive, `0.0` otherwise.
1616
"""
17-
co_val_minus_param(x; param, dom_size = 0, nvars = 0) = max(0.0, x - param)
17+
co_var_minus_val(x; val, params...) = max(0.0, x - val)
1818

1919
"""
20-
co_param_minus_val(x; param)
21-
Return the difference `param - x` if positive, `0.0` otherwise.
20+
co_val_minus_var(x; val)
21+
Return the difference `val - x` if positive, `0.0` otherwise.
2222
"""
23-
co_param_minus_val(x; param, dom_size = 0, nvars = 0) = max(0.0, param - x)
23+
co_val_minus_var(x; val, params...) = max(0.0, val - x)
2424

2525
"""
26-
co_euclidean_param(x; param, dom_size)
27-
Compute an euclidean norm with domain size `dom_size`, weighted by `param`, of a scalar.
26+
co_euclidean_val(x; val, dom_size)
27+
Compute an euclidean norm with domain size `dom_size`, weighted by `val`, of a scalar.
2828
"""
29-
function co_euclidean_param(x; param, dom_size, nvars = 0)
30-
return x == param ? 0.0 : (1.0 + abs(x - param) / dom_size)
29+
function co_euclidean_val(x; val, dom_size, params...)
30+
return x == val ? 0.0 : (1.0 + abs(x - val) / dom_size)
3131
end
3232

3333
"""
3434
co_euclidean(x; dom_size)
3535
Compute an euclidean norm with domain size `dom_size` of a scalar.
3636
"""
37-
function co_euclidean(x; param = nothing, dom_size, nvars = 0)
38-
return co_euclidean_param(x; param = 0.0, dom_size = dom_size)
37+
function co_euclidean(x; dom_size, params...)
38+
return co_euclidean_val(x; val = 0.0, dom_size)
3939
end
4040

4141
"""
42-
co_abs_diff_val_vars(x; nvars)
42+
co_abs_diff_var_vars(x; nvars)
4343
Return the absolute difference between `x` and the number of variables `nvars`.
4444
"""
45-
co_abs_diff_val_vars(x; param = nothing, dom_size = 0, nvars) = abs(x - nvars)
45+
co_abs_diff_var_vars(x; nvars, params...) = abs(x - nvars)
4646

4747
"""
48-
co_val_minus_vars(x; nvars)
48+
co_var_minus_vars(x; nvars)
4949
Return the difference `x - nvars` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
5050
"""
51-
co_val_minus_vars(x; param = nothing, dom_size = 0, nvars) =
52-
co_val_minus_param(x; param = nvars)
51+
co_var_minus_vars(x; nvars, params...) = co_var_minus_val(x; val = nvars)
5352

5453
"""
55-
co_vars_minus_val(x; nvars)
54+
co_vars_minus_var(x; nvars)
5655
Return the difference `nvars - x` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
5756
"""
58-
co_vars_minus_val(x; param = nothing, dom_size = 0, nvars) =
59-
co_param_minus_val(x; param = nvars)
57+
co_vars_minus_var(x; nvars, params...) = co_val_minus_var(x; val = nvars)
6058

6159

6260
# Parametric layers
@@ -66,18 +64,18 @@ function make_comparisons(::Val{:none})
6664
return LittleDict{Symbol,Function}(
6765
:identity => co_identity,
6866
:euclidean => co_euclidean,
69-
:abs_diff_val_vars => co_abs_diff_val_vars,
70-
:val_minus_vars => co_val_minus_vars,
71-
:vars_minus_val => co_vars_minus_val,
67+
:abs_diff_var_vars => co_abs_diff_var_vars,
68+
:var_minus_vars => co_var_minus_vars,
69+
:vars_minus_var => co_vars_minus_var,
7270
)
7371
end
7472

7573
function make_comparisons(::Val{:val})
7674
return LittleDict{Symbol,Function}(
77-
:abs_diff_val_param => co_abs_diff_val_param,
78-
:val_minus_param => co_val_minus_param,
79-
:param_minus_val => co_param_minus_val,
80-
:euclidean_param => co_euclidean_param,
75+
:abs_diff_var_val => co_abs_diff_var_val,
76+
:var_minus_val => co_var_minus_val,
77+
:val_minus_var => co_val_minus_var,
78+
:euclidean_val => co_euclidean_val,
8179
)
8280
end
8381

@@ -113,21 +111,21 @@ end
113111
end
114112

115113
funcs_param = [
116-
CN.co_abs_diff_val_param => [2, 5],
117-
CN.co_val_minus_param => [2, 0],
118-
CN.co_param_minus_val => [0, 5],
114+
CN.co_abs_diff_var_val => [2, 5],
115+
CN.co_var_minus_val => [2, 0],
116+
CN.co_val_minus_var => [0, 5],
119117
]
120118

121119
for (f, results) in funcs_param
122120
for (key, vals) in enumerate(data)
123-
@test f(vals.first; param = vals.second[1]) == results[key]
121+
@test f(vals.first; val = vals.second[1]) == results[key]
124122
end
125123
end
126124

127125
funcs_vars = [
128-
CN.co_abs_diff_val_vars => [2, 0],
129-
CN.co_val_minus_vars => [0, 0],
130-
CN.co_vars_minus_val => [2, 0],
126+
CN.co_abs_diff_var_vars => [2, 0],
127+
CN.co_var_minus_vars => [0, 0],
128+
CN.co_vars_minus_var => [2, 0],
131129
]
132130

133131
for (f, results) in funcs_vars
@@ -136,11 +134,11 @@ end
136134
end
137135
end
138136

139-
funcs_param_dom = [CN.co_euclidean_param => [1.4, 2.0]]
137+
funcs_val_dom = [CN.co_euclidean_val => [1.4, 2.0]]
140138

141-
for (f, results) in funcs_param_dom
139+
for (f, results) in funcs_val_dom
142140
for (key, vals) in enumerate(data)
143-
@test f(vals.first, param = vals.second[1], dom_size = vals.second[2])
141+
@test f(vals.first, val = vals.second[1], dom_size = vals.second[2])
144142
results[key]
145143
end
146144
end

0 commit comments

Comments
 (0)