Skip to content

Commit d8636ce

Browse files
author
Alexander Ororbia
committed
init commit of exp-syn material
1 parent 5396eb3 commit d8636ce

File tree

4 files changed

+217
-1
lines changed

4 files changed

+217
-1
lines changed

ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
3939
from .synapses.hebbian.BCMSynapse import BCMSynapse
4040
from .synapses.STPDenseSynapse import STPDenseSynapse
41+
from .synapses.exponentialSynapse import ExponentialSynapse
4142

4243
## point to convolutional component types
4344
from .synapses.convolution.convSynapse import ConvSynapse

ngclearn/components/synapses/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
## short-term plasticity components
66
from .STPDenseSynapse import STPDenseSynapse
7-
7+
from .exponentialSynapse import ExponentialSynapse
88

99
## dense synaptic components
1010
from .hebbian.hebbianSynapse import HebbianSynapse
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.components.synapses import DenseSynapse
9+
from ngclearn.utils import tensorstats
10+
11+
class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
12+
"""
13+
A dynamic exponential synaptic cable; this synapse evolves according to exponential synaptic conductance dynamics.
14+
Specifically, the dynamics are as follows:
15+
16+
| g = g + (weight * gbase) // on the occurrence of a pulse
17+
| i = g * (erev - v), where: d g /dt = -g / tauDecay
18+
19+
20+
| --- Synapse Compartments: ---
21+
| inputs - input (takes in external signals)
22+
| outputs - output signals
23+
| weights - current value matrix of synaptic efficacies
24+
| biases - current value vector of synaptic bias values
25+
| --- Short-Term Plasticity Compartments: ---
26+
| resources - fixed value matrix of synaptic resources (U)
27+
| u - release probability; fraction of resources ready for use
28+
| x - fraction of resources available after neurotransmitter depletion
29+
30+
| Dynamics note:
31+
| If tau_d >> tau_f and resources U are large, then synapse is STD-dominated
32+
| If tau_d << tau_f and resources U are small, then synases is STF-dominated
33+
34+
Args:
35+
name: the string name of this cell
36+
37+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
38+
with number of inputs by number of outputs)
39+
40+
weight_init: a kernel to drive initialization of this synaptic cable's values;
41+
typically a tuple with 1st element as a string calling the name of
42+
initialization to use
43+
44+
bias_init: a kernel to drive initialization of biases for this synaptic cable
45+
(Default: None, which turns off/disables biases)
46+
47+
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
48+
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
49+
50+
p_conn: probability of a connection existing (default: 1.); setting
51+
this to < 1 and > 0. will result in a sparser synaptic structure
52+
(lower values yield sparse structure)
53+
54+
tau_f: short-term facilitation (STF) time constant (default: `750` ms); note
55+
that setting this to `0` ms will disable STF
56+
57+
tau_d: shoft-term depression time constant (default: `50` ms); note
58+
that setting this to `0` ms will disable STD
59+
60+
resources_int: initialization kernel for synaptic resources matrix
61+
"""
62+
63+
# Define Functions
64+
def __init__(self, name, shape, weight_init=None, bias_init=None,
65+
resist_scale=1., p_conn=1., tau_f=750., tau_d=50.,
66+
resources_init=None, **kwargs):
67+
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
68+
## STP meta-parameters
69+
self.resources_init = resources_init
70+
self.tau_f = tau_f
71+
self.tau_d = tau_d
72+
73+
## Set up short-term plasticity / dynamic synapse compartment values
74+
tmp_key, *subkeys = random.split(self.key.value, 4)
75+
preVals = jnp.zeros((self.batch_size, shape[0]))
76+
self.i = Compartment(preVals) ## electrical current output
77+
self.g = Compartment(preVals) ## conductance variable
78+
79+
80+
@transition(output_compartments=["outputs", "i", "g"])
81+
@staticmethod
82+
def advance_state(
83+
tau_f, tau_d, Rscale, inputs, weights, biases, i, g
84+
):
85+
s = inputs
86+
87+
outputs = None #jnp.matmul(inputs, Wdyn * Rscale) + biases
88+
return outputs
89+
90+
@transition(output_compartments=["inputs", "outputs", "i", "g"])
91+
@staticmethod
92+
def reset(batch_size, shape):
93+
preVals = jnp.zeros((batch_size, shape[0]))
94+
postVals = jnp.zeros((batch_size, shape[1]))
95+
inputs = preVals
96+
outputs = postVals
97+
i = preVals
98+
g = preVals
99+
return inputs, outputs, i, g
100+
101+
def save(self, directory, **kwargs):
102+
file_name = directory + "/" + self.name + ".npz"
103+
if self.bias_init != None:
104+
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
105+
else:
106+
jnp.savez(file_name, weights=self.weights.value)
107+
108+
def load(self, directory, **kwargs):
109+
file_name = directory + "/" + self.name + ".npz"
110+
data = jnp.load(file_name)
111+
self.weights.set(data['weights'])
112+
if "biases" in data.keys():
113+
self.biases.set(data['biases'])
114+
115+
@classmethod
116+
def help(cls): ## component help function
117+
properties = {
118+
"synapse_type": "STPDenseSynapse - performs a synaptic transformation of inputs to produce "
119+
"output signals (e.g., a scaled linear multivariate transformation); "
120+
"this synapse is dynamic, adapting via a form of short-term plasticity"
121+
}
122+
compartment_props = {
123+
"inputs":
124+
{"inputs": "Takes in external input signal values"},
125+
"states":
126+
{"weights": "Synapse efficacy/strength parameter values",
127+
"biases": "Base-rate/bias parameter values",
128+
"key": "JAX PRNG key"},
129+
"outputs":
130+
{"outputs": "Output of synaptic transformation"},
131+
}
132+
hyperparams = {
133+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
134+
"weight_init": "Initialization conditions for synaptic weight (W) values",
135+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
136+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
137+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
138+
}
139+
info = {cls.__name__: properties,
140+
"compartments": compartment_props,
141+
"dynamics": "outputs = [(W * Rscale) * inputs] + b; "
142+
"dg/dt = ",
143+
"hyperparameters": hyperparams}
144+
return info
145+
146+
def __repr__(self):
147+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
148+
maxlen = max(len(c) for c in comps) + 5
149+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
150+
for c in comps:
151+
stats = tensorstats(getattr(self, c).value)
152+
if stats is not None:
153+
line = [f"{k}: {v}" for k, v in stats.items()]
154+
line = ", ".join(line)
155+
else:
156+
line = "None"
157+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
158+
return lines
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import ExponentialSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
import ngclearn.utils.weight_distribution as dist
14+
15+
def test_exponentialSynapse1():
16+
name = "expsyn_ctx"
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
# ---- build a single exp-synapse system ----
22+
with Context(name) as ctx:
23+
a = ExponentialSynapse(
24+
name="a", shape=(1,1), resources_init=dist.constant(value=1.),key=subkeys[0]
25+
)
26+
27+
advance_process = (Process("advance_proc")
28+
>> a.advance_state)
29+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
30+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31+
32+
reset_process = (Process("reset_proc")
33+
>> a.reset)
34+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
35+
36+
a.weights.set(jnp.ones((1, 1)))
37+
in_pulse = jnp.ones((1, 1)) * 0.425
38+
39+
outs_truth = jnp.array([[0.07676563, 0.14312361, 0.16848783]])
40+
Wdyn_truth = jnp.array([[0.180625, 0.33676142, 0.39644194]])
41+
42+
outs = []
43+
Wdyn = []
44+
ctx.reset()
45+
for t in range(3):
46+
a.inputs.set(in_pulse)
47+
ctx.run(t=t * dt, dt=dt)
48+
outs.append(a.outputs.value)
49+
Wdyn.append(a.Wdyn.value)
50+
outs = jnp.concatenate(outs, axis=1)
51+
Wdyn = jnp.concatenate(Wdyn, axis=1)
52+
# print(outs)
53+
# print(Wdyn)
54+
np.testing.assert_allclose(outs, outs_truth, atol=1e-8)
55+
np.testing.assert_allclose(Wdyn, Wdyn_truth, atol=1e-8)
56+
57+
#test_exponentialSynapse1()

0 commit comments

Comments
 (0)