Skip to content

Commit 54ec2dd

Browse files
ago109rxng8Alexander Ororbiawillgebhardt
authored
Major release update (to 2.0.0) (#100)
* add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen <vietdungnguyen233@gmail.com> Co-authored-by: Alexander Ororbia <ago@hal3.cs.rit.edu> Co-authored-by: Will Gebhardt <will@gebhardts.net>
1 parent bbea397 commit 54ec2dd

23 files changed

+72
-72
lines changed

README.md

+6-10
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,15 @@
22

33
<img src="docs/images/ngc-learn-logo.png" width="300">
44

5-
<b>ngc-learn</b> is a Python library for building, simulating, and analyzing
6-
biomimetic systems, neurobiological agents, spiking neuronal networks,
7-
predictive coding circuitry, and models that learn via biologically-plausible
8-
forms of credit assignment. This simulation toolkit is built on top of JAX and is
9-
distributed under the 3-Clause BSD license.
5+
<b>ngc-learn</b> is a Python library for building, simulating, and analyzing biophysical / neurobiological systems, spiking neuronal networks, predictive coding circuitry, and biomimetic (NeuroAI) agents that learn in a biologically-plausible manner. This simulation toolkit, meant to support computational neuroscience and brain-inspired computing research, is built on top of JAX and is distributed under the 3-Clause BSD license.
106

117
It is currently maintained by the
128
<a href="https://www.cs.rit.edu/~ago/nac_lab.html">Neural Adaptive Computing (NAC) laboratory</a>.
139

1410
## <b>Documentation</b>
1511

1612
Official documentation, including tutorials, can be found
17-
<a href="https://ngc-learn.readthedocs.io/en/latest/#">here</a>. The model museum repo,
13+
<a href="https://ngc-learn.readthedocs.io/en/latest/#">here</a>. The model museum repo (ngc-museum),
1814
which implements several historical models, can be found
1915
<a href="https://github.com/NACLab/ngc-museum">here</a>.
2016

@@ -36,8 +32,8 @@ ngc-learn requires:
3632
1) Python (>=3.10)
3733
2) NumPy (>=1.26.0)
3834
3) SciPy (>=1.7.0)
39-
4) ngcsimlib (>=0.3.b4), (visit official page <a href="https://github.com/NACLab/ngc-sim-lib">here</a>)
40-
5) JAX (>= 0.4.28) (to enable GPU use, make sure to install one of the CUDA variants)
35+
4) ngcsimlib (>=1.0.0), (visit official page <a href="https://github.com/NACLab/ngc-sim-lib">here</a>)
36+
5) JAX (>=0.4.28) (to enable GPU use, make sure to install one of the CUDA variants)
4137
<!--
4238
5) scikit-learn (>=1.3.1) if using `ngclearn.utils.density`
4339
6) matplotlib (>=3.4.3) if using `ngclearn.utils.viz`
@@ -46,7 +42,7 @@ ngc-learn requires:
4642
-->
4743

4844
---
49-
ngc-learn 1.2.beta2 and later require Python 3.10 or newer as well as ngcsimlib >=0.3.b4.
45+
ngc-learn 2.0.0 and later require Python 3.10 or newer as well as ngcsimlib >=1.0.0.
5046
ngc-learn's plotting capabilities (routines within `ngclearn.utils.viz`) require
5147
Matplotlib (>=3.8.0) and imageio (>=2.31.5) and both plotting and density estimation
5248
tools (routines within ``ngclearn.utils.density``) will require Scikit-learn (>=0.24.2).
@@ -66,7 +62,7 @@ running the above pip command if you want to use the GPU version.
6662

6763
The documentation includes more detailed
6864
<a href="https://ngc-learn.readthedocs.io/en/latest/installation.html">installation instructions</a>.
69-
Note that this library was developed on Ubuntu 20.04 and tested on Ubuntu(s) 18.04 and 20.04.
65+
Note that this library was developed on Ubuntu 20.04/22.04 and tested on Ubuntu(s) 20.04 and 22.04.
7066

7167
If the installation was successful, you should see the following if you test
7268
it against your Python interpreter, i.e., run the <code>$ python</code> command

docs/museum/snn_bfa.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ of $1000$ `SLIF` cells) similar to the one below:
233233
<img src="../images/museum/bfa_snn/bfasnn_codes.jpg" width="450" />
234234

