Skip to content

Commit df3f0e4

Browse files
author
pluto-azzaare
committed
Fixes for learning ICNs with CBLS
1 parent ed2690d commit df3f0e4

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

src/icn.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ mutable struct ICN
1717
weights::BitVector
1818

1919
function ICN(;
20-
param = Vector{Symbol}(),
21-
tr_layer = transformation_layer(param),
22-
ar_layer = arithmetic_layer(),
23-
ag_layer = aggregation_layer(),
24-
co_layer = comparison_layer(param),
20+
param=Vector{Symbol}(),
21+
tr_layer=transformation_layer(param),
22+
ar_layer=arithmetic_layer(),
23+
ag_layer=aggregation_layer(),
24+
co_layer=comparison_layer(param),
2525
)
2626
w = generate_weights([tr_layer, ar_layer, ag_layer, co_layer])
2727
return new(tr_layer, ar_layer, ag_layer, co_layer, w)
@@ -107,7 +107,7 @@ function regularization(icn)
107107
return Σop / (Σmax + 1)
108108
end
109109

110-
max_icn_length(icn = ICN(; param = [:val])) = length(icn.transformation)
110+
max_icn_length(icn=ICN(; param=[:val])) = length(icn.transformation)
111111

112112
"""
113113
_compose(icn)
@@ -116,7 +116,7 @@ Internal function called by `compose` and `show_composition`.
116116
function _compose(icn::ICN)
117117
!is_viable(icn) && (
118118
return (
119-
(x; X = zeros(length(x), max_icn_length()), param = nothing, dom_size = 0) -> typemax(Float64)
119+
(x; X=zeros(length(x), max_icn_length()), param=nothing, dom_size=0) -> typemax(Float64)
120120
),
121121
[]
122122
)
@@ -133,6 +133,7 @@ function _compose(icn::ICN)
133133

134134
if exclu(layer)
135135
f_id = as_int(@view weights(icn)[_start:_end])
136+
# @warn "debug" f_id _end _start weights(icn) (exclu(layer) ? "nbits_exclu(layer)" : "length(layer)") (@view weights(icn)[_start:_end])
136137
s = symbol(layer, f_id + 1)
137138
push!(funcs, [functions(layer)[s]])
138139
push!(symbols, [s])
@@ -151,11 +152,11 @@ function _compose(icn::ICN)
151152
end
152153
end
153154

154-
function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...)
155+
function composition(x; X=zeros(length(x), length(funcs[1])), dom_size, params...)
155156
tr_in(Tuple(funcs[1]), X, x; params...)
156157
X[1:length(x), 1] .=
157158
1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])]))
158-
return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))(
159+
return (y -> funcs[4][1](y; dom_size, nvars=length(x), params...))(
159160
funcs[3][1](@view X[:, 1]),
160161
)
161162
end

src/layer.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ exclu(layer) = layer.exclusive
3030
symbol(layer, i)
3131
Return the i-th symbols of the operations in a given layer.
3232
"""
33-
symbol(layer, i) = collect(keys(functions(layer)))[i]
33+
symbol(layer, i) = begin
34+
if i > length(layer)
35+
@info layer i functions(layer)
36+
end
37+
collect(keys(functions(layer)))[i]
38+
end
3439

3540
"""
3641
nbits_exclu(layer)

0 commit comments

Comments
 (0)