diff --git a/docs/installation.md b/docs/installation.md index cda569a2..b70db6c6 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -6,13 +6,13 @@ without a GPU. Setup: ngc-learn, in its entirety (including its supporting utilities), requires that you ensure that you have installed the following base dependencies in -your system. Note that this library was developed and tested on Ubuntu 22.04 (and 18.04). +your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04). Specifically, ngc-learn requires: * Python (>=3.10) -* ngcsimlib (>=0.3.b4), (official page) +* ngcsimlib (>=1.0.0), (official page) * NumPy (>=1.26.0) * SciPy (>=1.7.0) -* JAX (>= 0.4.18; and jaxlib>=0.4.18) +* JAX (>= 0.4.28; and jaxlib>=0.4.28) * Matplotlib (>=3.4.2), (for `ngclearn.utils.viz`) * Scikit-learn (>=1.3.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`) @@ -45,7 +45,7 @@ $ git clone https://github.com/NACLab/ngc-learn.git $ cd ngc-learn ``` -2. (Optional; only for GPU version) Install JAX for either CUDA 11 or 12 , depending +2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending on your system setup. Follow the installation instructions on the official JAX page to properly install the CUDA 11 or 12 version. diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md index 17ef4fee..04426d70 100644 --- a/docs/museum/sindy.md +++ b/docs/museum/sindy.md @@ -8,8 +8,7 @@ Flow diagrams lack clear directional indicators Inconsistent color schemes across visualizations --> - -# Sparse Identification of Non-linear Dynamical Systems (SINDy)[1] +# Sparse Identification of Non-linear Dynamical Systems (SINDy) In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will: @@ -28,19 +27,24 @@ SINDy is a data-driven algorithm that discovers the governing behavior of a dyna ### SINDy Dynamics -If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount. -$$$ +If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount: + +$$ d\mathbf{X} = \mathbf{Ẋ}(t)~dt -$$$ +$$ + SINDy models the derivative[^1] (a linear operation) as linear transformations with: [^1]: The derivative is a linear operation that acts on $dt$ and gives a differential that is the linearized approximation of the taylor series of the function. -$$$ + +$$ \frac{d\mathbf{X}(t)}{dt} = \mathbf{Ẋ}(t) = \mathbf{f}(\mathbf{X}(t)) -$$$ +$$ + SINDy assumes that this linear operation, i.e., $\mathbf{f}(\mathbf{X}(t))$, is a matrix multiplication that linearly combines the relevant predictors in order to describe the system's equation. -$$$ + +$$ \mathbf{f}(\mathbf{X}(t)) = \mathbf{\Theta}(\mathbf{X})~\mathbf{W} -$$$ +$$ Given a group of candidate functions within the library $\mathbf{\Theta}(\mathbf{X})$, the coefficients in $\mathbf{W}$ that choose the library terms are to be **sparse**. In other words, there are only a few functions that exist in the system's differential equation. Given these assumptions, SINDy solves a sparse regression problem in order to find the $\mathbf{W}$ that maps the library of selected terms to each feature of the system being identified. SINDy imposes parsimony constraints over the resulting symbolic regression (i.e., genetic programming) to describe a dynamical system's behavior with as few terms as possible. In order to select a sparse set of the given features, the model adds the LASSO regularizarion penalty (i.e., an L1 norm constraint) to the regression problem and solves the sparse regression or solves the regression problem via STLSQ. We will describe STLSQ in third step of the SINDy dynamics/process. @@ -48,206 +52,101 @@ In essence, SINDy's dynamics can be presented in three main phases, visualized i ------------------------------------------------------------------------------------------ -

- - + **Figure 1:** **The flow of the three phases in SINDy.** **Phase-1)** Data collection: capturing system states that are changing in time and creating the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 to compute its derivative with respect to time. **Phase-3)** Solving the sparse regression problem. -

