From c10af795b1ec6b43613275629b4110d8333813bc Mon Sep 17 00:00:00 2001 From: Alex Ororbia Date: Sat, 12 Apr 2025 18:24:15 -0400 Subject: [PATCH 1/4] further nudge from main to release (#104) * generalized rate-cell a bit * touched up rate-cell further * minor mod to lif * updated lif-cell to use units/tags and minor cleanup and edits * Monitor plot (#66) * Update base_monitor.py * added plotting viewed compartments * added meta-data to rate-cell, input encoders, adex * fixed minor saving/loading in rate-cell w/ vectorized compartments * Added auto resolving for monitors (#67) * fixed surr arg in lif-cell * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * Patched synapses added (#68) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Fix typo in pcn_discrim.md (#69) * model_utils and rate cell (#70) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * Create hierarchical_sc.md 1 * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Delete docs/images/hgpc_network.pdf * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create hgpc * Delete docs/images/museum/hgpc * Create d * Add files via upload * Delete docs/images/hgpc_model.png * Delete docs/images/museum/hgpc/d * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Delete docs/images/museum/hgpc/Input_layer.png * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create Generative_PC.md * Update and rename Generative_PC.md to generative_pc.md * Update generative_pc.md * Update generative_pc.md * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update rateCell.py * Update generative_pc.md * Create pc-sindy.md * Update pc-sindy.md * Update model_utils.py sine activation function added * Update model_utils.py * Update ode_utils.py jitified * Delete docs/museum/hierarchical_sc.md * Delete docs/museum/generative_pc.md * Delete ngclearn/components/synapses/patched directory * Update __init__.py * Add files via upload ode with scanner added * Update ode_solver.py _ removed * Fix/reorganize feature library (#74) * Update ode_utils.py * Update ode_solver.py rk4 revised and __main__ added * Delete ngclearn/utils/diffeq/ode_functions.py * Create odes.py odes name and structure changed * Update __init__.py * Create feature_library.py * Create __init__.py * Create base.py * Delete docs/museum/pc-sindy.md * Create m.md * Add files via upload * Delete docs/images/museum/sindy/m.md * Add files via upload * Create sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * fix: correct feature library path and directory name * Delete ngclearn/utils/dymbolic_dictionary directory * Update model_utils.py (#78) * Additions for inhibition stuff * add sindy documentation for exhibits (#81) * Add files via upload * Add files via upload * Update ode_utils.py (#79) refactor: delete @partial(jit, static_argnums=(2, )) lines Co-authored-by: Will Gebhardt * Add patched synapse (#80) * Update __init__.py Add point to patched components * Add patched in __init__.py Add patched synapses importing * Add patched synaptic components * Delete ngclearn/components/synapses/patched/__pycache__ directory * Update __init__.py new line characters added * Update hebbianPatchedSynapse.py * Update patchedSynapse.py new line characters added * Update staticPatchedSynapse.py new line characters added * Update staticPatchedSynapse.py New line characters + comments for describing each input vars * Update patchedSynapse.py Removed a comment line * Update hebbianPatchedSynapse.py remove unused arguments * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py add description for w_mask * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py * Update patchedSynapse.py * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update __init__.py (#83) * Update __init__.py typo fixed * Update staticPatchedSynapse.py a typo fixed * Update hebbianPatchedSynapse.py typo foxed * Add l1 decay term to update calculation (#84) * Update hebbianSynapse.py * update main update main at the end * Update hebbianSynapse.py add regularization argument and w_decay is deprecated. * Update hebbianSynapse.py add elastic_net * Update hebbianSynapse.py * Update hebbianSynapse.py * feat NGC module regression (#86) * feat npc module regression * Update __init__.py * Update __init__.py * Update elastic_net.py * Update lasso.py * Update ridge.py * Update elastic_net.py * Update ridge.py * Update lasso.py * Update odes.py removed @partial(jit, static_argnums=(0,)) * Update odes.py (#87) removed @partial(jit, static_argnums=(0,)) * Update odes.py typo fixed in __main__ * Update __init__.py add dot * Update __init__.py add dot * Add attribute 'lr' (#90) * Update elastic_net.py * add lr as attribute to lasso.py * add lr as attribute to ridge.py * refactor w_bound=0. for weights elastic_net.py deactivated w_bound for weights elastic_net.py * Update lasso.py * deactivated w_bound for weights ridge.py * commit probes/mods to utils to analysis_tools branch * commit probes/mods to utils to analysis_tools branch * update documentation * cleaned up probes/docs for probes * change heads_dim to attn_dim, and modify the mlp to be as similar as possible to the attentive probing pattern * in layer normalization or any other Gaussian, standardeviation can never be zero. Additionally, if the subtraction inside the square root goes to zero, the gradient will become NaN. Therefore, adding a clipping is necessary. * update attentive probe code * minor tweak to attentive prob code comments * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * minor edits to attn probe * update attentive probe with input layer norm * update input layer normalization * update code to fix nan bug * minor tweak to attn probe * cleaned up probes * cleaned up probes * cleaned up probes * cleaned up probes * generalized dropout in terms of shape * tweak to atten probe * tweak to atten probe * added silu/swish/elu to model_utils * cleaned up model_utils * fix bug in attention probe dropout, fix bug in None noise_key passed in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention * hyperparameter tunning arguments added * Merging over Dynamics feature branch to main (#92) * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Additions for inhibition stuff * update to API modeling docs to reflect RAF neuronal cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * remove unused local variables * update note * update model utils * remove notes * Update ode utils (#94) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * minor fix to header in diffeq * Update files with ode_solver (#95) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * Update odes.py * Update sindy.md ode_solver to ode_utils * revised/cleaned up sindy tutorial doc/imgs * add prior for hebbian patched synapse (#96) * prior replaced w_decay hebbianPatchedSynapse.py remove w_decay add prior_type and prior_lmbda * revised typo hebbianSynapse.py dWweight was typo * cleaned up doc-strings in odes.py to comply w/ ngc-learn format * minor tweak to sig-figs printing in probe utils * add-sigma-to-gaussianErrorCell (#97) * add-sigma-to-gaussianErrorCell add not updating scalar variance for gaussian errors * Update gaussianErrorCell.py * cleaned up ode_utils, cleaned up gaussian/laplacian cell * Update gaussianErrorCell.py (#98) added `and not isinstance(sigma, int)` * cleaned up gauss/laplace error cells * integrated bernoulli err-cell * Major release update merge to main (in prep for 2.0.0 release on release branch/pip) (#99) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * update test cases --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update (to 2.0.0) (#100) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update merge to main (sync up) (#101) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * update test cases * added hh-plot for hh tutorial * tweak to img folder for sindy --------- Co-authored-by: Will Gebhardt Co-authored-by: Alexander Ororbia Co-authored-by: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Co-authored-by: Sonny George <56851635+sonnygeorge@users.noreply.github.com> Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen Co-authored-by: Viet Nguyen Co-authored-by: Alexander Ororbia --- docs/museum/sindy.md | 3 +-- docs/tutorials/neurocog/hodgkin_huxley_cell.md | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md index 17ef4fee..d6e5a229 100644 --- a/docs/museum/sindy.md +++ b/docs/museum/sindy.md @@ -8,8 +8,7 @@ Flow diagrams lack clear directional indicators Inconsistent color schemes across visualizations --> - -# Sparse Identification of Non-linear Dynamical Systems (SINDy)[1] +# Sparse Identification of Non-linear Dynamical Systems (SINDy) In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will: diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md index 1580ab1c..44e3b0a7 100755 --- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md +++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md @@ -83,9 +83,9 @@ Formally, the core dynamics of the H-H cell can be written out as follows: $$ \tau_v \frac{\partial \mathbf{v}_t}{\partial t} &= \mathbf{j}_t - g_Na * \mathbf{m}^3_t * \mathbf{h}_t * (\mathbf{v}_t - v_Na) - g_K * \mathbf{n}^4_t * (\mathbf{v}_t - v_K) - g_L * (\mathbf{v}_t - v_L) \\ -\frac{\partial \mathbf{n}_t}{\partial t} &= alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ -\frac{\partial \mathbf{m}_t}{\partial t} &= alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ -\frac{\partial \mathbf{h}_t}{\partial t} &= alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - beta_h(\mathbf{v}_t) * \mathbf{h}_t +\frac{\partial \mathbf{n}_t}{\partial t} &= \alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - \beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ +\frac{\partial \mathbf{m}_t}{\partial t} &= \alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - \beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ +\frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t $$ where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to. From ffc56605c5a548ba0bcee853f8a7c67232392664 Mon Sep 17 00:00:00 2001 From: Alex Ororbia Date: Sat, 12 Apr 2025 22:44:13 -0400 Subject: [PATCH 2/4] nudge from main to release (update to sindy tutorial) (#105) * generalized rate-cell a bit * touched up rate-cell further * minor mod to lif * updated lif-cell to use units/tags and minor cleanup and edits * Monitor plot (#66) * Update base_monitor.py * added plotting viewed compartments * added meta-data to rate-cell, input encoders, adex * fixed minor saving/loading in rate-cell w/ vectorized compartments * Added auto resolving for monitors (#67) * fixed surr arg in lif-cell * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * Patched synapses added (#68) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Fix typo in pcn_discrim.md (#69) * model_utils and rate cell (#70) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * Create hierarchical_sc.md 1 * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Delete docs/images/hgpc_network.pdf * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create hgpc * Delete docs/images/museum/hgpc * Create d * Add files via upload * Delete docs/images/hgpc_model.png * Delete docs/images/museum/hgpc/d * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Delete docs/images/museum/hgpc/Input_layer.png * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create Generative_PC.md * Update and rename Generative_PC.md to generative_pc.md * Update generative_pc.md * Update generative_pc.md * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update rateCell.py * Update generative_pc.md * Create pc-sindy.md * Update pc-sindy.md * Update model_utils.py sine activation function added * Update model_utils.py * Update ode_utils.py jitified * Delete docs/museum/hierarchical_sc.md * Delete docs/museum/generative_pc.md * Delete ngclearn/components/synapses/patched directory * Update __init__.py * Add files via upload ode with scanner added * Update ode_solver.py _ removed * Fix/reorganize feature library (#74) * Update ode_utils.py * Update ode_solver.py rk4 revised and __main__ added * Delete ngclearn/utils/diffeq/ode_functions.py * Create odes.py odes name and structure changed * Update __init__.py * Create feature_library.py * Create __init__.py * Create base.py * Delete docs/museum/pc-sindy.md * Create m.md * Add files via upload * Delete docs/images/museum/sindy/m.md * Add files via upload * Create sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * fix: correct feature library path and directory name * Delete ngclearn/utils/dymbolic_dictionary directory * Update model_utils.py (#78) * Additions for inhibition stuff * add sindy documentation for exhibits (#81) * Add files via upload * Add files via upload * Update ode_utils.py (#79) refactor: delete @partial(jit, static_argnums=(2, )) lines Co-authored-by: Will Gebhardt * Add patched synapse (#80) * Update __init__.py Add point to patched components * Add patched in __init__.py Add patched synapses importing * Add patched synaptic components * Delete ngclearn/components/synapses/patched/__pycache__ directory * Update __init__.py new line characters added * Update hebbianPatchedSynapse.py * Update patchedSynapse.py new line characters added * Update staticPatchedSynapse.py new line characters added * Update staticPatchedSynapse.py New line characters + comments for describing each input vars * Update patchedSynapse.py Removed a comment line * Update hebbianPatchedSynapse.py remove unused arguments * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py add description for w_mask * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py * Update patchedSynapse.py * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update __init__.py (#83) * Update __init__.py typo fixed * Update staticPatchedSynapse.py a typo fixed * Update hebbianPatchedSynapse.py typo foxed * Add l1 decay term to update calculation (#84) * Update hebbianSynapse.py * update main update main at the end * Update hebbianSynapse.py add regularization argument and w_decay is deprecated. * Update hebbianSynapse.py add elastic_net * Update hebbianSynapse.py * Update hebbianSynapse.py * feat NGC module regression (#86) * feat npc module regression * Update __init__.py * Update __init__.py * Update elastic_net.py * Update lasso.py * Update ridge.py * Update elastic_net.py * Update ridge.py * Update lasso.py * Update odes.py removed @partial(jit, static_argnums=(0,)) * Update odes.py (#87) removed @partial(jit, static_argnums=(0,)) * Update odes.py typo fixed in __main__ * Update __init__.py add dot * Update __init__.py add dot * Add attribute 'lr' (#90) * Update elastic_net.py * add lr as attribute to lasso.py * add lr as attribute to ridge.py * refactor w_bound=0. for weights elastic_net.py deactivated w_bound for weights elastic_net.py * Update lasso.py * deactivated w_bound for weights ridge.py * commit probes/mods to utils to analysis_tools branch * commit probes/mods to utils to analysis_tools branch * update documentation * cleaned up probes/docs for probes * change heads_dim to attn_dim, and modify the mlp to be as similar as possible to the attentive probing pattern * in layer normalization or any other Gaussian, standardeviation can never be zero. Additionally, if the subtraction inside the square root goes to zero, the gradient will become NaN. Therefore, adding a clipping is necessary. * update attentive probe code * minor tweak to attentive prob code comments * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * minor edits to attn probe * update attentive probe with input layer norm * update input layer normalization * update code to fix nan bug * minor tweak to attn probe * cleaned up probes * cleaned up probes * cleaned up probes * cleaned up probes * generalized dropout in terms of shape * tweak to atten probe * tweak to atten probe * added silu/swish/elu to model_utils * cleaned up model_utils * fix bug in attention probe dropout, fix bug in None noise_key passed in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention * hyperparameter tunning arguments added * Merging over Dynamics feature branch to main (#92) * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Additions for inhibition stuff * update to API modeling docs to reflect RAF neuronal cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * remove unused local variables * update note * update model utils * remove notes * Update ode utils (#94) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * minor fix to header in diffeq * Update files with ode_solver (#95) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * Update odes.py * Update sindy.md ode_solver to ode_utils * revised/cleaned up sindy tutorial doc/imgs * add prior for hebbian patched synapse (#96) * prior replaced w_decay hebbianPatchedSynapse.py remove w_decay add prior_type and prior_lmbda * revised typo hebbianSynapse.py dWweight was typo * cleaned up doc-strings in odes.py to comply w/ ngc-learn format * minor tweak to sig-figs printing in probe utils * add-sigma-to-gaussianErrorCell (#97) * add-sigma-to-gaussianErrorCell add not updating scalar variance for gaussian errors * Update gaussianErrorCell.py * cleaned up ode_utils, cleaned up gaussian/laplacian cell * Update gaussianErrorCell.py (#98) added `and not isinstance(sigma, int)` * cleaned up gauss/laplace error cells * integrated bernoulli err-cell * Major release update merge to main (in prep for 2.0.0 release on release branch/pip) (#99) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * update test cases --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update (to 2.0.0) (#100) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update merge to main (sync up) (#101) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * update test cases * added hh-plot for hh tutorial * tweak to img folder for sindy * update to sindy tutorial to adhere to readthedocs formatting --------- Co-authored-by: Will Gebhardt Co-authored-by: Alexander Ororbia Co-authored-by: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Co-authored-by: Sonny George <56851635+sonnygeorge@users.noreply.github.com> Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen Co-authored-by: Viet Nguyen Co-authored-by: Alexander Ororbia --- docs/museum/sindy.md | 317 +++++++++++-------------------------------- 1 file changed, 77 insertions(+), 240 deletions(-) diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md index d6e5a229..04426d70 100644 --- a/docs/museum/sindy.md +++ b/docs/museum/sindy.md @@ -27,19 +27,24 @@ SINDy is a data-driven algorithm that discovers the governing behavior of a dyna ### SINDy Dynamics -If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount. -$$$ +If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount: + +$$ d\mathbf{X} = \mathbf{Ẋ}(t)~dt -$$$ +$$ + SINDy models the derivative[^1] (a linear operation) as linear transformations with: [^1]: The derivative is a linear operation that acts on $dt$ and gives a differential that is the linearized approximation of the taylor series of the function. -$$$ + +$$ \frac{d\mathbf{X}(t)}{dt} = \mathbf{Ẋ}(t) = \mathbf{f}(\mathbf{X}(t)) -$$$ +$$ + SINDy assumes that this linear operation, i.e., $\mathbf{f}(\mathbf{X}(t))$, is a matrix multiplication that linearly combines the relevant predictors in order to describe the system's equation. -$$$ + +$$ \mathbf{f}(\mathbf{X}(t)) = \mathbf{\Theta}(\mathbf{X})~\mathbf{W} -$$$ +$$ Given a group of candidate functions within the library $\mathbf{\Theta}(\mathbf{X})$, the coefficients in $\mathbf{W}$ that choose the library terms are to be **sparse**. In other words, there are only a few functions that exist in the system's differential equation. Given these assumptions, SINDy solves a sparse regression problem in order to find the $\mathbf{W}$ that maps the library of selected terms to each feature of the system being identified. SINDy imposes parsimony constraints over the resulting symbolic regression (i.e., genetic programming) to describe a dynamical system's behavior with as few terms as possible. In order to select a sparse set of the given features, the model adds the LASSO regularizarion penalty (i.e., an L1 norm constraint) to the regression problem and solves the sparse regression or solves the regression problem via STLSQ. We will describe STLSQ in third step of the SINDy dynamics/process. @@ -47,206 +52,101 @@ In essence, SINDy's dynamics can be presented in three main phases, visualized i ------------------------------------------------------------------------------------------ -

- - + **Figure 1:** **The flow of the three phases in SINDy.** **Phase-1)** Data collection: capturing system states that are changing in time and creating the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 to compute its derivative with respect to time. **Phase-3)** Solving the sparse regression problem. -

