Skip to content

Commit ac609ea

Browse files
Merge pull request #13 from Stability-AI/nar
Add NAR support
2 parents ffb105a + 36366ef commit ac609ea

File tree

10 files changed

+897
-8
lines changed

10 files changed

+897
-8
lines changed

Diff for: .gitignore

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
# Repo-specific
2-
# wav file created as part of the gradio demo
3-
output.wav
4-
5-
61
# Byte-compiled / optimized / DLL files
72
__pycache__/
83
*.py[cod]
@@ -163,3 +158,7 @@ cython_debug/
163158
# and can be added to the global gitignore or merged into this file. For a more nuclear
164159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165160
#.idea/
161+
162+
*.ckpt
163+
*.wav
164+
wandb/*

Diff for: stable_audio_tools/models/factory.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def create_model_from_config(model_config):
2323
elif model_type == 'lm':
2424
from .lm import create_audio_lm_from_config
2525
return create_audio_lm_from_config(model_config)
26+
elif model_type == 'nar':
27+
from .nar import create_audio_nar_from_config
28+
return create_audio_nar_from_config(model_config)
2629
else:
2730
raise NotImplementedError(f'Unknown model type: {model_type}')
2831

Diff for: stable_audio_tools/models/lm_backbone.py

+90-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import torch
2+
import math
23
from torch import nn
34
from x_transformers import ContinuousTransformerWrapper, Decoder
5+
from functools import partial
46

57
from mamba_ssm.utils.generation import InferenceParams
68
from .transformer import ContinuousTransformer
9+
from .mambaplus.mamba import MambaPlus, MambaPlusConfig
710

811
# Interface for backbone of a language model
912
# Handles conditioning and cross-attention
@@ -253,4 +256,90 @@ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross
253256
self.cuda_graph.replay()
254257
return self.captured_logits.clone()
255258

256-
return self.model(x, inference_params=self.inference_params if use_cache else None)[:, prepend_length:, :]
259+
return self.model(x, inference_params=self.inference_params if use_cache else None)[:, prepend_length:, :]
260+
261+
262+
def _init_weights(
263+
module,
264+
n_layer,
265+
initializer_range=0.02, # Now only used for embedding layer.
266+
rescale_prenorm_residual=True,
267+
n_residuals_per_layer=1, # Change to 2 if we have MLP
268+
):
269+
if isinstance(module, nn.Linear):
270+
if module.bias is not None:
271+
if not getattr(module.bias, "_no_reinit", False):
272+
nn.init.zeros_(module.bias)
273+
elif isinstance(module, nn.Embedding):
274+
nn.init.normal_(module.weight, std=initializer_range)
275+
276+
if rescale_prenorm_residual:
277+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
278+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
279+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
280+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
281+
#
282+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
283+
for name, p in module.named_parameters():
284+
if name in ["out_proj.weight", "fc2.weight"]:
285+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
286+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
287+
# We need to reinit p since this code could be called multiple times
288+
# Having just p *= scale would repeatedly scale it down
289+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
290+
with torch.no_grad():
291+
p /= math.sqrt(n_residuals_per_layer * n_layer)
292+
293+
class MambaPlusAudioLMBackbone(AudioLMBackbone):
294+
def __init__(self,
295+
embed_dim: int = 512,
296+
n_layers: int = 32,
297+
d_state: int = 1,
298+
bidirectional: bool = False,
299+
num_mod_groups: int = 128,
300+
cross_attn_cond_dim: int = 0,
301+
prepend_cond_dim: int = 0,
302+
**kwargs):
303+
super().__init__(embed_dim=embed_dim)
304+
305+
self.config = MambaPlusConfig(d_model=embed_dim,
306+
n_layers= n_layers,
307+
d_state = d_state,
308+
expand_factor=2,
309+
num_mod_groups = num_mod_groups,
310+
complex=True,
311+
mamba_plus_enabled = True,
312+
bidirectional=bidirectional,
313+
**kwargs)
314+
315+
# Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
316+
self.model = MambaPlus(
317+
config=self.config,
318+
**kwargs
319+
)
320+
self.apply(
321+
partial(
322+
_init_weights,
323+
n_layer=self.config.n_layers
324+
)
325+
)
326+
327+
if prepend_cond_dim > 0:
328+
# Prepend conditioning
329+
self.to_prepend_embed = nn.Sequential(
330+
nn.Linear(prepend_cond_dim, embed_dim, bias=False)
331+
)
332+
333+
assert (cross_attn_cond_dim == 0, "Cross-attention conditioning not supported for MambaPlus")
334+
335+
def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, use_cache=False):
336+
337+
prepend_length = 0
338+
if prepend_cond is not None:
339+
# Project the prepend conditioning to the embedding dimension
340+
prepend_cond = self.to_prepend_embed(prepend_cond)
341+
prepend_length = prepend_cond.shape[1]
342+
343+
x = torch.cat([prepend_cond, x], dim=1)
344+
return self.model(x)[:, prepend_length:, :]
345+

0 commit comments

Comments
 (0)