6
6
from ngcsimlib .compilers import compile_command , wrap_command
7
7
from numpy .testing import assert_array_equal
8
8
9
- from ngcsimlib .compilers .process import Process , transition
10
- from ngcsimlib .component import Component
11
- from ngcsimlib .compartment import Compartment
9
+ from ngcsimlib .compilers .process import Process
12
10
from ngcsimlib .context import Context
13
11
import ngclearn .utils .weight_distribution as dist
14
12
@@ -18,10 +16,17 @@ def test_exponentialSynapse1():
18
16
dkey = random .PRNGKey (1234 )
19
17
dkey , * subkeys = random .split (dkey , 6 )
20
18
dt = 1. # ms
19
+ ## excitatory properties
20
+ tau_syn = 2.
21
+ E_rest = 0.
22
+ ## inhibitory properties
23
+ #tau_syn = 5.
24
+ #E_rest = -80.
21
25
# ---- build a single exp-synapse system ----
22
26
with Context (name ) as ctx :
23
27
a = ExponentialSynapse (
24
- name = "a" , shape = (1 ,1 ), resources_init = dist .constant (value = 1. ),key = subkeys [0 ]
28
+ name = "a" , shape = (1 ,1 ), tau_syn = tau_syn , g_syn_bar = 2.4 , syn_rest = E_rest , weight_init = dist .constant (value = 1. ),
29
+ key = subkeys [0 ]
25
30
)
26
31
27
32
advance_process = (Process ("advance_proc" )
@@ -33,25 +38,24 @@ def test_exponentialSynapse1():
33
38
>> a .reset )
34
39
ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
35
40
36
- a . weights . set ( jnp .ones (( 1 , 1 )) )
37
- in_pulse = jnp .ones ((1 , 1 )) * 0.425
41
+ sp_train = jnp .array ([ 1. , 0. , 1. ], dtype = jnp . float32 )
42
+ post_syn_neuron_volt = jnp .ones ((1 , 1 )) * - 65. ## post-syn neuron is at rest
38
43
39
- outs_truth = jnp .array ([[0.07676563 , 0.14312361 , 0.16848783 ]])
40
- Wdyn_truth = jnp .array ([[0.180625 , 0.33676142 , 0.39644194 ]])
44
+ outs_truth = jnp .array ([[156. , 78. , 195. ]])
41
45
42
46
outs = []
43
- Wdyn = []
44
47
ctx .reset ()
45
48
for t in range (3 ):
49
+ in_pulse = jnp .expand_dims (sp_train [t ], axis = 0 )
46
50
a .inputs .set (in_pulse )
51
+ a .v .set (post_syn_neuron_volt )
47
52
ctx .run (t = t * dt , dt = dt )
53
+ print ("g: " ,a .g_syn .value )
54
+ print ("i: " , a .i_syn .value )
48
55
outs .append (a .outputs .value )
49
- Wdyn .append (a .Wdyn .value )
50
56
outs = jnp .concatenate (outs , axis = 1 )
51
- Wdyn = jnp .concatenate (Wdyn , axis = 1 )
52
- # print(outs)
53
- # print(Wdyn)
57
+ #print(outs)
58
+
54
59
np .testing .assert_allclose (outs , outs_truth , atol = 1e-8 )
55
- np .testing .assert_allclose (Wdyn , Wdyn_truth , atol = 1e-8 )
56
60
57
- # test_exponentialSynapse1()
61
+ test_exponentialSynapse1 ()
0 commit comments