------------------------------------------------------------------------------------------ - - - - - - - - -
- ## Phase 1: Collecting Dataset → $\mathbf{X}_{(m \times n)}$ This phase involves gathering the raw data points representing the system's states across time. In this example, this means capturing the $x$, $y$, and $z$ coordinates of the system's states. Here, $m$ represents the number of data points (number of the snapshots/length of time) and $n$ is the system's dimensionality. - -

- Dataset collection showing x, y, z coordinates - -

-
- + + + + + - + - - - - - - - - - - - +Library of Candidate Functions: +$\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - - - - - + + -
- ## Phase 2: Processing - -

- -

-
+ + ### 2.A: Making the Library → $\mathbf{\Theta}_{(m \times p)}$ In this step, using the dataset collected in phase 1, given pre-defined function terms, we construct a dictionary of candidate predictors for identifying the target system's differential equations. These functions form the columns of our library matrix $\mathbf{\Theta}(\mathbf{X})$ and $p$ is the number of candidate predictors. To identify the dynamical structure of the system, this library of candidate functions appears in the regression problem to propose the model's structure that will later serve as the coefficient matrix for weighting the functions according to the problem setup. We assume sparse models will be sufficient to identify the system and do this through sparsification (LASSO or thresholding weights) in order decide which structure best describes the system's behavior using predictors. Given a set of time-series measurements of a dynamical system state variables ($\mathbf{X}_{(m \times n)}$) we construct the following: -Library of Candidate Functions: $\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - -

