|
| 1 | +from ngclearn import resolver, Component, Compartment |
| 2 | +from ngclearn.components.jaxComponent import JaxComponent |
| 3 | +from jax import numpy as jnp, jit |
| 4 | +from ngclearn.utils import tensorstats |
| 5 | + |
| 6 | +class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell |
| 7 | + """ |
| 8 | + A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution |
| 9 | + of a mismatch signal. Specifically, this cell operates as a factorized multivariate |
| 10 | + Bernoulli distribution. |
| 11 | +
|
| 12 | + | --- Cell Input Compartments: --- |
| 13 | + | p - predicted probability of positive trial (takes in external signals) |
| 14 | + | target - desired/goal value (takes in external signals) |
| 15 | + | modulator - modulation signal (takes in optional external signals) |
| 16 | + | mask - binary/gating mask to apply to error neuron calculations |
| 17 | + | --- Cell Output Compartments: --- |
| 18 | + | L - local loss function embodied by this cell |
| 19 | + | dp - derivative of L w.r.t. p |
| 20 | + | dtarget - derivative of L w.r.t. target |
| 21 | +
|
| 22 | + Args: |
| 23 | + name: the string name of this cell |
| 24 | +
|
| 25 | + n_units: number of cellular entities (neural population size) |
| 26 | +
|
| 27 | + batch_size: batch size dimension of this cell (Default: 1) |
| 28 | +
|
| 29 | + """ |
| 30 | + def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs): |
| 31 | + super().__init__(name, **kwargs) |
| 32 | + |
| 33 | + ## Layer Size Setup |
| 34 | + _shape = (batch_size, n_units) ## default shape is 2D/matrix |
| 35 | + if shape is None: |
| 36 | + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided |
| 37 | + else: |
| 38 | + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor |
| 39 | + self.shape = shape |
| 40 | + self.n_units = n_units |
| 41 | + self.batch_size = batch_size |
| 42 | + |
| 43 | + ## Convolution shape setup |
| 44 | + self.width = self.height = n_units |
| 45 | + |
| 46 | + ## Compartment setup |
| 47 | + restVals = jnp.zeros(_shape) |
| 48 | + self.L = Compartment(0., display_name="Bernoulli Log likelihood", units="nats") # loss compartment |
| 49 | + self.p = Compartment(restVals, display_name="Bernoulli prob for B(X=1; p)") # pos trial prob name. input wire |
| 50 | + self.dp = Compartment(restVals) # derivative of positive trial prob |
| 51 | + self.target = Compartment(restVals, display_name="Bernoulli data/target variable") # target. input wire |
| 52 | + self.dtarget = Compartment(restVals) # derivative target |
| 53 | + self.modulator = Compartment(restVals + 1.0) # to be set/consumed |
| 54 | + self.mask = Compartment(restVals + 1.0) |
| 55 | + |
| 56 | + @staticmethod |
| 57 | + def _advance_state(dt, p, target, modulator, mask): ## compute Bernoulli error cell output |
| 58 | + # Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit |
| 59 | + # behavior of the local cost functional |
| 60 | + eps = 0.001 |
| 61 | + _p = jnp.clip(p, eps, 1. - eps) ## to prevent division by 0 later on |
| 62 | + x = target |
| 63 | + sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint) |
| 64 | + sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n) |
| 65 | + log_p = jnp.log(_p) ## log(p) |
| 66 | + log_1mp = jnp.log(1. - _p) ## log(1 - p) |
| 67 | + L = log_p * sum_x + log_1mp * sum_1mx ## Bern LL |
| 68 | + dL_dp = sum_x/log_p - sum_1mx/log_1mp ## d(Bern LL)/dp |
| 69 | + dL_dx = log_p - log_1mp ## d(Bern LL)/dx |
| 70 | + |
| 71 | + dp = dL_dp * modulator * mask ## not sure how mask will apply to a full covariance... |
| 72 | + dtarget = dL_dx * modulator * mask |
| 73 | + mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t |
| 74 | + return dp, dtarget, jnp.squeeze(L), mask |
| 75 | + |
| 76 | + @resolver(_advance_state) |
| 77 | + def advance_state(self, dp, dtarget, L, mask): |
| 78 | + self.dp.set(dp) |
| 79 | + self.dtarget.set(dtarget) |
| 80 | + self.L.set(L) |
| 81 | + self.mask.set(mask) |
| 82 | + |
| 83 | + @staticmethod |
| 84 | + def _reset(batch_size, shape): ## reset core components/statistics |
| 85 | + _shape = (batch_size, shape[0]) |
| 86 | + if len(shape) > 1: |
| 87 | + _shape = (batch_size, shape[0], shape[1], shape[2]) |
| 88 | + restVals = jnp.zeros(_shape) |
| 89 | + dp = restVals |
| 90 | + dtarget = restVals |
| 91 | + target = restVals |
| 92 | + p = restVals |
| 93 | + modulator = mu + 1. |
| 94 | + L = 0. #jnp.zeros((1, 1)) |
| 95 | + mask = jnp.ones(_shape) |
| 96 | + return dp, dtarget, target, p, modulator, L, mask |
| 97 | + |
| 98 | + @resolver(_reset) |
| 99 | + def reset(self, dp, dtarget, target, p, modulator, L, mask): |
| 100 | + self.dp.set(dp) |
| 101 | + self.dtarget.set(dtarget) |
| 102 | + self.target.set(target) |
| 103 | + self.p.set(p) |
| 104 | + self.modulator.set(modulator) |
| 105 | + self.L.set(L) |
| 106 | + self.mask.set(mask) |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def help(cls): ## component help function |
| 110 | + properties = { |
| 111 | + "cell_type": "GaussianErrorcell - computes mismatch/error signals at " |
| 112 | + "each time step t (between a `target` and a prediction `mu`)" |
| 113 | + } |
| 114 | + compartment_props = { |
| 115 | + "inputs": |
| 116 | + {"p": "External input positive probability value(s)", |
| 117 | + "target": "External input target signal value(s)", |
| 118 | + "modulator": "External input modulatory/scaling signal(s)", |
| 119 | + "mask": "External binary/gating mask to apply to signals"}, |
| 120 | + "outputs": |
| 121 | + {"L": "Local loss value computed/embodied by this error-cell", |
| 122 | + "dp": "first derivative of loss w.r.t. positive probability value(s)", |
| 123 | + "dtarget": "first derivative of loss w.r.t. target value(s)"}, |
| 124 | + } |
| 125 | + hyperparams = { |
| 126 | + "n_units": "Number of neuronal cells to model in this layer", |
| 127 | + "batch_size": "Batch size dimension of this component", |
| 128 | + "sigma": "External input variance value (currently fixed and not learnable)" |
| 129 | + } |
| 130 | + info = {cls.__name__: properties, |
| 131 | + "compartments": compartment_props, |
| 132 | + "dynamics": "Bernoulli(x=target; p) where target is binary variable", |
| 133 | + "hyperparameters": hyperparams} |
| 134 | + return info |
| 135 | + |
| 136 | + def __repr__(self): |
| 137 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 138 | + maxlen = max(len(c) for c in comps) + 5 |
| 139 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 140 | + for c in comps: |
| 141 | + stats = tensorstats(getattr(self, c).value) |
| 142 | + if stats is not None: |
| 143 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 144 | + line = ", ".join(line) |
| 145 | + else: |
| 146 | + line = "None" |
| 147 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 148 | + return lines |
| 149 | + |
| 150 | +if __name__ == '__main__': |
| 151 | + from ngcsimlib.context import Context |
| 152 | + with Context("Bar") as bar: |
| 153 | + X = GaussianErrorCell("X", 9) |
| 154 | + print(X) |
0 commit comments