Skip to content

Commit 0182340

Browse files
author
Alexander Ororbia
committed
minor tweaks + init of rl-snn exhibit lesson
1 parent 5396eb3 commit 0182340

File tree

5 files changed

+68
-1
lines changed

5 files changed

+68
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ $ python install -e .
122122
</pre>
123123

124124
**Version:**<br>
125-
2.0.0 <!--1.2.3-Beta--> <!-- -Alpha -->
125+
2.0.1 <!--1.2.3-Beta--> <!-- -Alpha -->
126126

127127
Author:
128128
Alexander G. Ororbia II<br>

docs/museum/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ relevant, referenced publicly available ngc-learn simulation code.
1818
snn_dc
1919
snn_bfa
2020
sindy
21+
rl_snn

docs/museum/rl_snn.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Reinforcement Learning through a Spiking Controller
2+
3+
In this exhibit, we will see how to construct a simple biophysical model for
4+
reinforcement learning with a spiking neural network and modulated
5+
spike-timing-dependent plasticity.
6+
This model incorporates a mechanisms from several different models, including
7+
the constrained RL-centric SNN of <b>[1]</b> as well as the simplifications
8+
made with respect to the model of <b>[2]</b>. The model code for this
9+
exhibit can be found
10+
[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn).
11+
12+
## Modeling Operant Conditioning through Modulation
13+
14+
15+
### Reward-Modulated Spike-Timing-Dependent Plasticity (R-STDP)
16+
17+
18+
## The Spiking Neural Circuit Model
19+
20+
21+
### Neuronal Dynamics
22+
23+
24+
## Running the RL-SNN Model
25+
26+
27+
<!-- References/Citations -->
28+
## References
29+
<b>[1]</b> Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse
30+
and delayed rewards with a multilayer spiking neural network." 2020 International
31+
Joint Conference on Neural Networks (IJCNN). IEEE, 2020. <br>
32+
<b>[2]</b> Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit
33+
recognition using spike-timing-dependent plasticity." Frontiers in computational
34+
neuroscience 9 (2015): 99.
35+

ngclearn/components/neurons/graded/gaussianErrorCell.py

+7
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6565
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
6666
self.mask = Compartment(restVals + 1.0)
6767

68+
@staticmethod
69+
def eval_log_density(target, mu, Sigma):
70+
_dmu = (target - mu)
71+
log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
72+
return log_density
73+
6874
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
6975
@staticmethod
7076
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
@@ -79,6 +85,7 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e
7985
dtarget = -dmu # reverse of e
8086
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
8187
L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
88+
#L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
8289

8390
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
8491
dtarget = dtarget * modulator * mask

ngclearn/utils/diffeq/ode_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ def _euler(carry, dfx, dt, params, x_scale=1.):
112112
new_carry = (_t, _x)
113113
return new_carry, (new_carry, carry)
114114

115+
@partial(jit, static_argnums=(1))
116+
def _leapfrog(carry, dfq, dt, params):
117+
t, q, p = carry
118+
dq_dt = dfq(t, q, params)
119+
120+
_p = p + dq_dt * (dt/2.)
121+
_q = q + p * dt
122+
dq_dtpdt = dfq(t+dt, _q, params)
123+
_p = _p + dq_dtpdt * (dt/2.)
124+
_t = t + dt
125+
new_carry = (_t, _q, _p)
126+
return new_carry, (new_carry, carry)
127+
128+
@partial(jit, static_argnums=(3, 4))
129+
def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params):
130+
t = t_curr + 0.
131+
q = q_curr + 0.
132+
p = p_curr + 0.
133+
def scanner(carry, _):
134+
return _leapfrog(carry, dfq, step_size, params)
135+
new_values, (xs_next, xs_carry) = _scan(scanner, init=(t, q, p), xs=jnp.arange(L))
136+
t, q, p = new_values
137+
return t, q, p
138+
115139
@partial(jit, static_argnums=(2))
116140
def step_heun(t, x, dfx, dt, params, x_scale=1.):
117141
"""

0 commit comments

Comments
 (0)