- -

-
- + + ### 2.B: Compute State Derivatives → $\mathbf{Ẋ}_{(m \times n)}$ Given a set of time-series measurements of a dynamical system's state variables $\mathbf{X}_{(m \times n)}$, we next construct the derivative matrix: $\mathbf{Ẋ}_{(m \times n)}$ (computed numerically). In this step, using the dataset collected in phase 1, we compute the derivatives of each state variable with respect to time. In this example, we compute $ẋ$, $ẏ$, and $ż$ in order to capture how the system evolves over time. - -

- -

-
- - - - - - - - - - - -
- ## Phase 3: Solving Sparse Regression Problem → $\mathbf{W_s}_{(p \times n)}$ Solving the resulting sparse regression (SR) problem that results from the phases/steps above can be done using various method such as Lasso, STLSQ, Elastic Net, as well as many other schemes. Here, we describe the STLSQ approach to solve the SR problem according to the SINDy process. - -

- Dataset collection showing x, y, z coordincates -

-
- - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - -
- + + ### Solving Sparse Regression by Sequential Thresholding Least Squares (STLSQ) -

- + + **Figure 1:** **The flow of three phases in SINDy.** **Phase-1)** Data collection: capturing system's states that are changing in time and making the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 and computing its derivative with respect to time. **Phase-3)** Solving the sparse regression problem via STLSQ.