235235
Intriguingly, we see that the latent codes represented by the BFA-SNN's hidden
236-
layer spikes yield a rather (piecewise) linearly-separable transformation
236+
layer spikes yield a rather (piecewise) linearly-separable representation
237237
of the input digits, making the process of mapping inputs to label vectors
238238
much easier for the model's second layer of classification LIF units.
239239
Note that, in the `BFA_SNN` model exhibit class, we estimated

docs/requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ numpy>=1.26.0
55
scikit-learn>=0.24.2
66
scipy>=1.7.0
77
matplotlib>=3.8.0
8-
jax>=0.4.18
9-
jaxlib>=0.4.18
8+
jax>=0.4.28
9+
jaxlib>=0.4.28
1010
imageio>=2.31.5
11-
ngcsimlib>=0.3.b4
11+
ngcsimlib>=1.0.0

docs/tutorials/model_basics/evolving_synapses.md

+5-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ We do this specifically as follows:
1919
```python
2020
from jax import numpy as jnp, random, jit
2121
from ngcsimlib.context import Context
22-
from ngcsimlib.compilers.process import Process
22+
from ngclearn.utils import JaxProcess
2323
from ngclearn.components import HebbianSynapse, RateCell
2424
import ngclearn.utils.weight_distribution as dist
2525

@@ -49,16 +49,16 @@ with Context("Circuit") as circuit:
4949
Wab.post << b.zF
5050

5151
## create and compile core simulation commands
52-
evolve_process = (Process()
52+
evolve_process = (JaxProcess()
5353
>> a.evolve)
5454
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
5555

56-
advance_process = (Process()
56+
advance_process = (JaxProcess()
5757
>> a.advance_state)
5858
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
5959

60-
reset_process = (Process()
61-
>> a.reset)
60+
reset_process = (JaxProcess()
61+
>> a.reset)
6262
circuit.wrap_and_add_command(jit(reset_process.pure), name="reset")
6363

6464
## set up non-compiled utility commands
@@ -83,7 +83,6 @@ for ts in range(x_seq.shape[1]):
8383
circuit.advance(t=ts*1., dt=1.)
8484
circuit.evolve(t=ts*1., dt=1.)
8585
print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.value))
86-
8786
```
8887

8988
Your code should produce the same output (towards the bottom):

