Skip to content

Commit 65ca503

Browse files
author
Alexander Ororbia
committed
integrated alpha-synapse
1 parent 9836005 commit 65ca503

File tree

4 files changed

+183
-3
lines changed

4 files changed

+183
-3
lines changed

ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .synapses.hebbian.BCMSynapse import BCMSynapse
4040
from .synapses.STPDenseSynapse import STPDenseSynapse
4141
from .synapses.exponentialSynapse import ExponentialSynapse
42+
from .synapses.alphaSynapse import AlphaSynapse
4243

4344
## point to convolutional component types
4445
from .synapses.convolution.convSynapse import ConvSynapse

ngclearn/components/synapses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
## short-term plasticity components
66
from .STPDenseSynapse import STPDenseSynapse
77
from .exponentialSynapse import ExponentialSynapse
8+
from .alphaSynapse import AlphaSynapse
89

910
## dense synaptic components
1011
from .hebbian.hebbianSynapse import HebbianSynapse
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.components.synapses import DenseSynapse
9+
from ngclearn.utils import tensorstats
10+
11+
class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
12+
"""
13+
A dynamic alpha synaptic cable; this synapse evolves according to alpha synaptic conductance dynamics.
14+
Specifically, the conductance dynamics are as follows:
15+
16+
| dh/dt = -h/tau_syn + gBar sum_k (t - t_k) // h is an intermediate variable
17+
| dg/dt = -g/tau_syn + h/tau_syn
18+
| i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation
19+
| where: syn_rest is the post-synaptic reverse potential for this synapse
20+
| t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
21+
22+
23+
| --- Synapse Compartments: ---
24+
| inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes)
25+
| outputs - output signals (also equal to i_syn, total electrical current)
26+
| v - coupled voltages from post-synaptic neurons this synaptic cable connects to
27+
| weights - current value matrix of synaptic efficacies
28+
| biases - current value vector of synaptic bias values
29+
| --- Dynamic / Short-term Plasticity Compartments: ---
30+
| g_syn - fixed value matrix of synaptic resources (U)
31+
| i_syn - derived total electrical current variable
32+
33+
Args:
34+
name: the string name of this synapse
35+
36+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
37+
with number of inputs by number of outputs)
38+
39+
tau_syn: synaptic time constant (ms)
40+
41+
g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight")
42+
43+
syn_rest: synaptic reversal potential
44+
45+
weight_init: a kernel to drive initialization of this synaptic cable's values;
46+
typically a tuple with 1st element as a string calling the name of
47+
initialization to use
48+
49+
bias_init: a kernel to drive initialization of biases for this synaptic cable
50+
(Default: None, which turns off/disables biases) <unused>
51+
52+
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
53+
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
54+
55+
p_conn: probability of a connection existing (default: 1.); setting
56+
this to < 1 and > 0. will result in a sparser synaptic structure
57+
(lower values yield sparse structure)
58+
59+
is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True)
60+
61+
"""
62+
63+
# Define Functions
64+
def __init__(
65+
self, name, shape, tau_syn, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
66+
is_nonplastic=True, **kwargs
67+
):
68+
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
69+
## dynamic synapse meta-parameters
70+
self.tau_syn = tau_syn
71+
self.g_syn_bar = g_syn_bar
72+
self.syn_rest = syn_rest ## synaptic resting potential
73+
74+
## Set up short-term plasticity / dynamic synapse compartment values
75+
#tmp_key, *subkeys = random.split(self.key.value, 4)
76+
#preVals = jnp.zeros((self.batch_size, shape[0]))
77+
postVals = jnp.zeros((self.batch_size, shape[1]))
78+
self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron)
79+
self.i_syn = Compartment(postVals) ## electrical current output
80+
self.g_syn = Compartment(postVals) ## conductance variable
81+
self.h_syn = Compartment(postVals) ## intermediate conductance variable
82+
if is_nonplastic:
83+
self.weights.set(self.weights.value * 0 + 1.)
84+
85+
@transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
86+
@staticmethod
87+
def advance_state(
88+
dt, tau_syn, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
89+
):
90+
s = inputs
91+
## advance conductance variable
92+
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
93+
dhsyn_dt = _out * g_syn_bar - h_syn/tau_syn
94+
h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
95+
96+
dgsyn_dt = -g_syn/tau_syn + h_syn/tau_syn
97+
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
98+
99+
i_syn = -g_syn * (v - syn_rest)
100+
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
101+
return outputs, i_syn, g_syn, h_syn
102+
103+
@transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
104+
@staticmethod
105+
def reset(batch_size, shape):
106+
preVals = jnp.zeros((batch_size, shape[0]))
107+
postVals = jnp.zeros((batch_size, shape[1]))
108+
inputs = preVals
109+
outputs = postVals
110+
i_syn = postVals
111+
g_syn = postVals
112+
h_syn = postVals
113+
v = postVals
114+
return inputs, outputs, i_syn, g_syn, h_syn, v
115+
116+
def save(self, directory, **kwargs):
117+
file_name = directory + "/" + self.name + ".npz"
118+
if self.bias_init != None:
119+
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
120+
else:
121+
jnp.savez(file_name, weights=self.weights.value)
122+
123+
def load(self, directory, **kwargs):
124+
file_name = directory + "/" + self.name + ".npz"
125+
data = jnp.load(file_name)
126+
self.weights.set(data['weights'])
127+
if "biases" in data.keys():
128+
self.biases.set(data['biases'])
129+
130+
@classmethod
131+
def help(cls): ## component help function
132+
properties = {
133+
"synapse_type": "STPDenseSynapse - performs a synaptic transformation of inputs to produce "
134+
"output signals (e.g., a scaled linear multivariate transformation); "
135+
"this synapse is dynamic, adapting via a form of short-term plasticity"
136+
}
137+
compartment_props = {
138+
"inputs":
139+
{"inputs": "Takes in external input signal values"},
140+
"states":
141+
{"weights": "Synapse efficacy/strength parameter values",
142+
"biases": "Base-rate/bias parameter values",
143+
"key": "JAX PRNG key"},
144+
"outputs":
145+
{"outputs": "Output of synaptic transformation"},
146+
}
147+
hyperparams = {
148+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
149+
"weight_init": "Initialization conditions for synaptic weight (W) values",
150+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
151+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
152+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
153+
"tau_syn": "Synaptic time constant (ms)",
154+
"g_bar_syn": "Maximum conductance value",
155+
"syn_rest": "Synaptic reversal potential"
156+
}
157+
info = {cls.__name__: properties,
158+
"compartments": compartment_props,
159+
"dynamics": "outputs = g_syn * (v - syn_rest); "
160+
"dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_syn ",
161+
"hyperparameters": hyperparams}
162+
return info
163+
164+
def __repr__(self):
165+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
166+
maxlen = max(len(c) for c in comps) + 5
167+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
168+
for c in comps:
169+
stats = tensorstats(getattr(self, c).value)
170+
if stats is not None:
171+
line = [f"{k}: {v}" for k, v in stats.items()]
172+
line = ", ".join(line)
173+
else:
174+
line = "None"
175+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
176+
return lines

ngclearn/components/synapses/exponentialSynapse.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
1212
"""
1313
A dynamic exponential synaptic cable; this synapse evolves according to exponential synaptic conductance dynamics.
14-
Specifically, the dynamics are as follows:
14+
Specifically, the conductance dynamics are as follows:
1515
16-
| g = g + (weight * gbase) // on the occurrence of a pulse
17-
| i = g * (erev - v), where: d g /dt = -g / tauDecay
16+
| dg/dt = -g/tau_syn + gBar sum_k (t - t_k)
17+
| i_syn = g * (syn_rest - v) // g is `g_syn` in this synapse implementation
18+
| where: syn_rest is the post-synaptic reverse potential for this synapse
19+
| t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
1820
1921
2022
| --- Synapse Compartments: ---

0 commit comments

Comments
 (0)