1
1
import torch
2
+ import math
2
3
from torch import nn
3
4
from x_transformers import ContinuousTransformerWrapper , Decoder
5
+ from functools import partial
4
6
5
7
from mamba_ssm .utils .generation import InferenceParams
6
8
from .transformer import ContinuousTransformer
9
+ from .mambaplus .mamba import MambaPlus , MambaPlusConfig
7
10
8
11
# Interface for backbone of a language model
9
12
# Handles conditioning and cross-attention
@@ -253,4 +256,90 @@ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross
253
256
self .cuda_graph .replay ()
254
257
return self .captured_logits .clone ()
255
258
256
- return self .model (x , inference_params = self .inference_params if use_cache else None )[:, prepend_length :, :]
259
+ return self .model (x , inference_params = self .inference_params if use_cache else None )[:, prepend_length :, :]
260
+
261
+
262
+ def _init_weights (
263
+ module ,
264
+ n_layer ,
265
+ initializer_range = 0.02 , # Now only used for embedding layer.
266
+ rescale_prenorm_residual = True ,
267
+ n_residuals_per_layer = 1 , # Change to 2 if we have MLP
268
+ ):
269
+ if isinstance (module , nn .Linear ):
270
+ if module .bias is not None :
271
+ if not getattr (module .bias , "_no_reinit" , False ):
272
+ nn .init .zeros_ (module .bias )
273
+ elif isinstance (module , nn .Embedding ):
274
+ nn .init .normal_ (module .weight , std = initializer_range )
275
+
276
+ if rescale_prenorm_residual :
277
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
278
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
279
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
280
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
281
+ #
282
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
283
+ for name , p in module .named_parameters ():
284
+ if name in ["out_proj.weight" , "fc2.weight" ]:
285
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
286
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
287
+ # We need to reinit p since this code could be called multiple times
288
+ # Having just p *= scale would repeatedly scale it down
289
+ nn .init .kaiming_uniform_ (p , a = math .sqrt (5 ))
290
+ with torch .no_grad ():
291
+ p /= math .sqrt (n_residuals_per_layer * n_layer )
292
+
293
+ class MambaPlusAudioLMBackbone (AudioLMBackbone ):
294
+ def __init__ (self ,
295
+ embed_dim : int = 512 ,
296
+ n_layers : int = 32 ,
297
+ d_state : int = 1 ,
298
+ bidirectional : bool = False ,
299
+ num_mod_groups : int = 128 ,
300
+ cross_attn_cond_dim : int = 0 ,
301
+ prepend_cond_dim : int = 0 ,
302
+ ** kwargs ):
303
+ super ().__init__ (embed_dim = embed_dim )
304
+
305
+ self .config = MambaPlusConfig (d_model = embed_dim ,
306
+ n_layers = n_layers ,
307
+ d_state = d_state ,
308
+ expand_factor = 2 ,
309
+ num_mod_groups = num_mod_groups ,
310
+ complex = True ,
311
+ mamba_plus_enabled = True ,
312
+ bidirectional = bidirectional ,
313
+ ** kwargs )
314
+
315
+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
316
+ self .model = MambaPlus (
317
+ config = self .config ,
318
+ ** kwargs
319
+ )
320
+ self .apply (
321
+ partial (
322
+ _init_weights ,
323
+ n_layer = self .config .n_layers
324
+ )
325
+ )
326
+
327
+ if prepend_cond_dim > 0 :
328
+ # Prepend conditioning
329
+ self .to_prepend_embed = nn .Sequential (
330
+ nn .Linear (prepend_cond_dim , embed_dim , bias = False )
331
+ )
332
+
333
+ assert (cross_attn_cond_dim == 0 , "Cross-attention conditioning not supported for MambaPlus" )
334
+
335
+ def forward (self , x , mask = None , prepend_cond = None , prepend_cond_mask = None , cross_attn_cond = None , use_cache = False ):
336
+
337
+ prepend_length = 0
338
+ if prepend_cond is not None :
339
+ # Project the prepend conditioning to the embedding dimension
340
+ prepend_cond = self .to_prepend_embed (prepend_cond )
341
+ prepend_length = prepend_cond .shape [1 ]
342
+
343
+ x = torch .cat ([prepend_cond , x ], dim = 1 )
344
+ return self .model (x )[:, prepend_length :, :]
345
+
0 commit comments