------------------------------------------------------------------------------------------ -
- ### Sequential Thresholding Least Square (STLSQ) -
-

- State derivatives visualization -

-
#### 3.A: Least Square method (LSQ) → $\mathbf{W}$ This step entails finding library coefficients by solving the following regression problem $\mathbf{Ẋ} = \mathbf{\Theta}\mathbf{W}$ analytically $\mathbf{W} = (\mathbf{\Theta}^T \mathbf{\Theta})^{-1} \mathbf{\Theta}^T \mathbf{Ẋ}$ - -

- State derivatives visualization -

-
- + + #### 3.B: Thresholding → $\mathbf{W_s}$ This step entails sparsifying $\mathbf{W}$ by keeping only some of the terms within $\mathbf{W}$, particularly those that correspond to the effective terms in the library. - -

- State derivatives visualization -

-
- + + + #### 3.C: Masking → $\mathbf{\Theta_s}$ This step sparsifies $\mathbf{\Theta}$ by keeping only the corresponding terms in $\mathbf{W}$ that remain (from the prior step). - -

- State derivatives visualization -

-
+ #### 3.D: Repeat A → B → C until convergence We continue to solve LSQ with the sparse matrix $\mathbf{\Theta_s}$ and $\mathbf{W_s}$ and find a new $\mathbf{W}$, repeating steps B and C until convergence. - -