------------------------------------------------------------------------------------------ - - - - - - - - -
- ## Phase 1: Collecting Dataset → $\mathbf{X}_{(m \times n)}$ This phase involves gathering the raw data points representing the system's states across time. In this example, this means capturing the $x$, $y$, and $z$ coordinates of the system's states. Here, $m$ represents the number of data points (number of the snapshots/length of time) and $n$ is the system's dimensionality. - -

- Dataset collection showing x, y, z coordinates - -

-
- + + + + + - + - - - - - - - - - - - +Library of Candidate Functions: +$\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - - - - - + + -
- ## Phase 2: Processing - -

- -

-
+ + ### 2.A: Making the Library → $\mathbf{\Theta}_{(m \times p)}$ In this step, using the dataset collected in phase 1, given pre-defined function terms, we construct a dictionary of candidate predictors for identifying the target system's differential equations. These functions form the columns of our library matrix $\mathbf{\Theta}(\mathbf{X})$ and $p$ is the number of candidate predictors. To identify the dynamical structure of the system, this library of candidate functions appears in the regression problem to propose the model's structure that will later serve as the coefficient matrix for weighting the functions according to the problem setup. We assume sparse models will be sufficient to identify the system and do this through sparsification (LASSO or thresholding weights) in order decide which structure best describes the system's behavior using predictors. Given a set of time-series measurements of a dynamical system state variables ($\mathbf{X}_{(m \times n)}$) we construct the following: -Library of Candidate Functions: $\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - -

- -

-
- + + ### 2.B: Compute State Derivatives → $\mathbf{Ẋ}_{(m \times n)}$ Given a set of time-series measurements of a dynamical system's state variables $\mathbf{X}_{(m \times n)}$, we next construct the derivative matrix: $\mathbf{Ẋ}_{(m \times n)}$ (computed numerically). In this step, using the dataset collected in phase 1, we compute the derivatives of each state variable with respect to time. In this example, we compute $ẋ$, $ẏ$, and $ż$ in order to capture how the system evolves over time. - -

- -

-
- - - - - - - - - - - -
- ## Phase 3: Solving Sparse Regression Problem → $\mathbf{W_s}_{(p \times n)}$ Solving the resulting sparse regression (SR) problem that results from the phases/steps above can be done using various method such as Lasso, STLSQ, Elastic Net, as well as many other schemes. Here, we describe the STLSQ approach to solve the SR problem according to the SINDy process. - -

- Dataset collection showing x, y, z coordincates -

-
- - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - -
- + + ### Solving Sparse Regression by Sequential Thresholding Least Squares (STLSQ) -

- + + **Figure 1:** **The flow of three phases in SINDy.** **Phase-1)** Data collection: capturing system's states that are changing in time and making the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 and computing its derivative with respect to time. **Phase-3)** Solving the sparse regression problem via STLSQ.

------------------------------------------------------------------------------------------ -
- ### Sequential Thresholding Least Square (STLSQ) -
-

- State derivatives visualization -

-
#### 3.A: Least Square method (LSQ) → $\mathbf{W}$ This step entails finding library coefficients by solving the following regression problem $\mathbf{Ẋ} = \mathbf{\Theta}\mathbf{W}$ analytically $\mathbf{W} = (\mathbf{\Theta}^T \mathbf{\Theta})^{-1} \mathbf{\Theta}^T \mathbf{Ẋ}$ - -

- State derivatives visualization -

-
- + + #### 3.B: Thresholding → $\mathbf{W_s}$ This step entails sparsifying $\mathbf{W}$ by keeping only some of the terms within $\mathbf{W}$, particularly those that correspond to the effective terms in the library. - -

- State derivatives visualization -

-
- + + + #### 3.C: Masking → $\mathbf{\Theta_s}$ This step sparsifies $\mathbf{\Theta}$ by keeping only the corresponding terms in $\mathbf{W}$ that remain (from the prior step). - -

- State derivatives visualization -

-
+ #### 3.D: Repeat A → B → C until convergence We continue to solve LSQ with the sparse matrix $\mathbf{\Theta_s}$ and $\mathbf{W_s}$ and find a new $\mathbf{W}$, repeating steps B and C until convergence. - -

- State derivatives visualization -

