|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from ngcsimlib.compilers.process import transition |
| 3 | +from ngcsimlib.component import Component |
| 4 | +from ngcsimlib.compartment import Compartment |
| 5 | + |
| 6 | +from ngclearn.utils.weight_distribution import initialize_params |
| 7 | +from ngcsimlib.logger import info |
| 8 | +from ngclearn.components.synapses import DenseSynapse |
| 9 | +from ngclearn.utils import tensorstats |
| 10 | + |
| 11 | +class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable |
| 12 | + """ |
| 13 | + A dynamic exponential synaptic cable; this synapse evolves according to exponential synaptic conductance dynamics. |
| 14 | + Specifically, the dynamics are as follows: |
| 15 | +
|
| 16 | + | g = g + (weight * gbase) // on the occurrence of a pulse |
| 17 | + | i = g * (erev - v), where: d g /dt = -g / tauDecay |
| 18 | +
|
| 19 | +
|
| 20 | + | --- Synapse Compartments: --- |
| 21 | + | inputs - input (takes in external signals) |
| 22 | + | outputs - output signals |
| 23 | + | weights - current value matrix of synaptic efficacies |
| 24 | + | biases - current value vector of synaptic bias values |
| 25 | + | --- Short-Term Plasticity Compartments: --- |
| 26 | + | resources - fixed value matrix of synaptic resources (U) |
| 27 | + | u - release probability; fraction of resources ready for use |
| 28 | + | x - fraction of resources available after neurotransmitter depletion |
| 29 | +
|
| 30 | + | Dynamics note: |
| 31 | + | If tau_d >> tau_f and resources U are large, then synapse is STD-dominated |
| 32 | + | If tau_d << tau_f and resources U are small, then synases is STF-dominated |
| 33 | +
|
| 34 | + Args: |
| 35 | + name: the string name of this cell |
| 36 | +
|
| 37 | + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple |
| 38 | + with number of inputs by number of outputs) |
| 39 | +
|
| 40 | + weight_init: a kernel to drive initialization of this synaptic cable's values; |
| 41 | + typically a tuple with 1st element as a string calling the name of |
| 42 | + initialization to use |
| 43 | +
|
| 44 | + bias_init: a kernel to drive initialization of biases for this synaptic cable |
| 45 | + (Default: None, which turns off/disables biases) |
| 46 | +
|
| 47 | + resist_scale: a fixed (resistance) scaling factor to apply to synaptic |
| 48 | + transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) |
| 49 | +
|
| 50 | + p_conn: probability of a connection existing (default: 1.); setting |
| 51 | + this to < 1 and > 0. will result in a sparser synaptic structure |
| 52 | + (lower values yield sparse structure) |
| 53 | +
|
| 54 | + tau_f: short-term facilitation (STF) time constant (default: `750` ms); note |
| 55 | + that setting this to `0` ms will disable STF |
| 56 | +
|
| 57 | + tau_d: shoft-term depression time constant (default: `50` ms); note |
| 58 | + that setting this to `0` ms will disable STD |
| 59 | +
|
| 60 | + resources_int: initialization kernel for synaptic resources matrix |
| 61 | + """ |
| 62 | + |
| 63 | + # Define Functions |
| 64 | + def __init__(self, name, shape, weight_init=None, bias_init=None, |
| 65 | + resist_scale=1., p_conn=1., tau_f=750., tau_d=50., |
| 66 | + resources_init=None, **kwargs): |
| 67 | + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) |
| 68 | + ## STP meta-parameters |
| 69 | + self.resources_init = resources_init |
| 70 | + self.tau_f = tau_f |
| 71 | + self.tau_d = tau_d |
| 72 | + |
| 73 | + ## Set up short-term plasticity / dynamic synapse compartment values |
| 74 | + tmp_key, *subkeys = random.split(self.key.value, 4) |
| 75 | + preVals = jnp.zeros((self.batch_size, shape[0])) |
| 76 | + self.i = Compartment(preVals) ## electrical current output |
| 77 | + self.g = Compartment(preVals) ## conductance variable |
| 78 | + |
| 79 | + |
| 80 | + @transition(output_compartments=["outputs", "i", "g"]) |
| 81 | + @staticmethod |
| 82 | + def advance_state( |
| 83 | + tau_f, tau_d, Rscale, inputs, weights, biases, i, g |
| 84 | + ): |
| 85 | + s = inputs |
| 86 | + |
| 87 | + outputs = None #jnp.matmul(inputs, Wdyn * Rscale) + biases |
| 88 | + return outputs |
| 89 | + |
| 90 | + @transition(output_compartments=["inputs", "outputs", "i", "g"]) |
| 91 | + @staticmethod |
| 92 | + def reset(batch_size, shape): |
| 93 | + preVals = jnp.zeros((batch_size, shape[0])) |
| 94 | + postVals = jnp.zeros((batch_size, shape[1])) |
| 95 | + inputs = preVals |
| 96 | + outputs = postVals |
| 97 | + i = preVals |
| 98 | + g = preVals |
| 99 | + return inputs, outputs, i, g |
| 100 | + |
| 101 | + def save(self, directory, **kwargs): |
| 102 | + file_name = directory + "/" + self.name + ".npz" |
| 103 | + if self.bias_init != None: |
| 104 | + jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) |
| 105 | + else: |
| 106 | + jnp.savez(file_name, weights=self.weights.value) |
| 107 | + |
| 108 | + def load(self, directory, **kwargs): |
| 109 | + file_name = directory + "/" + self.name + ".npz" |
| 110 | + data = jnp.load(file_name) |
| 111 | + self.weights.set(data['weights']) |
| 112 | + if "biases" in data.keys(): |
| 113 | + self.biases.set(data['biases']) |
| 114 | + |
| 115 | + @classmethod |
| 116 | + def help(cls): ## component help function |
| 117 | + properties = { |
| 118 | + "synapse_type": "STPDenseSynapse - performs a synaptic transformation of inputs to produce " |
| 119 | + "output signals (e.g., a scaled linear multivariate transformation); " |
| 120 | + "this synapse is dynamic, adapting via a form of short-term plasticity" |
| 121 | + } |
| 122 | + compartment_props = { |
| 123 | + "inputs": |
| 124 | + {"inputs": "Takes in external input signal values"}, |
| 125 | + "states": |
| 126 | + {"weights": "Synapse efficacy/strength parameter values", |
| 127 | + "biases": "Base-rate/bias parameter values", |
| 128 | + "key": "JAX PRNG key"}, |
| 129 | + "outputs": |
| 130 | + {"outputs": "Output of synaptic transformation"}, |
| 131 | + } |
| 132 | + hyperparams = { |
| 133 | + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", |
| 134 | + "weight_init": "Initialization conditions for synaptic weight (W) values", |
| 135 | + "bias_init": "Initialization conditions for bias/base-rate (b) values", |
| 136 | + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", |
| 137 | + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", |
| 138 | + } |
| 139 | + info = {cls.__name__: properties, |
| 140 | + "compartments": compartment_props, |
| 141 | + "dynamics": "outputs = [(W * Rscale) * inputs] + b; " |
| 142 | + "dg/dt = ", |
| 143 | + "hyperparameters": hyperparams} |
| 144 | + return info |
| 145 | + |
| 146 | + def __repr__(self): |
| 147 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 148 | + maxlen = max(len(c) for c in comps) + 5 |
| 149 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 150 | + for c in comps: |
| 151 | + stats = tensorstats(getattr(self, c).value) |
| 152 | + if stats is not None: |
| 153 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 154 | + line = ", ".join(line) |
| 155 | + else: |
| 156 | + line = "None" |
| 157 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 158 | + return lines |
0 commit comments