- State derivatives visualization -

-
- - - - + - - - - ## Code: Simulating SINDy We finally present ngc-learn code below for using and simulating the SINDy process to identify several dynamical systems. - - ```python - - - import numpy as np import jax.numpy as jnp from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary @@ -334,31 +224,14 @@ for dim in range(dX.shape[1]): coeff = jnp.where(jnp.abs(coef) >= threshold, coef, 0.) print(f"coefficients for dimension {dim+1}: \n", coeff.T) - - - ``` - - - - ## Results: System Identification Running the above code should produce results similar to the findings we present next. - - - - - - - - - - - - - - - - - - - - - - - +which should produce the following results: - - - - - - - - -
- Model - - Results -
- ## Oscillator +## Oscillator True model's equation \ $\mathbf{ẋ} = \mu_1\mathbf{x} + \sigma \mathbf{xy}$ \ @@ -379,19 +252,13 @@ $\mathbf{ż} = \mu_2\mathbf{z} - (\omega + \alpha \mathbf{y} + \beta \mathbf{z}) [ 0. -0.009 0. -2.99 4.99 1.99 0. 0. 0. 0.]] ``` - -

- - -

-
+which should produce the following results: + + + + - ## Lorenz +## Lorenz True model's equation \ $\mathbf{ẋ} = 10(\mathbf{y} - \mathbf{x})$ \ @@ -399,7 +266,6 @@ $\mathbf{ẏ} = \mathbf{x}(28 - \mathbf{z}) - \mathbf{y}$ \ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ - ```python --- SINDy results ---- ẋ = 9.969 𝑦 -9.966 𝑥 @@ -412,19 +278,12 @@ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ [-2.656 0. 0. 0. 0. 0. 0. 0.996 0.]] ``` - -

- - -

-
- - ## Linear-2D +which should produce the following results: + + + + +## Linear-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2.0\mathbf{y}$ \ @@ -439,21 +298,14 @@ $\mathbf{ẏ} = -2.0\mathbf{x} - 0.1\mathbf{y}$ [[ 1.999 0. -0.100 0. 0.] [-0.099 0. -1.999 0. 0.]] ``` + +which should produce the following results: + + - -

- - -

-
- ## Linear-3D +## Linear-3D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2\mathbf{y}$ \ @@ -472,22 +324,13 @@ $\mathbf{ż} = -0.3\mathbf{z}$ [-0.299 0. 0. 0. 0. 0. 0. 0. 0.]] ``` - -

- - -

-
+ + + - ## Cubic-2D +## Cubic-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x}^3 + 2.0\mathbf{y}^3$ \ @@ -503,16 +346,10 @@ $\mathbf{ẏ} = -2.0\mathbf{x}^3 - 0.1\mathbf{y}^3$ [ 0. 0. -0.099 0. 0. 0. 0. 0. -1.99]] ``` - -

- - -