-
- - - - + - - - - ## Code: Simulating SINDy We finally present ngc-learn code below for using and simulating the SINDy process to identify several dynamical systems. - - ```python - - - import numpy as np import jax.numpy as jnp from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary @@ -335,31 +224,14 @@ for dim in range(dX.shape[1]): coeff = jnp.where(jnp.abs(coef) >= threshold, coef, 0.) print(f"coefficients for dimension {dim+1}: \n", coeff.T) - - - ``` - - - - ## Results: System Identification Running the above code should produce results similar to the findings we present next. - - - - - - - - - - - - - - - - - - - - - - - +which should produce the following results: - - - - - - - - -
- Model - - Results -
- ## Oscillator +## Oscillator True model's equation \ $\mathbf{ẋ} = \mu_1\mathbf{x} + \sigma \mathbf{xy}$ \ @@ -380,19 +252,13 @@ $\mathbf{ż} = \mu_2\mathbf{z} - (\omega + \alpha \mathbf{y} + \beta \mathbf{z}) [ 0. -0.009 0. -2.99 4.99 1.99 0. 0. 0. 0.]] ``` - -

- - -

-
+which should produce the following results: + + + + - ## Lorenz +## Lorenz True model's equation \ $\mathbf{ẋ} = 10(\mathbf{y} - \mathbf{x})$ \ @@ -400,7 +266,6 @@ $\mathbf{ẏ} = \mathbf{x}(28 - \mathbf{z}) - \mathbf{y}$ \ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ - ```python --- SINDy results ---- ẋ = 9.969 𝑦 -9.966 𝑥 @@ -413,19 +278,12 @@ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ [-2.656 0. 0. 0. 0. 0. 0. 0.996 0.]] ``` - -

- - -

-
- - ## Linear-2D +which should produce the following results: + + + + +## Linear-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2.0\mathbf{y}$ \ @@ -440,21 +298,14 @@ $\mathbf{ẏ} = -2.0\mathbf{x} - 0.1\mathbf{y}$ [[ 1.999 0. -0.100 0. 0.] [-0.099 0. -1.999 0. 0.]] ``` + +which should produce the following results: + + - -

- - -

-
- ## Linear-3D +## Linear-3D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2\mathbf{y}$ \ @@ -473,22 +324,13 @@ $\mathbf{ż} = -0.3\mathbf{z}$ [-0.299 0. 0. 0. 0. 0. 0. 0. 0.]] ``` - -

- - -

-
+ + + - ## Cubic-2D +## Cubic-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x}^3 + 2.0\mathbf{y}^3$ \ @@ -504,16 +346,10 @@ $\mathbf{ẏ} = -2.0\mathbf{x}^3 - 0.1\mathbf{y}^3$ [ 0. 0. -0.099 0. 0. 0. 0. 0. -1.99]] ``` - -

- - -

