Skip to content

Commit b0b496a

Browse files
author
Alexander Ororbia
committed
integrated bernoulli err-cell
1 parent e055d95 commit b0b496a

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .neurons.graded.rateCell import RateCell
66
from .neurons.graded.gaussianErrorCell import GaussianErrorCell
77
from .neurons.graded.laplacianErrorCell import LaplacianErrorCell
8+
from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell
89
from .neurons.graded.rewardErrorCell import RewardErrorCell
910

1011

ngclearn/components/neurons/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .graded.rateCell import RateCell
33
from .graded.gaussianErrorCell import GaussianErrorCell
44
from .graded.laplacianErrorCell import LaplacianErrorCell
5+
from .graded.bernoulliErrorCell import BernoulliErrorCell
56
from .graded.rewardErrorCell import RewardErrorCell
67
## point to standard spiking cell component types
78
from .spiking.sLIFCell import SLIFCell

ngclearn/components/neurons/graded/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .rateCell import RateCell
33
from .gaussianErrorCell import GaussianErrorCell
44
from .laplacianErrorCell import LaplacianErrorCell
5+
from .bernoulliErrorCell import BernoulliErrorCell
56
from .rewardErrorCell import RewardErrorCell
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from ngclearn import resolver, Component, Compartment
2+
from ngclearn.components.jaxComponent import JaxComponent
3+
from jax import numpy as jnp, jit
4+
from ngclearn.utils import tensorstats
5+
6+
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
7+
"""
8+
A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution
9+
of a mismatch signal. Specifically, this cell operates as a factorized multivariate
10+
Bernoulli distribution.
11+
12+
| --- Cell Input Compartments: ---
13+
| p - predicted probability of positive trial (takes in external signals)
14+
| target - desired/goal value (takes in external signals)
15+
| modulator - modulation signal (takes in optional external signals)
16+
| mask - binary/gating mask to apply to error neuron calculations
17+
| --- Cell Output Compartments: ---
18+
| L - local loss function embodied by this cell
19+
| dp - derivative of L w.r.t. p
20+
| dtarget - derivative of L w.r.t. target
21+
22+
Args:
23+
name: the string name of this cell
24+
25+
n_units: number of cellular entities (neural population size)
26+
27+
batch_size: batch size dimension of this cell (Default: 1)
28+
29+
"""
30+
def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
31+
super().__init__(name, **kwargs)
32+
33+
## Layer Size Setup
34+
_shape = (batch_size, n_units) ## default shape is 2D/matrix
35+
if shape is None:
36+
shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
37+
else:
38+
_shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
39+
self.shape = shape
40+
self.n_units = n_units
41+
self.batch_size = batch_size
42+
43+
## Convolution shape setup
44+
self.width = self.height = n_units
45+
46+
## Compartment setup
47+
restVals = jnp.zeros(_shape)
48+
self.L = Compartment(0., display_name="Bernoulli Log likelihood", units="nats") # loss compartment
49+
self.p = Compartment(restVals, display_name="Bernoulli prob for B(X=1; p)") # pos trial prob name. input wire
50+
self.dp = Compartment(restVals) # derivative of positive trial prob
51+
self.target = Compartment(restVals, display_name="Bernoulli data/target variable") # target. input wire
52+
self.dtarget = Compartment(restVals) # derivative target
53+
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
54+
self.mask = Compartment(restVals + 1.0)
55+
56+
@staticmethod
57+
def _advance_state(dt, p, target, modulator, mask): ## compute Bernoulli error cell output
58+
# Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
59+
# behavior of the local cost functional
60+
eps = 0.001
61+
_p = jnp.clip(p, eps, 1. - eps) ## to prevent division by 0 later on
62+
x = target
63+
sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint)
64+
sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n)
65+
log_p = jnp.log(_p) ## log(p)
66+
log_1mp = jnp.log(1. - _p) ## log(1 - p)
67+
L = log_p * sum_x + log_1mp * sum_1mx ## Bern LL
68+
dL_dp = sum_x/log_p - sum_1mx/log_1mp ## d(Bern LL)/dp
69+
dL_dx = log_p - log_1mp ## d(Bern LL)/dx
70+
71+
dp = dL_dp * modulator * mask ## not sure how mask will apply to a full covariance...
72+
dtarget = dL_dx * modulator * mask
73+
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
74+
return dp, dtarget, jnp.squeeze(L), mask
75+
76+
@resolver(_advance_state)
77+
def advance_state(self, dp, dtarget, L, mask):
78+
self.dp.set(dp)
79+
self.dtarget.set(dtarget)
80+
self.L.set(L)
81+
self.mask.set(mask)
82+
83+
@staticmethod
84+
def _reset(batch_size, shape): ## reset core components/statistics
85+
_shape = (batch_size, shape[0])
86+
if len(shape) > 1:
87+
_shape = (batch_size, shape[0], shape[1], shape[2])
88+
restVals = jnp.zeros(_shape)
89+
dp = restVals
90+
dtarget = restVals
91+
target = restVals
92+
p = restVals
93+
modulator = mu + 1.
94+
L = 0. #jnp.zeros((1, 1))
95+
mask = jnp.ones(_shape)
96+
return dp, dtarget, target, p, modulator, L, mask
97+
98+
@resolver(_reset)
99+
def reset(self, dp, dtarget, target, p, modulator, L, mask):
100+
self.dp.set(dp)
101+
self.dtarget.set(dtarget)
102+
self.target.set(target)
103+
self.p.set(p)
104+
self.modulator.set(modulator)
105+
self.L.set(L)
106+
self.mask.set(mask)
107+
108+
@classmethod
109+
def help(cls): ## component help function
110+
properties = {
111+
"cell_type": "GaussianErrorcell - computes mismatch/error signals at "
112+
"each time step t (between a `target` and a prediction `mu`)"
113+
}
114+
compartment_props = {
115+
"inputs":
116+
{"p": "External input positive probability value(s)",
117+
"target": "External input target signal value(s)",
118+
"modulator": "External input modulatory/scaling signal(s)",
119+
"mask": "External binary/gating mask to apply to signals"},
120+
"outputs":
121+
{"L": "Local loss value computed/embodied by this error-cell",
122+
"dp": "first derivative of loss w.r.t. positive probability value(s)",
123+
"dtarget": "first derivative of loss w.r.t. target value(s)"},
124+
}
125+
hyperparams = {
126+
"n_units": "Number of neuronal cells to model in this layer",
127+
"batch_size": "Batch size dimension of this component",
128+
"sigma": "External input variance value (currently fixed and not learnable)"
129+
}
130+
info = {cls.__name__: properties,
131+
"compartments": compartment_props,
132+
"dynamics": "Bernoulli(x=target; p) where target is binary variable",
133+
"hyperparameters": hyperparams}
134+
return info
135+
136+
def __repr__(self):
137+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
138+
maxlen = max(len(c) for c in comps) + 5
139+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
140+
for c in comps:
141+
stats = tensorstats(getattr(self, c).value)
142+
if stats is not None:
143+
line = [f"{k}: {v}" for k, v in stats.items()]
144+
line = ", ".join(line)
145+
else:
146+
line = "None"
147+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
148+
return lines
149+
150+
if __name__ == '__main__':
151+
from ngcsimlib.context import Context
152+
with Context("Bar") as bar:
153+
X = GaussianErrorCell("X", 9)
154+
print(X)

0 commit comments

Comments
 (0)