-
+which should produce the following results: + + + ## References From 7004c3249af21a097fda68163c0566c842f2f314 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sun, 13 Apr 2025 01:12:42 -0400 Subject: [PATCH 3/4] removed some clutter - old files --- .../components/neurons/graded/rateCellOld.py | 350 ------------------ .../synapses/hebbian/hebbianSynapseOld.py | 326 ---------------- 2 files changed, 676 deletions(-) delete mode 100644 ngclearn/components/neurons/graded/rateCellOld.py delete mode 100644 ngclearn/components/synapses/hebbian/hebbianSynapseOld.py diff --git a/ngclearn/components/neurons/graded/rateCellOld.py b/ngclearn/components/neurons/graded/rateCellOld.py deleted file mode 100644 index 6962810c..00000000 --- a/ngclearn/components/neurons/graded/rateCellOld.py +++ /dev/null @@ -1,350 +0,0 @@ -from jax import numpy as jnp, random, jit -from functools import partial -from ngclearn.utils import tensorstats -from ngclearn import resolver, Component, Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils.model_utils import create_function, threshold_soft, \ - threshold_cauchy -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 - -def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): - z_leak = z # * 2 ## Default: assume Gaussian - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.sign(z) ## d/dx of Laplace is signum - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): - z_leak = (z * 2)/(1. + jnp.square(z)) - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.exp(-jnp.square(z)) * z * 2 - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - - -def _dfz_gaussian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_laplacian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_cauchy(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_exp(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -@jit -def _modulate(j, dfx_val): - """ - Apply a signal modulator to j (typically of the form of a derivative/dampening function) - - Args: - j: current/stimulus value to modulate - - dfx_val: modulator signal - - Returns: - modulated j value - """ - return j * dfx_val - -def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): - """ - Runs leaky rate-coded state dynamics one step in time. - - Args: - dt: integration time constant - - j: input (bottom-up) electrical/stimulus current - - j_td: modulatory (top-down) electrical/stimulus pressure - - z: current value of membrane/state - - tau_m: membrane/state time constant - - leak_gamma: strength of leak to apply to membrane/state - - integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) - - priorType: scale-shift prior distribution to impose over neural dynamics - - Returns: - New value of membrane/state for next time step - """ - _dfz = { - 0: _dfz_gaussian, - 1: _dfz_laplacian, - 2: _dfz_cauchy, - 3: _dfz_exp - }.get(priorType, _dfz_gaussian) - if integType == 1: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk2(0., z, _dfz, dt, params) - elif integType == 2: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk4(0., z, _dfz, dt, params) - else: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_euler(0., z, _dfz, dt, params) - return _z - -@jit -def _run_cell_stateless(j): - """ - A simplification of running a stateless set of dynamics over j (an identity - functional form of dynamics). - - Args: - j: stimulus to do nothing to - - Returns: - the stimulus - """ - return j + 0 - -class RateCell(JaxComponent): ## Rate-coded/real-valued cell - """ - A non-spiking cell driven by the gradient dynamics of neural generative - coding-driven predictive processing. - - The specific differential equation that characterizes this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dz/dt = lambda * prior(z) + (j + j_td) - | where j is the set of general incoming input signals (e.g., message-passed signals) - | and j_td is taken to be the set of top-down pressure signals - - | --- Cell Input Compartments: --- - | j - input pressure (takes in external signals) - | j_td - input/top-down pressure input (takes in external signals) - | --- Cell State Compartments --- - | z - rate activity - | --- Cell Output Compartments: --- - | zF - post-activation function activity, i.e., fx(z) - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - tau_m: membrane/state time constant (milliseconds) - - prior: a kernel for specifying the type of centered scale-shift distribution - to impose over neuronal dynamics, applied to each neuron or - dimension within this component (Default: ("gaussian", 0)); this is - a tuple with 1st element containing a string name of the distribution - one wants to use while the second value is a `leak rate` scalar - that controls the influence/weighting that this distribution - has on the dynamics; for example, ("laplacian, 0.001") means that a - centered laplacian distribution scaled by `0.001` will be injected - into this cell's dynamics ODE each step of simulated time - - :Note: supported scale-shift distributions include "laplacian", - "cauchy", "exp", and "gaussian" - - act_fx: string name of activation function/nonlinearity to use - - integration_type: type of integration to use for this cell's dynamics; - current supported forms include "euler" (Euler/RK-1 integration) - and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") - - :Note: setting the integration type to the midpoint method will - increase the accuray of the estimate of the cell's evolution - at an increase in computational cost (and simulation time) - - resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) - """ - - # Define Functions - def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", - threshold=("none", 0.), integration_type="euler", - batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): - super().__init__(name, **kwargs) - - ## membrane parameter setup (affects ODE integration) - self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode - self.is_stateful = is_stateful - if isinstance(tau_m, float): - if tau_m <= 0: ## trigger stateless mode - self.is_stateful = False - priorType, leakRate = prior - self.priorType = { - "gaussian": 0, - "laplacian": 1, - "cauchy": 2, - "exp": 3 - }.get(priorType, 0) ## type of scale-shift prior to impose over the leak - self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) - thresholdType, thr_lmbda = threshold - self.thresholdType = thresholdType ## type of thresholding function to use - self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics - self.resist_scale = resist_scale ## a "resistance" scaling factor - - ## integration properties - self.integrationType = integration_type - self.intgFlag = get_integrator_code(self.integrationType) - - ## Layer size setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - self.shape = shape - self.n_units = n_units - self.batch_size = batch_size - - omega_0 = None - if act_fx == "sine": - omega_0 = kwargs["omega_0"] - self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) - - # compartments (state of the cell & parameters will be updated through stateless calls) - restVals = jnp.zeros(_shape) - self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current - self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity - self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure - self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - - @staticmethod - def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, - resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z): - #if tau_m > 0.: - if is_stateful: - ### run a step of integration over neuronal dynamics - ## Notes: - ## self.pressure <-- "top-down" expectation / contextual pressure - ## self.current <-- "bottom-up" data-dependent signal - dfx_val = dfx(z) - j = _modulate(j, dfx_val) - j = j * resist_scale - tmp_z = _run_cell(dt, j, j_td, z, - tau_m, leak_gamma=priorLeakRate, - integType=intgFlag, priorType=priorType) - ## apply optional thresholding sub-dynamics - if thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, thr_lmbda) - elif thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, thr_lmbda) - z = tmp_z ## pre-activation function value(s) - zF = fx(z) ## post-activation function value(s) - else: - ## run in "stateless" mode (when no membrane time constant provided) - j_total = j + j_td - z = _run_cell_stateless(j_total) - zF = fx(z) - return j, j_td, z, zF - - @resolver(_advance_state) - def advance_state(self, j, j_td, z, zF): - self.j.set(j) - self.j_td.set(j_td) - self.z.set(z) - self.zF.set(zF) - - @staticmethod - def _reset(batch_size, shape): #n_units - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) - restVals = jnp.zeros(_shape) - return tuple([restVals for _ in range(4)]) - - @resolver(_reset) - def reset(self, j, zF, j_td, z): - self.j.set(j) # electrical current - self.zF.set(zF) # rate-coded output - activity - self.j_td.set(j_td) # top-down electrical current - pressure - self.z.set(z) # rate activity - - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) - priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - else jnp.ones([[self.priorLeakRate]])) - resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - else jnp.ones([[self.resist_scale]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - tau_m=tau_m, priorLeakRate=priorLeakRate, - resist_scale=resist_scale) #, key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - ## constants loaded in - self.tau_m = data['tau_m'] - self.priorLeakRate = data['priorLeakRate'] - self.resist_scale = data['resist_scale'] - #if seeded: - # self.key.set(data['key']) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "RateCell - evolves neurons according to rate-coded/" - "continuous dynamics " - } - compartment_props = { - "inputs": - {"j": "External input stimulus value(s)", - "j_td": "External top-down input stimulus value(s); these get " - "multiplied by the derivative of f(x), i.e., df(x)"}, - "states": - {"z": "Update to rate-coded continuous dynamics; value at time t"}, - "outputs": - {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "tau_m": "Cell state/membrane time constant", - "prior": "What kind of kurtotic prior to place over neuronal dynamics?", - "act_fx": "Elementwise activation function to apply over cell state `z`", - "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", - "integration_type": "Type of numerical integration to use for the cell dynamics", - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = RateCell("X", 9, 0.03) - print(X) \ No newline at end of file diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py b/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py deleted file mode 100644 index 04ebd4cb..00000000 --- a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py +++ /dev/null @@ -1,326 +0,0 @@ -from jax import random, numpy as jnp, jit -from functools import partial -from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args - -@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) -def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., - prior_type=None, prior_lmbda=0., - pre_wght=1., post_wght=1.): - """ - Compute a tensor of adjustments to be applied to a synaptic value matrix. - - Args: - pre: pre-synaptic statistic to drive Hebbian update - - post: post-synaptic statistic to drive Hebbian update - - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: (Unused) - - signVal: multiplicative factor to modulate final update by (good for - flipping the signs of a computed synaptic change matrix) - - prior_type: prior type or name (Default: None) - - prior_lmbda: prior parameter (Default: 0.0) - - pre_wght: pre-synaptic weighting term (Default: 1.) - - post_wght: post-synaptic weighting term (Default: 1.) - - Returns: - an update/adjustment matrix, an update adjustment vector (for biases) - """ - _pre = pre * pre_wght - _post = post * post_wght - dW = jnp.matmul(_pre.T, _post) - db = jnp.sum(_post, axis=0, keepdims=True) - dW_reg = 0. - - if w_bound > 0.: - dW = dW * (w_bound - jnp.abs(W)) - - if prior_type == "l2" or prior_type == "ridge": - dW_reg = W - if prior_type == "l1" or prior_type == "lasso": - dW_reg = jnp.sign(W) - if prior_type == "l1l2" or prior_type == "elastic_net": - l1_ratio = prior_lmbda[1] - prior_lmbda = prior_lmbda[0] - dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2 - - dW = dW + prior_lmbda * dW_reg - return dW * signVal, db * signVal - -@partial(jit, static_argnums=[1,2]) -def _enforce_constraints(W, w_bound, is_nonnegative=True): - """ - Enforces constraints that the (synaptic) efficacies/values within matrix - `W` must adhere to. - - Args: - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: ensure updated value matrix is strictly non-negative - - Returns: - the newly evolved synaptic weight value matrix - """ - _W = W - if w_bound > 0.: - if is_nonnegative == True: - _W = jnp.clip(_W, 0., w_bound) - else: - _W = jnp.clip(_W, -w_bound, w_bound) - return _W - - -class HebbianSynapse(DenseSynapse): - """ - A synaptic cable that adjusts its efficacies via a two-factor Hebbian - adjustment rule. - - | --- Synapse Compartments: --- - | inputs - input (takes in external signals) - | outputs - output signals (transformation induced by synapses) - | weights - current value matrix of synaptic efficacies - | biases - current value vector of synaptic bias values - | key - JAX PRNG key - | --- Synaptic Plasticity Compartments: --- - | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals) - | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals) - | dWeights - current delta matrix containing changes to be applied to synaptic efficacies - | dBiases - current delta vector containing changes to be applied to bias values - | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used) - - Args: - name: the string name of this cell - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - eta: global learning rate - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - bias_init: a kernel to drive initialization of biases for this synaptic cable - (Default: None, which turns off/disables biases) - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - is_nonnegative: enforce that synaptic efficacies are always non-negative - after each synaptic update (if False, no constraint will be applied) - - prior: a kernel to drive prior of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - prior to use and 2nd element as a floating point number - calling the prior parameter lambda (Default: (None, 0.)) - currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net". - usage guide: - prior = ('l1', 0.01) or prior = ('lasso', lmbda) - prior = ('l2', 0.01) or prior = ('ridge', lmbda) - prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio)) - - - - sign_value: multiplicative factor to apply to final synaptic update before - it is applied to synapses; this is useful if gradient descent style - optimization is required (as Hebbian rules typically yield - adjustments for ascent) - - optim_type: optimization scheme to physically alter synaptic values - once an update is computed (Default: "sgd"); supported schemes - include "sgd" and "adam" - - :Note: technically, if "sgd" or "adam" is used but `signVal = 1`, - then the ascent form of each rule is employed (signVal = -1) or - a negative learning rate will mean a descent form of the - `optim_scheme` is being employed - - pre_wght: pre-synaptic weighting factor (Default: 1.) - - post_wght: post-synaptic weighting factor (Default: 1.) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - p_conn: probability of a connection existing (default: 1.); setting - this to < 1. will result in a sparser synaptic structure - """ - - # Define Functions - @deprecate_args(_rebind=False, w_decay='prior') - def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, - w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1., - optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs): - super().__init__(name, shape, weight_init, bias_init, resist_scale, - p_conn, batch_size=batch_size, **kwargs) - - if w_decay > 0.: - prior = ('l2', w_decay) - - prior_type, prior_lmbda = prior - ## synaptic plasticity properties and characteristics - self.shape = shape - self.Rscale = resist_scale - self.prior_type = prior_type - self.prior_lmbda = prior_lmbda - self.w_bound = w_bound - self.pre_wght = pre_wght - self.post_wght = post_wght - self.eta = eta - self.is_nonnegative = is_nonnegative - self.sign_value = sign_value - - ## optimization / adjustment properties (given learning dynamics above) - self.opt = get_opt_step_fn(optim_type, eta=self.eta) - - # compartments (state of the cell, parameters, will be updated through stateless calls) - self.preVals = jnp.zeros((self.batch_size, shape[0])) - self.postVals = jnp.zeros((self.batch_size, shape[1])) - self.pre = Compartment(self.preVals) - self.post = Compartment(self.postVals) - self.dWeights = Compartment(jnp.zeros(shape)) - self.dBiases = Compartment(jnp.zeros(shape[1])) - - #key, subkey = random.split(self.key.value) - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) - - @staticmethod - def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, pre, post, weights): - ## calculate synaptic update values - dW, db = _calc_update( - pre, post, weights, w_bound, is_nonnegative=is_nonnegative, - signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght, - post_wght=post_wght) - return dW, db - - @staticmethod - def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): - ## calculate synaptic update values - dWeights, dBiases = HebbianSynapse._compute_update( - w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, - pre, post, weights - ) - ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) - else: - # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) - ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @resolver(_evolve) - def evolve(self, opt_params, weights, biases, dWeights, dBiases): - self.opt_params.set(opt_params) - self.weights.set(weights) - self.biases.set(biases) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @staticmethod - def _reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) - - @resolver(_reset) - def reset(self, inputs, outputs, pre, post, dWeights, dBiases): - self.inputs.set(inputs) - self.outputs.set(outputs) - self.pre.set(pre) - self.post.set(post) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @classmethod - def help(cls): ## component help function - properties = { - "synapse_type": "HebbianSynapse - performs an adaptable synaptic " - "transformation of inputs to produce output signals; " - "synapses are adjusted via two-term/factor Hebbian adjustment" - } - compartment_props = { - "inputs": - {"inputs": "Takes in external input signal values", - "pre": "Pre-synaptic statistic for Hebb rule (z_j)", - "post": "Post-synaptic statistic for Hebb rule (z_i)"}, - "states": - {"weights": "Synapse efficacy/strength parameter values", - "biases": "Base-rate/bias parameter values", - "key": "JAX PRNG key"}, - "analytics": - {"dWeights": "Synaptic weight value adjustment matrix produced at time t", - "dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"}, - "outputs": - {"outputs": "Output of synaptic transformation"}, - } - hyperparams = { - "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", - "batch_size": "Batch size dimension of this component", - "weight_init": "Initialization conditions for synaptic weight (W) values", - "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (applied to output of transformation)", - "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", - "is_nonnegative": "Should synapses be constrained to be non-negative post-updates?", - "sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0", - "eta": "Global (fixed) learning rate", - "pre_wght": "Pre-synaptic weighting coefficient (q_pre)", - "post_wght": "Post-synaptic weighting coefficient (q_post)", - "w_bound": "Soft synaptic bound applied to synapses post-update", - "prior": "prior name and value for synaptic updating prior", - "optim_type": "Choice of optimizer to adjust synaptic weights" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "outputs = [(W * Rscale) * inputs] + b ;" - "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam', - sign_value=-1.0, prior=("l1l2", 0.001)) - print(Wab) \ No newline at end of file From 53a1e5aede746305c2cebb9c799aecbebdd01d1c Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sun, 13 Apr 2025 01:22:35 -0400 Subject: [PATCH 4/4] minor update to installation doc --- docs/installation.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index cda569a2..b70db6c6 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -6,13 +6,13 @@ without a GPU. Setup: ngc-learn, in its entirety (including its supporting utilities), requires that you ensure that you have installed the following base dependencies in -your system. Note that this library was developed and tested on Ubuntu 22.04 (and 18.04). +your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04). Specifically, ngc-learn requires: * Python (>=3.10) -* ngcsimlib (>=0.3.b4), (official page) +* ngcsimlib (>=1.0.0), (official page) * NumPy (>=1.26.0) * SciPy (>=1.7.0) -* JAX (>= 0.4.18; and jaxlib>=0.4.18) +* JAX (>= 0.4.28; and jaxlib>=0.4.28) * Matplotlib (>=3.4.2), (for `ngclearn.utils.viz`) * Scikit-learn (>=1.3.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`) @@ -45,7 +45,7 @@ $ git clone https://github.com/NACLab/ngc-learn.git $ cd ngc-learn ``` -2. (Optional; only for GPU version) Install JAX for either CUDA 11 or 12 , depending +2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending on your system setup. Follow the installation instructions on the official JAX page to properly install the CUDA 11 or 12 version.