-
+which should produce the following results: + + + ## References diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md index 1580ab1c..44e3b0a7 100755 --- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md +++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md @@ -83,9 +83,9 @@ Formally, the core dynamics of the H-H cell can be written out as follows: $$ \tau_v \frac{\partial \mathbf{v}_t}{\partial t} &= \mathbf{j}_t - g_Na * \mathbf{m}^3_t * \mathbf{h}_t * (\mathbf{v}_t - v_Na) - g_K * \mathbf{n}^4_t * (\mathbf{v}_t - v_K) - g_L * (\mathbf{v}_t - v_L) \\ -\frac{\partial \mathbf{n}_t}{\partial t} &= alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ -\frac{\partial \mathbf{m}_t}{\partial t} &= alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ -\frac{\partial \mathbf{h}_t}{\partial t} &= alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - beta_h(\mathbf{v}_t) * \mathbf{h}_t +\frac{\partial \mathbf{n}_t}{\partial t} &= \alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - \beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ +\frac{\partial \mathbf{m}_t}{\partial t} &= \alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - \beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ +\frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t $$ where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to. diff --git a/ngclearn/components/neurons/graded/rateCellOld.py b/ngclearn/components/neurons/graded/rateCellOld.py deleted file mode 100644 index 6962810c..00000000 --- a/ngclearn/components/neurons/graded/rateCellOld.py +++ /dev/null @@ -1,350 +0,0 @@ -from jax import numpy as jnp, random, jit -from functools import partial -from ngclearn.utils import tensorstats -from ngclearn import resolver, Component, Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils.model_utils import create_function, threshold_soft, \ - threshold_cauchy -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 - -def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): - z_leak = z # * 2 ## Default: assume Gaussian - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.sign(z) ## d/dx of Laplace is signum - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): - z_leak = (z * 2)/(1. + jnp.square(z)) - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.exp(-jnp.square(z)) * z * 2 - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - - -def _dfz_gaussian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_laplacian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_cauchy(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_exp(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -@jit -def _modulate(j, dfx_val): - """ - Apply a signal modulator to j (typically of the form of a derivative/dampening function) - - Args: - j: current/stimulus value to modulate - - dfx_val: modulator signal - - Returns: - modulated j value - """ - return j * dfx_val - -def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): - """ - Runs leaky rate-coded state dynamics one step in time. - - Args: - dt: integration time constant - - j: input (bottom-up) electrical/stimulus current - - j_td: modulatory (top-down) electrical/stimulus pressure - - z: current value of membrane/state - - tau_m: membrane/state time constant - - leak_gamma: strength of leak to apply to membrane/state - - integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) - - priorType: scale-shift prior distribution to impose over neural dynamics - - Returns: - New value of membrane/state for next time step - """ - _dfz = { - 0: _dfz_gaussian, - 1: _dfz_laplacian, - 2: _dfz_cauchy, - 3: _dfz_exp - }.get(priorType, _dfz_gaussian) - if integType == 1: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk2(0., z, _dfz, dt, params) - elif integType == 2: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk4(0., z, _dfz, dt, params) - else: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_euler(0., z, _dfz, dt, params) - return _z - -@jit -def _run_cell_stateless(j): - """ - A simplification of running a stateless set of dynamics over j (an identity - functional form of dynamics). - - Args: - j: stimulus to do nothing to - - Returns: - the stimulus - """ - return j + 0 - -class RateCell(JaxComponent): ## Rate-coded/real-valued cell - """ - A non-spiking cell driven by the gradient dynamics of neural generative - coding-driven predictive processing. - - The specific differential equation that characterizes this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dz/dt = lambda * prior(z) + (j + j_td) - | where j is the set of general incoming input signals (e.g., message-passed signals) - | and j_td is taken to be the set of top-down pressure signals - - | --- Cell Input Compartments: --- - | j - input pressure (takes in external signals) - | j_td - input/top-down pressure input (takes in external signals) - | --- Cell State Compartments --- - | z - rate activity - | --- Cell Output Compartments: --- - | zF - post-activation function activity, i.e., fx(z) - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - tau_m: membrane/state time constant (milliseconds) - - prior: a kernel for specifying the type of centered scale-shift distribution - to impose over neuronal dynamics, applied to each neuron or - dimension within this component (Default: ("gaussian", 0)); this is - a tuple with 1st element containing a string name of the distribution - one wants to use while the second value is a `leak rate` scalar - that controls the influence/weighting that this distribution - has on the dynamics; for example, ("laplacian, 0.001") means that a - centered laplacian distribution scaled by `0.001` will be injected - into this cell's dynamics ODE each step of simulated time - - :Note: supported scale-shift distributions include "laplacian", - "cauchy", "exp", and "gaussian" - - act_fx: string name of activation function/nonlinearity to use - - integration_type: type of integration to use for this cell's dynamics; - current supported forms include "euler" (Euler/RK-1 integration) - and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") - - :Note: setting the integration type to the midpoint method will - increase the accuray of the estimate of the cell's evolution - at an increase in computational cost (and simulation time) - - resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) - """ - - # Define Functions - def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", - threshold=("none", 0.), integration_type="euler", - batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): - super().__init__(name, **kwargs) - - ## membrane parameter setup (affects ODE integration) - self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode - self.is_stateful = is_stateful - if isinstance(tau_m, float): - if tau_m <= 0: ## trigger stateless mode - self.is_stateful = False - priorType, leakRate = prior - self.priorType = { - "gaussian": 0, - "laplacian": 1, - "cauchy": 2, - "exp": 3 - }.get(priorType, 0) ## type of scale-shift prior to impose over the leak - self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) - thresholdType, thr_lmbda = threshold - self.thresholdType = thresholdType ## type of thresholding function to use - self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics - self.resist_scale = resist_scale ## a "resistance" scaling factor - - ## integration properties - self.integrationType = integration_type - self.intgFlag = get_integrator_code(self.integrationType) - - ## Layer size setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - self.shape = shape - self.n_units = n_units - self.batch_size = batch_size - - omega_0 = None - if act_fx == "sine": - omega_0 = kwargs["omega_0"] - self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) - - # compartments (state of the cell & parameters will be updated through stateless calls) - restVals = jnp.zeros(_shape) - self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current - self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity - self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure - self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - - @staticmethod - def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, - resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z): - #if tau_m > 0.: - if is_stateful: - ### run a step of integration over neuronal dynamics - ## Notes: - ## self.pressure <-- "top-down" expectation / contextual pressure - ## self.current <-- "bottom-up" data-dependent signal - dfx_val = dfx(z) - j = _modulate(j, dfx_val) - j = j * resist_scale - tmp_z = _run_cell(dt, j, j_td, z, - tau_m, leak_gamma=priorLeakRate, - integType=intgFlag, priorType=priorType) - ## apply optional thresholding sub-dynamics - if thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, thr_lmbda) - elif thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, thr_lmbda) - z = tmp_z ## pre-activation function value(s) - zF = fx(z) ## post-activation function value(s) - else: - ## run in "stateless" mode (when no membrane time constant provided) - j_total = j + j_td - z = _run_cell_stateless(j_total) - zF = fx(z) - return j, j_td, z, zF - - @resolver(_advance_state) - def advance_state(self, j, j_td, z, zF): - self.j.set(j) - self.j_td.set(j_td) - self.z.set(z) - self.zF.set(zF) - - @staticmethod - def _reset(batch_size, shape): #n_units - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) - restVals = jnp.zeros(_shape) - return tuple([restVals for _ in range(4)]) - - @resolver(_reset) - def reset(self, j, zF, j_td, z): - self.j.set(j) # electrical current - self.zF.set(zF) # rate-coded output - activity - self.j_td.set(j_td) # top-down electrical current - pressure - self.z.set(z) # rate activity - - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) - priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - else jnp.ones([[self.priorLeakRate]])) - resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - else jnp.ones([[self.resist_scale]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - tau_m=tau_m, priorLeakRate=priorLeakRate, - resist_scale=resist_scale) #, key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - ## constants loaded in - self.tau_m = data['tau_m'] - self.priorLeakRate = data['priorLeakRate'] - self.resist_scale = data['resist_scale'] - #if seeded: - # self.key.set(data['key']) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "RateCell - evolves neurons according to rate-coded/" - "continuous dynamics " - } - compartment_props = { - "inputs": - {"j": "External input stimulus value(s)", - "j_td": "External top-down input stimulus value(s); these get " - "multiplied by the derivative of f(x), i.e., df(x)"}, - "states": - {"z": "Update to rate-coded continuous dynamics; value at time t"}, - "outputs": - {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "tau_m": "Cell state/membrane time constant", - "prior": "What kind of kurtotic prior to place over neuronal dynamics?", - "act_fx": "Elementwise activation function to apply over cell state `z`", - "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", - "integration_type": "Type of numerical integration to use for the cell dynamics", - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = RateCell("X", 9, 0.03) - print(X) \ No newline at end of file diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py b/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py deleted file mode 100644 index 04ebd4cb..00000000 --- a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py +++ /dev/null @@ -1,326 +0,0 @@ -from jax import random, numpy as jnp, jit -from functools import partial -from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args - -@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) -def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., - prior_type=None, prior_lmbda=0., - pre_wght=1., post_wght=1.): - """ - Compute a tensor of adjustments to be applied to a synaptic value matrix. - - Args: - pre: pre-synaptic statistic to drive Hebbian update - - post: post-synaptic statistic to drive Hebbian update - - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: (Unused) - - signVal: multiplicative factor to modulate final update by (good for - flipping the signs of a computed synaptic change matrix) - - prior_type: prior type or name (Default: None) - - prior_lmbda: prior parameter (Default: 0.0) - - pre_wght: pre-synaptic weighting term (Default: 1.) - - post_wght: post-synaptic weighting term (Default: 1.) - - Returns: - an update/adjustment matrix, an update adjustment vector (for biases) - """ - _pre = pre * pre_wght - _post = post * post_wght - dW = jnp.matmul(_pre.T, _post) - db = jnp.sum(_post, axis=0, keepdims=True) - dW_reg = 0. - - if w_bound > 0.: - dW = dW * (w_bound - jnp.abs(W)) - - if prior_type == "l2" or prior_type == "ridge": - dW_reg = W - if prior_type == "l1" or prior_type == "lasso": - dW_reg = jnp.sign(W) - if prior_type == "l1l2" or prior_type == "elastic_net": - l1_ratio = prior_lmbda[1] - prior_lmbda = prior_lmbda[0] - dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2 - - dW = dW + prior_lmbda * dW_reg - return dW * signVal, db * signVal - -@partial(jit, static_argnums=[1,2]) -def _enforce_constraints(W, w_bound, is_nonnegative=True): - """ - Enforces constraints that the (synaptic) efficacies/values within matrix - `W` must adhere to. - - Args: - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: ensure updated value matrix is strictly non-negative - - Returns: - the newly evolved synaptic weight value matrix - """ - _W = W - if w_bound > 0.: - if is_nonnegative == True: - _W = jnp.clip(_W, 0., w_bound) - else: - _W = jnp.clip(_W, -w_bound, w_bound) - return _W - - -class HebbianSynapse(DenseSynapse): - """ - A synaptic cable that adjusts its efficacies via a two-factor Hebbian - adjustment rule. - - | --- Synapse Compartments: --- - | inputs - input (takes in external signals) - | outputs - output signals (transformation induced by synapses) - | weights - current value matrix of synaptic efficacies - | biases - current value vector of synaptic bias values - | key - JAX PRNG key - | --- Synaptic Plasticity Compartments: --- - | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals) - | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals) - | dWeights - current delta matrix containing changes to be applied to synaptic efficacies - | dBiases - current delta vector containing changes to be applied to bias values - | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used) - - Args: - name: the string name of this cell - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - eta: global learning rate - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - bias_init: a kernel to drive initialization of biases for this synaptic cable - (Default: None, which turns off/disables biases) - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - is_nonnegative: enforce that synaptic efficacies are always non-negative - after each synaptic update (if False, no constraint will be applied) - - prior: a kernel to drive prior of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - prior to use and 2nd element as a floating point number - calling the prior parameter lambda (Default: (None, 0.)) - currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net". - usage guide: - prior = ('l1', 0.01) or prior = ('lasso', lmbda) - prior = ('l2', 0.01) or prior = ('ridge', lmbda) - prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio)) - - - - sign_value: multiplicative factor to apply to final synaptic update before - it is applied to synapses; this is useful if gradient descent style - optimization is required (as Hebbian rules typically yield - adjustments for ascent) - - optim_type: optimization scheme to physically alter synaptic values - once an update is computed (Default: "sgd"); supported schemes - include "sgd" and "adam" - - :Note: technically, if "sgd" or "adam" is used but `signVal = 1`, - then the ascent form of each rule is employed (signVal = -1) or - a negative learning rate will mean a descent form of the - `optim_scheme` is being employed - - pre_wght: pre-synaptic weighting factor (Default: 1.) - - post_wght: post-synaptic weighting factor (Default: 1.) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - p_conn: probability of a connection existing (default: 1.); setting - this to < 1. will result in a sparser synaptic structure - """ - - # Define Functions - @deprecate_args(_rebind=False, w_decay='prior') - def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, - w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1., - optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs): - super().__init__(name, shape, weight_init, bias_init, resist_scale, - p_conn, batch_size=batch_size, **kwargs) - - if w_decay > 0.: - prior = ('l2', w_decay) - - prior_type, prior_lmbda = prior - ## synaptic plasticity properties and characteristics - self.shape = shape - self.Rscale = resist_scale - self.prior_type = prior_type - self.prior_lmbda = prior_lmbda - self.w_bound = w_bound - self.pre_wght = pre_wght - self.post_wght = post_wght - self.eta = eta - self.is_nonnegative = is_nonnegative - self.sign_value = sign_value - - ## optimization / adjustment properties (given learning dynamics above) - self.opt = get_opt_step_fn(optim_type, eta=self.eta) - - # compartments (state of the cell, parameters, will be updated through stateless calls) - self.preVals = jnp.zeros((self.batch_size, shape[0])) - self.postVals = jnp.zeros((self.batch_size, shape[1])) - self.pre = Compartment(self.preVals) - self.post = Compartment(self.postVals) - self.dWeights = Compartment(jnp.zeros(shape)) - self.dBiases = Compartment(jnp.zeros(shape[1])) - - #key, subkey = random.split(self.key.value) - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) - - @staticmethod - def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, pre, post, weights): - ## calculate synaptic update values - dW, db = _calc_update( - pre, post, weights, w_bound, is_nonnegative=is_nonnegative, - signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght, - post_wght=post_wght) - return dW, db - - @staticmethod - def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): - ## calculate synaptic update values - dWeights, dBiases = HebbianSynapse._compute_update( - w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, - pre, post, weights - ) - ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) - else: - # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) - ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @resolver(_evolve) - def evolve(self, opt_params, weights, biases, dWeights, dBiases): - self.opt_params.set(opt_params) - self.weights.set(weights) - self.biases.set(biases) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @staticmethod - def _reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) - - @resolver(_reset) - def reset(self, inputs, outputs, pre, post, dWeights, dBiases): - self.inputs.set(inputs) - self.outputs.set(outputs) - self.pre.set(pre) - self.post.set(post) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @classmethod - def help(cls): ## component help function - properties = { - "synapse_type": "HebbianSynapse - performs an adaptable synaptic " - "transformation of inputs to produce output signals; " - "synapses are adjusted via two-term/factor Hebbian adjustment" - } - compartment_props = { - "inputs": - {"inputs": "Takes in external input signal values", - "pre": "Pre-synaptic statistic for Hebb rule (z_j)", - "post": "Post-synaptic statistic for Hebb rule (z_i)"}, - "states": - {"weights": "Synapse efficacy/strength parameter values", - "biases": "Base-rate/bias parameter values", - "key": "JAX PRNG key"}, - "analytics": - {"dWeights": "Synaptic weight value adjustment matrix produced at time t", - "dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"}, - "outputs": - {"outputs": "Output of synaptic transformation"}, - } - hyperparams = { - "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", - "batch_size": "Batch size dimension of this component", - "weight_init": "Initialization conditions for synaptic weight (W) values", - "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (applied to output of transformation)", - "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", - "is_nonnegative": "Should synapses be constrained to be non-negative post-updates?", - "sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0", - "eta": "Global (fixed) learning rate", - "pre_wght": "Pre-synaptic weighting coefficient (q_pre)", - "post_wght": "Post-synaptic weighting coefficient (q_post)", - "w_bound": "Soft synaptic bound applied to synapses post-update", - "prior": "prior name and value for synaptic updating prior", - "optim_type": "Choice of optimizer to adjust synaptic weights" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "outputs = [(W * Rscale) * inputs] + b ;" - "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam', - sign_value=-1.0, prior=("l1l2", 0.001)) - print(Wab) \ No newline at end of file