docs/tutorials/model_basics/model_building.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ While building our dynamical system we will set up a Context and then add the th
1010
```python
1111
from jax import numpy as jnp, random
1212
from ngclearn import Context
13+
from ngclearn.utils import JaxProcess
1314
from ngcsimlib.compilers.process import Process
1415
from ngclearn.components import RateCell, HebbianSynapse
1516
import ngclearn.utils.weight_distribution as dist
@@ -71,13 +72,13 @@ This is simply done with the use of the following convenience function calls:
7172

7273
```python
7374
## configure desired commands for simulation object
74-
reset_process = (Process()
75+
reset_process = (JaxProcess()
7576
>> a.reset
7677
>> Wab.reset
7778
>> b.reset)
7879
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
7980

80-
advance_process = (Process()
81+
advance_process = (JaxProcess()
8182
>> a.advance_state
8283
>> Wab.advance_state
8384
>> b.advance_state)

docs/tutorials/neurocog/adex_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import numpy as np
2424

2525
from ngclearn.utils.model_utils import scanner
2626
from ngcsimlib.context import Context
27-
from ngcsimlib.compilers.process import Process
27+
from ngclearn.utils import JaxProcess
2828
## import model-specific mechanisms
2929
from ngclearn.components.neurons.spiking.adExCell import AdExCell
3030

@@ -48,11 +48,11 @@ with Context("Model") as model:
4848
)
4949

5050
## create and compile core simulation commands
51-
advance_process = (Process()
51+
advance_process = (JaxProcess()
5252
>> cell.advance_state)
5353
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5454

55-
reset_process = (Process()
55+
reset_process = (JaxProcess()
5656
>> cell.reset)
5757
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5858

docs/tutorials/neurocog/error_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The code you would write amounts to the below:
5454
```python
5555
from jax import numpy as jnp, jit
5656
from ngcsimlib.context import Context
57-
from ngcsimlib.compilers.process import Process, transition
57+
from ngclearn.utils import JaxProcess
5858
## import model-specific mechanisms
5959
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
6060

@@ -64,11 +64,11 @@ T = 5 ## number time steps to simulate
6464
with Context("Model") as model:
6565
cell = GaussianErrorCell("z0", n_units=3)
6666

67-
advance_process = (Process()
67+
advance_process = (JaxProcess()
6868
>> cell.advance_state)
6969
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
7070

71-
reset_process = (Process()
71+
reset_process = (JaxProcess()
7272
>> cell.reset)
7373
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
7474

docs/tutorials/neurocog/fitzhugh_nagumo_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ from jax import numpy as jnp, random, jit
1818
import numpy as np
1919

2020
from ngcsimlib.context import Context
21-
from ngcsimlib.compilers.process import Process
21+
from ngclearn.utils import JaxProcess
2222
## import model-specific mechanisms
2323
from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
2424

@@ -40,11 +40,11 @@ with Context("Model") as model:
4040
gamma=gamma, v0=v0, w0=w0, integration_type="euler")
4141

4242
## create and compile core simulation commands
43-
advance_process = (Process()
43+
advance_process = (JaxProcess()
4444
>> cell.advance_state)
4545
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
4646

47-
reset_process = (Process()
47+
reset_process = (JaxProcess()
4848
>> cell.reset)
4949
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5050

docs/tutorials/neurocog/hebbian.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ Wab.post << b.zF
3838
as well as (a bit later in the model construction code):
3939

4040
```python
41-
evolve_process = (Process()
41+
evolve_process = (JaxProcess()
4242
>> a.evolve)
4343
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
4444

45-
advance_process = (Process()
45+
advance_process = (JaxProcess()
4646
>> a.advance_state)
4747
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
4848
```

docs/tutorials/neurocog/hodgkin_huxley_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import numpy as np
2424

2525
from ngclearn.utils.model_utils import scanner
2626
from ngcsimlib.context import Context
27-
from ngcsimlib.compilers.process import Process
27+
from ngclearn.utils import JaxProcess
2828
## import model-specific mechanisms
2929
from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
3030

@@ -52,11 +52,11 @@ with Context("Model") as model:
5252
)
5353

5454
## create and compile core simulation commands
55-
advance_process = (Process()
55+
advance_process = (JaxProcess()
5656
>> cell.advance_state)
5757
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5858

59-
reset_process = (Process()
59+
reset_process = (JaxProcess()
6060
>> cell.reset)
6161
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
6262

docs/tutorials/neurocog/input_cells.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ spike train over $100$ steps in time as follows:
4040
```python
4141
from jax import numpy as jnp, random, jit
4242
from ngcsimlib.context import Context
43-
from ngcsimlib.compilers.process import Process
43+
from ngclearn.utils import JaxProcess
4444

4545
from ngclearn.utils.viz.raster import create_raster_plot
4646
## import model-specific mechanisms
@@ -56,11 +56,11 @@ T = 100 ## number time steps to simulate
5656
with Context("Model") as model:
5757
cell = BernoulliCell("z0", n_units=10, key=subkeys[0])
5858

59-
advance_process = (Process()
59+
advance_process = (JaxProcess()
6060
>> cell.advance_state)
6161
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
6262

63-
reset_process = (Process()
63+
reset_process = (JaxProcess()
6464
>> cell.reset)
6565
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
6666

docs/tutorials/neurocog/izhikevich_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ from jax import numpy as jnp, random, jit
2020
import numpy as np
2121

2222
from ngcsimlib.context import Context
23-
from ngcsimlib.compilers.process import Process
23+
from ngclearn.utils import JaxProcess
2424
## import model-specific mechanisms
2525
from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell
2626

@@ -44,11 +44,11 @@ with Context("Model") as model:
4444
integration_type="euler", v0=v0, w0=w0, key=subkeys[0])
4545

4646
## create and compile core simulation commands
47-
advance_process = (Process()
47+
advance_process = (JaxProcess()
4848
>> cell.advance_state)
4949
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5050

51-
reset_process = (Process()
51+
reset_process = (JaxProcess()
5252
>> cell.reset)
5353
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5454

docs/tutorials/neurocog/lif.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ cell, you would write code akin to the following:
2525
from jax import numpy as jnp, random, jit
2626

2727
from ngcsimlib.context import Context
28-
from ngcsimlib.compilers.process import Process
28+
from ngclearn.utils import JaxProcess
2929
## import model-specific mechanisms
3030
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
3131
from ngclearn.utils.viz.spike_plot import plot_spiking_neuron
@@ -47,11 +47,11 @@ with Context("Model") as model:
4747
refract_time=2., key=subkeys[0])
4848

4949
## create and compile core simulation commands
50-
advance_process = (Process()
50+
advance_process = (JaxProcess()
5151
>> cell.advance_state)
5252
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5353

54-
reset_process = (Process()
54+
reset_process = (JaxProcess()
5555
>> cell.reset)
5656
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5757

docs/tutorials/neurocog/mod_stdp.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ and the required compiled simulation and dynamic commands, can be done as follow
4242
```python
4343
from jax import numpy as jnp, random, jit
4444
from ngcsimlib.context import Context
45-
from ngcsimlib.compilers.process import Process
45+
from ngclearn.utils import JaxProcess
4646
## import model-specific mechanisms
4747
from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse,
4848
RewardErrorCell, VarTrace)
@@ -75,13 +75,13 @@ with Context("Model") as model:
7575
tr1 = VarTrace("tr1", n_units=1, tau_tr=tau_post, a_delta=Aminus)
7676
rpe = RewardErrorCell("r", n_units=1, alpha=0.)
7777

78-
evolve_process = (Process()
78+
evolve_process = (JaxProcess()
7979
>> W_stdp.evolve
8080
>> W_mstdp.evolve
8181
>> W_mstdpet.evolve)
8282
model.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
8383

84-
advance_process = (Process()
84+
advance_process = (JaxProcess()
8585
>> tr0.advance_state
8686
>> tr1.advance_state
8787
>> rpe.advance_state
@@ -90,7 +90,7 @@ with Context("Model") as model:
9090
>> W_mstdpet.advance_state)
9191
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
9292

93-
reset_process = (Process()
93+
reset_process = (JaxProcess()
9494
>> W_stdp.reset
9595
>> W_mstdp.reset
9696
>> W_mstdpet.reset

docs/tutorials/neurocog/rate_cell.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ specifically the rate-cell (RateCell). Let's start with the file's header
1616

1717
```python
1818
from jax import numpy as jnp, random, jit
19-
from ngcsimlib.compilers.process import Process, transition
19+
from ngclearn.utils import JaxProcess
2020
from ngcsimlib.context import Context
2121
## import model-specific elements
2222
from ngclearn.components.neurons.graded.rateCell import RateCell
@@ -40,11 +40,11 @@ with Context("Model") as model: ## model/simulation definition
4040
prior=("gaussian", gamma), integration_type="euler", key=subkeys[0])
4141

4242
## instantiate desired core commands that drive the simulation
43-
advance_process = (Process()
43+
advance_process = (JaxProcess()
4444
>> cell.advance_state)
4545
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
4646

47-
reset_process = (Process()
47+
reset_process = (JaxProcess()
4848
>> cell.reset)
4949
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5050

docs/tutorials/neurocog/short_term_plasticity.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ STF-dominated dynamics):
6060
```python
6161
from jax import numpy as jnp, random, jit
6262
from ngcsimlib.context import Context
63-
from ngcsimlib.compilers.process import Process
63+
from ngclearn.utils import JaxProcess
6464
## import model-specific mechanisms
6565
from ngclearn.components import PoissonCell, STPDenseSynapse, LIFCell
6666
import ngclearn.utils.weight_distribution as dist
@@ -98,13 +98,13 @@ with Context("Model") as model:
9898
W.inputs << z0.outputs ## z0 -> W
9999
z1.j << W.outputs ## W -> z1
100100

101-
advance_process = (Process()
101+
advance_process = (JaxProcess()
102102
>> z0.advance_state
103103
>> W.advance_state
104104
>> z1.advance_state)
105105
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
106106

107-
reset_process = (Process()
107+
reset_process = (JaxProcess()
108108
>> z0.reset
109109
>> z1.reset
110110
>> W.reset)

0 commit comments

Comments
 (0)