Skip to content

Commit 1f80353

Browse files
author
Alexander Ororbia
committed
wrote dynamics for exp-syn
1 parent 00eeff9 commit 1f80353

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

ngclearn/components/synapses/exponentialSynapse.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
2828
| i_syn - derived total electrical current variable
2929
3030
Args:
31-
name: the string name of this cell
31+
name: the string name of this synapse
3232
3333
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
3434
with number of inputs by number of outputs)
@@ -76,7 +76,7 @@ def __init__(
7676
self.i_syn = Compartment(postVals) ## electrical current output
7777
self.g_syn = Compartment(postVals) ## conductance variable
7878
if is_nonplastic:
79-
self.weights.set(self.weights * 0 + 1.)
79+
self.weights.set(self.weights.value * 0 + 1.)
8080

8181
@transition(output_compartments=["outputs", "i_syn", "g_syn"])
8282
@staticmethod
@@ -88,7 +88,7 @@ def advance_state(
8888
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
8989
dgsyn_dt = _out * g_syn_bar - g_syn/tau_syn
9090
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance
91-
i_syn = g_syn * (v - syn_rest)
91+
i_syn = -g_syn * (v - syn_rest)
9292
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
9393
return outputs, i_syn, g_syn
9494

tests/components/synapses/test_exponentialSynapse.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from ngcsimlib.compilers import compile_command, wrap_command
77
from numpy.testing import assert_array_equal
88

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
9+
from ngcsimlib.compilers.process import Process
1210
from ngcsimlib.context import Context
1311
import ngclearn.utils.weight_distribution as dist
1412

@@ -18,10 +16,17 @@ def test_exponentialSynapse1():
1816
dkey = random.PRNGKey(1234)
1917
dkey, *subkeys = random.split(dkey, 6)
2018
dt = 1. # ms
19+
## excitatory properties
20+
tau_syn = 2.
21+
E_rest = 0.
22+
## inhibitory properties
23+
#tau_syn = 5.
24+
#E_rest = -80.
2125
# ---- build a single exp-synapse system ----
2226
with Context(name) as ctx:
2327
a = ExponentialSynapse(
24-
name="a", shape=(1,1), resources_init=dist.constant(value=1.),key=subkeys[0]
28+
name="a", shape=(1,1), tau_syn=tau_syn, g_syn_bar=2.4, syn_rest=E_rest, weight_init=dist.constant(value=1.),
29+
key=subkeys[0]
2530
)
2631

2732
advance_process = (Process("advance_proc")
@@ -33,25 +38,24 @@ def test_exponentialSynapse1():
3338
>> a.reset)
3439
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3540

36-
a.weights.set(jnp.ones((1, 1)))
37-
in_pulse = jnp.ones((1, 1)) * 0.425
41+
sp_train = jnp.array([1., 0., 1.], dtype=jnp.float32)
42+
post_syn_neuron_volt = jnp.ones((1, 1)) * -65. ## post-syn neuron is at rest
3843

39-
outs_truth = jnp.array([[0.07676563, 0.14312361, 0.16848783]])
40-
Wdyn_truth = jnp.array([[0.180625, 0.33676142, 0.39644194]])
44+
outs_truth = jnp.array([[156., 78., 195.]])
4145

4246
outs = []
43-
Wdyn = []
4447
ctx.reset()
4548
for t in range(3):
49+
in_pulse = jnp.expand_dims(sp_train[t], axis=0)
4650
a.inputs.set(in_pulse)
51+
a.v.set(post_syn_neuron_volt)
4752
ctx.run(t=t * dt, dt=dt)
53+
print("g: ",a.g_syn.value)
54+
print("i: ", a.i_syn.value)
4855
outs.append(a.outputs.value)
49-
Wdyn.append(a.Wdyn.value)
5056
outs = jnp.concatenate(outs, axis=1)
51-
Wdyn = jnp.concatenate(Wdyn, axis=1)
52-
# print(outs)
53-
# print(Wdyn)
57+
#print(outs)
58+
5459
np.testing.assert_allclose(outs, outs_truth, atol=1e-8)
55-
np.testing.assert_allclose(Wdyn, Wdyn_truth, atol=1e-8)
5660

57-
#test_exponentialSynapse1()
61+
test_exponentialSynapse1()

0 commit comments

Comments
 (0)