|
| 1 | +--- |
| 2 | +layout: post |
| 3 | +title: "Training Sparse Neural Networks with L0 Regularisation" |
| 4 | +date: 2023-04-06 |
| 5 | +mathjax: true |
| 6 | +status: [Code samples, Instructional] |
| 7 | +tldr: Explores L0 norm regularization for training sparse neural networks, where weights are encouraged to be entirely 0. It discusses overcoming non-differentiability issues by using a soft form of counting and reparameterization tricks. The post also delves into concrete distributions and introduces a method to make the continuous distribution more suitable for regularization purposes. |
| 8 | +categories: [Compression, Machine Learning] |
| 9 | +--- |
| 10 | + |
| 11 | + |
| 12 | +$L_0$ norm regularisation[^fn1] is a pretty fascinating technique for neural network pruning or for training sparse networks, where weights are encouraged to be completely 0. It is easy to implement, with only a few lines of code (see below), but getting there conceptually is not so easy. |
| 13 | + |
| 14 | + |
| 15 | +Several ML tricks are needed to achieve gradient flow through the network, because the default $L_0$ regularisation loss is non-differentiable for *evolving* reasons. (While solving one problem we introduce another and the loss function "evolves".) |
| 16 | + |
| 17 | +The form of the loss is $\mathcal{L}(f(x; \tilde{\theta} \odot z), y) + \mathcal{L}\_{\mathrm{reg}}$, where $z$ is a discrete binary mask on the parameters $\tilde{\theta}$. We use $\tilde{\theta}$ because ultimately the parameters that we care about are not $\tilde{\theta}$ exactly, but $\theta = \tilde{\theta} \odot z$ |
| 18 | + |
| 19 | + |
| 20 | +The final solution involves sampling from a *hard-concrete* distribution; |
| 21 | +which is obtained by stretching a *binary-concrete* distribution and then transforming the |
| 22 | +samples with a *hard-sigmoid*. |
| 23 | + |
| 24 | +<br> |
| 25 | + |
| 26 | + |
| 27 | +### Preliminaries |
| 28 | + |
| 29 | +#### <u>$L_p$ regularisation</u> |
| 30 | + |
| 31 | +Regularisation adds a term $\mathcal{L}_{\mathrm{reg}}$ to the loss function, which penalises the complexity of solution ($\theta$ weights) `typically' used to avoid overfitting and reduce generalisation error. The Maximum Likelihood Estimate of model parameters $\theta$ is given by |
| 32 | + |
| 33 | +$$ |
| 34 | +\hat{\theta}_{\mathrm{MLE}} = \mathrm{argmin}_{\theta} \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f(x_i, \theta), y_i) + \lambda \mathcal{L}_{\mathrm{reg}} |
| 35 | +$$ |
| 36 | + |
| 37 | +$L_p$ regularisation is a type of penalising cost based on the p-norm of the $\theta$ |
| 38 | +vector, $\mathcal{L}_{\mathrm{reg}} = \mid \mid \theta \mid \mid_p$, where $\mid \mid \theta \mid \mid_p = (\mid\theta_1 \mid^p + \mid\theta_2 \mid^p + \cdots)^{\frac{1}{p}}$. $L_1$ and $L_2$ regularisation are typically used in gradient based methods, but $L_0$ regularisation involves counting of non-zero weights, and is non-differentiable. |
| 39 | + |
| 40 | +Note: $L_2$ norm is continuously differentiable but $L_1$ is not continuously differentiable (at $\theta=0$). |
| 41 | + |
| 42 | + |
| 43 | +#### <u>Reparameterisation Trick</u> |
| 44 | + |
| 45 | +The reparameterization trick is used when we want to sample from a distribution (and learn the parameters of that distribution). The " trick" is to reparameterise the distribution, such that a sample has |
| 46 | +a deterministic differentiable) and noise non-differentiable component.[^fn2] This means |
| 47 | +re-expressing the sampling function as dependent on trainable parameters and some independent |
| 48 | +noise. |
| 49 | + |
| 50 | +Fpr example, a sample from $\mathcal{N}(\mu, \sigma^2)$ can be obtained by sampling $u$ from the standard form of the normal distribution, $u \sim \mathcal{N}(0, 1)$ and then transforming it |
| 51 | +using $\mu + \sigma u$. This reparameterisation makes it possible to reduce the problem of |
| 52 | +estimating gradients wrt parameters of a distribution, to estimating gradients wrt parameters |
| 53 | +of a deterministic function. |
| 54 | + |
| 55 | + |
| 56 | +#### <u>Concrete Distributions</u> |
| 57 | + |
| 58 | +The class of “Concrete” distributions was invented to enable **discrete** distributions to use |
| 59 | +the **reparameterisation trick**, by approximating discrete distributions as continuous |
| 60 | +distributions.[^fn3] The high level strategy is to first, relax the state of a discrete variable into a probability vector by adding noise. Second, use a softmax (or logistic in the case of binary) |
| 61 | +function instead of an argmax over the probabilities. Sampling from the Concrete distribution |
| 62 | +then becomes taking the softmax of logits, perturbed by fixed additive noise. |
| 63 | + |
| 64 | +*Note: Don't overthink the semantics of "Concrete"; it's just a (in my opinion poor) name and stands for a "CONtinuous relaxation of disCRETE random variables".* |
| 65 | + |
| 66 | +<br><br> |
| 67 | +### Method |
| 68 | + |
| 69 | +> **Problem:** $L_0$ Regularisation Cost is Non-differentiable \\ |
| 70 | +> **Solution:** Use the *probability* rather than the counts, of the weights being 0 |
| 71 | + |
| 72 | +Writing out $L_0$ regularisation, the maximum likelihood estimate is given by |
| 73 | + |
| 74 | +$$ |
| 75 | +\hat{\theta} = \mathrm{argmin}_{\theta} \frac{1}{N}(\sum_{i=1}^N \mathcal{L}(f(x_i; \theta), y_i)) + \lambda \mid \mid \theta \mid \mid_0 |
| 76 | +\tag{eq 1}\label{eq:1} |
| 77 | +$$ |
| 78 | + |
| 79 | + |
| 80 | +Where $\mid \mid \theta \mid \mid_0 = \sum_{j=1}^{\mid \theta |} \mathbb{1} [\theta_j \neq 0]$. This loss is non-differentiable because the |
| 81 | +counting of parameters is non-differentiable. |
| 82 | + |
| 83 | +To work around this, a soft form of counting is required, i.e., the *probability* of the |
| 84 | +weights being 0. We thus consider $\theta = \tilde{\theta} \odot z$, where $\odot$ is |
| 85 | +element-wise multiplication. The variable $z \sim \mathrm{Bernoulli}(\pi)$ can be |
| 86 | +viewed as $\\{ 0,1 \\}$ gates, which determine if the parameter $\theta$ is effectively present |
| 87 | +or absent. The probability of $z$ being 0 or 1, is controlled by the parameter $\pi$. We therefore need to learn $\pi$. |
| 88 | + |
| 89 | +$$ |
| 90 | +\pi^* = \mathrm{argmin}_{\pi} \mathbb{E}_{z \sim Bern(\pi)} \frac{1}{N} \sum_{i=1}^N \mathcal{L} |
| 91 | +(h(x_i, \tilde{\theta} \odot z), y_i) + \lambda \sum_{j=1}^{\mid \theta \mid} \pi_j |
| 92 | +\tag{eq 2}\label{eq:2} |
| 93 | +$$ |
| 94 | + |
| 95 | + |
| 96 | +The regularisation cost is now differentiable because instead of raw counts of $\theta$, |
| 97 | +\eqref{eq:1} we are |
| 98 | +summing the average probability ($\pi$) of the gates $z$ being 0, and thus the parameters |
| 99 | +$\theta=\tilde{\theta} \odot z$ being 0. $\pi_j$ is the parameter of each Bernoulli |
| 100 | +distribution that corresponds to a binary gate. |
| 101 | + |
| 102 | +At this point, we have solved the problem of parameter counting, but still cannot use gradient based optimization for $\pi$ because the $z$ we introduced is a discrete stochastic random variable. |
| 103 | + |
| 104 | +<br> |
| 105 | + |
| 106 | +> **Problem 2:** The gated parameters $\tilde{\theta}\odot z$ are non-differentiable because the masks $z \in \\{0, 1\\}$ are i) discrete, ii) stochastic\\ |
| 107 | +> **Solution 2i:** Sample random variables from Binary [Concrete Distribution](#concrete-distributions) |
| 108 | +> **Solution 2ii:** Apply [Reparameterisation Trick](#reparameterisation-trick) |
| 109 | + |
| 110 | + |
| 111 | +We have solved the first problem of the regularisation term $L_{\mathrm{reg}}$ being |
| 112 | +differentiable by reformulating $\mid \mid \theta \mid \mid_0 \rightarrow \sum_{j=1}^{|\theta|} \pi_j$. But |
| 113 | +in doing so, we rewrote the term $h(x; \theta) \rightarrow h(x; \tilde{\theta}\odot z)$. Since |
| 114 | +$z$ is stochastic, gradient does not flow and we would like to employ the [reparameterisation |
| 115 | +trick](#reparameterisation-trick). However, we are not able to reparameterise the discrete distribution due to the |
| 116 | +discontinuous nature of discrete states. Therefore, we need to first approximate the Bernoulli |
| 117 | +with a Binary [Concrete distribution](#concrete-distributions). |
| 118 | + |
| 119 | +Next we apply the reparameterisation trick on the Binary Concrete distribution, resulting in learnable parameters $(\mathrm{log} \alpha)$ + some noise which is gumbel distributed. The noise takes the form $\log (u) - log(1-u)$, where $u \sim Uniform(0,1)$. |
| 120 | + |
| 121 | +Let $s$ be a random variable distributed in the (0, 1) interval sampled from a Binary Concrete |
| 122 | +distribution. After applying the reparameterisation trick (details in Louizos 2017), we can sample |
| 123 | + |
| 124 | +$$s = \mathrm{Sigmoid}((\mathrm{log} u - \mathrm{log} (1-u) + \mathrm{log} \alpha) / \beta)$$ |
| 125 | + |
| 126 | +where $u \sim \mathrm{Uniform}(0, 1)$. Here $\mathrm{log}\alpha$ is the location parameter and |
| 127 | +$\beta$ is the temperature. The temperature controls the degree of approximation. With $\beta |
| 128 | += 0$ we recover the original Bernoulli r.v. (but lose the differentiable properties). $\alpha$ |
| 129 | +and $\beta$ are now trainable parameters, while the stochasticity comes from $u \sim U(0, 1)$. |
| 130 | + |
| 131 | +<br> |
| 132 | +> **Problem:** The continuous distribution has too much probability mass which are not at 0 and 1. \\ |
| 133 | +> **Solution:** “stretch” this distribution beyond (0,1) and "fold" it back. |
| 134 | +\\ |
| 135 | + |
| 136 | +We can "stretch" the samples from the distribution to $(\gamma, \zeta)$ interval, where $\gamma |
| 137 | +<0$ and $\zeta>1$. $\tilde{s} = s(\zeta - \gamma) + \gamma$, then apply a *hard-sigmoid* to |
| 138 | +fold the samples back to the interval (0, 1). $z=\mathrm{min}(1, \mathrm{max}(0, \tilde{s}))$. |
| 139 | + |
| 140 | +{% highlight python %} |
| 141 | + |
| 142 | +def sample_z(self): |
| 143 | + if self.training: |
| 144 | + # sample s from binary concrete |
| 145 | + u = torch.FloatTensor(self.num_heads).uniform_().cuda() |
| 146 | + s_ = torch.sigmoid((torch.log(u) - torch.log(1-u) + self.log_alpha) / self.beta) |
| 147 | + |
| 148 | + else: |
| 149 | + # test time |
| 150 | + # sample without noise |
| 151 | + s_ = torch.sigmoid(self.log_alpha) |
| 152 | + |
| 153 | + # stretch values and fold them back to (0,1) |
| 154 | + s_ = s_ * (self.zeta - self.gamma) + self.gamma |
| 155 | + z = torch.clip(s_, min=0, max=1) |
| 156 | + return z |
| 157 | +{% endhighlight %} |
| 158 | + |
| 159 | +<br> |
| 160 | + |
| 161 | +> **Problem:** $z$ is no longer drawn from a Bernouli, so what should be the new regularisation term? \\ |
| 162 | +> **Solution:** Compute the probability of $z$ being 0, but under a CDF. |
| 163 | + |
| 164 | + |
| 165 | +ecall the regularisation term $L_{\mathrm{reg}}$ has evolved from no. Of non-zero parameters |
| 166 | +\eqref{eq:1} , to probability of being 0 under a Bernouli distribution \eqref{eq:2}. |
| 167 | + |
| 168 | + |
| 169 | +We still want to compute the probability of being 0 but since we now have a continuous instead |
| 170 | +of discrete Bernoulli, we need the cumulative distribution function (CDF) $Q(s \mid \alpha, |
| 171 | +\beta)$. |
| 172 | + |
| 173 | +$$ |
| 174 | +\pi^* = \mathrm{argmin}_{\pi} \mathbb{E}_{z \sim Bern(\pi)} \frac{1}{N} \sum_{i=1}^N \mathcal{L} |
| 175 | +(h(x_i, \tilde{\theta} \odot z), y_i) + \lambda \sum_{j=1}^{\mid \theta \mid} (1-Q(s_j \leq0 |
| 176 | +\mid \alpha_j, \beta_j)) |
| 177 | +\tag{eq 3}\label{eq:3} |
| 178 | +$$ |
| 179 | + |
| 180 | + |
| 181 | +The regularisation cost works out to be |
| 182 | + |
| 183 | +$$ |
| 184 | +\sum_{j=1}^{\mid \theta \mid}(1-Q_{s_j}(0 \mid \alpha, \beta)) = \sum_{j=1}^{\mid \theta \mid} \mathrm{sigmoid}(\mathrm{log} \alpha_j - \beta\times \mathrm{log}\frac{-\gamma}{\zeta}) |
| 185 | +$$ |
| 186 | + |
| 187 | +{% highlight python %} |
| 188 | +self.log_ratio_ = math.log(-gamma / self.zeta) |
| 189 | +def get_reg_cost(self): |
| 190 | + if self.log_alpha.requires_grad: |
| 191 | + cost = torch.sigmoid(self.log_alpha - self.beta * self.log_ratio_).sum() |
| 192 | +{% endhighlight %} |
| 193 | + |
| 194 | + |
| 195 | +<br> |
| 196 | + |
| 197 | +#### Concluding Notes (mostly for implementation) |
| 198 | + |
| 199 | +1. When someone writes "Hard Concrete", they mean Hard sigmoid clamping on a continuous relaxation of Bernouli (Concrete) distribution. |
| 200 | + |
| 201 | +2. $\alpha$ and $\beta$ are the parameters that we need to train. |
| 202 | + |
| 203 | +3. Start with gates initialised near 1, not 0 or 0.5, I find that this is the only |
| 204 | + initialisation where the gates can be trained to a reasonable value. |
| 205 | + |
| 206 | +4. Disable early stopping callbacks, or increase the patience level for early stopping. |
| 207 | + Compared to training a model from scratch where we expect the performance to continuously |
| 208 | +increase, we expect the performance to drop rather than increase, as long as it doesnt drop too |
| 209 | +far we’re happy. |
| 210 | + |
| 211 | +5. Consider scaling the $L_0$ Regularisation loss to be in a similar range as the task objective. |
| 212 | + e.g., normalise by batch size and total number of heads. |
| 213 | + |
| 214 | + |
| 215 | +<br> |
| 216 | + |
| 217 | +#### **References** |
| 218 | +[^fn1]: Louizos, Welling and Kingma. (2017) [Learning Sparse Neural Networks Through L0 Regularization](https://arxiv.org/pdf/1712.01312.pdf) |
| 219 | +[^fn2]: Kingma and Welling. [Auto-Encoding Variational Bayes](https://arxiv.org/pdf/1312.6114.pdf) Note: Reparameterisation trick was popularised in ML but not invented by these guys. |
| 220 | +[^fn3]: Maddison, Mnih, Yee. (2016). [The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables](https://arxiv.org/pdf/1611.00712.pdf) |
0 commit comments