Skip to content

Commit 34ce944

Browse files
authored
Merge branch 'main' into main
2 parents 926511c + 277c597 commit 34ce944

19 files changed

+1653
-21
lines changed

mlx_audio/codec/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .models import Encodec, Mimi, Vocos
1+
from .models import DAC, Encodec, Mimi, Vocos

mlx_audio/codec/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .descript import DAC
12
from .encodec import Encodec
23
from .mimi import Mimi
4+
from .snac import SNAC
35
from .vocos import Vocos
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dac import DAC
+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import math
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Union
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
import numpy as np
9+
import soundfile as sf
10+
from einops.array_api import rearrange
11+
12+
SUPPORTED_VERSIONS = ["1.0.0"]
13+
14+
15+
@dataclass
16+
class DACFile:
17+
codes: mx.array
18+
19+
# Metadata
20+
chunk_length: int
21+
original_length: int
22+
input_db: float
23+
channels: int
24+
sample_rate: int
25+
padding: bool
26+
dac_version: str
27+
28+
def save(self, path):
29+
artifacts = {
30+
"codes": np.array(self.codes).astype(np.uint16),
31+
"metadata": {
32+
"input_db": self.input_db,
33+
"original_length": self.original_length,
34+
"sample_rate": self.sample_rate,
35+
"chunk_length": self.chunk_length,
36+
"channels": self.channels,
37+
"padding": self.padding,
38+
"dac_version": SUPPORTED_VERSIONS[-1],
39+
},
40+
}
41+
path = Path(path).with_suffix(".dac")
42+
with open(path, "wb") as f:
43+
np.save(f, artifacts)
44+
return path
45+
46+
@classmethod
47+
def load(cls, path):
48+
artifacts = np.load(path, allow_pickle=True)[()]
49+
codes = mx.array(artifacts["codes"], dtype=mx.int32)
50+
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51+
raise RuntimeError(
52+
f"Given file {path} can't be loaded with this version of descript-audio-codec."
53+
)
54+
return cls(codes=codes, **artifacts["metadata"])
55+
56+
57+
class CodecMixin:
58+
@property
59+
def padding(self):
60+
if not hasattr(self, "_padding"):
61+
self._padding = True
62+
return self._padding
63+
64+
@padding.setter
65+
def padding(self, value):
66+
assert isinstance(value, bool)
67+
68+
layers = [
69+
layer
70+
for layer in self.modules()
71+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))
72+
]
73+
74+
for layer in layers:
75+
if value:
76+
if hasattr(layer, "original_padding"):
77+
layer.padding = layer.original_padding
78+
else:
79+
layer.original_padding = layer.padding
80+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
81+
82+
self._padding = value
83+
84+
def get_delay(self):
85+
l_out = self.get_output_length(0)
86+
L = l_out
87+
88+
layers = []
89+
for layer in self.modules():
90+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
91+
layers.append(layer)
92+
93+
for layer in reversed(layers):
94+
d = layer.dilation
95+
k = layer.weight.shape[1]
96+
s = layer.stride
97+
98+
if isinstance(layer, nn.ConvTranspose1d):
99+
L = ((L - d * (k - 1) - 1) / s) + 1
100+
elif isinstance(layer, nn.Conv1d):
101+
L = (L - 1) * s + d * (k - 1) + 1
102+
103+
L = math.ceil(L)
104+
105+
l_in = L
106+
107+
return (l_in - l_out) // 2
108+
109+
def get_output_length(self, input_length):
110+
L = input_length
111+
for layer in self.modules():
112+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113+
d = layer.dilation
114+
k = layer.weight.shape[1]
115+
s = layer.stride
116+
117+
if isinstance(layer, nn.Conv1d):
118+
L = ((L - d * (k - 1) - 1) / s) + 1
119+
elif isinstance(layer, nn.ConvTranspose1d):
120+
L = (L - 1) * s + d * (k - 1) + 1
121+
122+
L = math.floor(L)
123+
return L
124+
125+
def compress(
126+
self,
127+
audio_path: Union[str, Path],
128+
win_duration: float = 1.0,
129+
normalize_db: float = -16,
130+
n_quantizers: int = None,
131+
) -> DACFile:
132+
audio_signal, original_sr = sf.read(audio_path)
133+
signal_duration = audio_signal.shape[-1] / original_sr
134+
135+
original_padding = self.padding
136+
if original_sr != self.sample_rate:
137+
raise ValueError(
138+
f"Sample rate of the audio signal ({original_sr}) does not match the sample rate of the model ({self.sample_rate})."
139+
)
140+
141+
audio_data = mx.array(audio_signal)
142+
143+
rms = mx.sqrt(mx.mean(mx.power(audio_data, 2), axis=-1) + 1e-12)
144+
input_db = 20 * mx.log10(rms / 1.0 + 1e-12)
145+
146+
if normalize_db is not None:
147+
audio_data = audio_data * mx.power(10, (normalize_db - input_db) / 20)
148+
149+
audio_data = rearrange(audio_data, "n -> 1 1 n")
150+
nb, nac, nt = audio_data.shape
151+
audio_data = rearrange(audio_data, "nb nac nt -> (nb nac) 1 nt")
152+
153+
win_duration = signal_duration if win_duration is None else win_duration
154+
155+
if signal_duration <= win_duration:
156+
self.padding = True
157+
n_samples = nt
158+
hop = nt
159+
else:
160+
self.padding = False
161+
audio_data = mx.pad(audio_data, [(0, 0), (0, 0), (self.delay, self.delay)])
162+
163+
n_samples = int(win_duration * self.sample_rate)
164+
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
165+
hop = self.get_output_length(n_samples)
166+
167+
codes = []
168+
for i in range(0, nt, hop):
169+
x = audio_data[..., i : i + n_samples]
170+
x = mx.pad(x, [(0, 0), (0, 0), (0, max(0, n_samples - x.shape[-1]))])
171+
172+
x = self.preprocess(x, self.sample_rate)
173+
_, c, _, _, _ = self.encode(x, n_quantizers)
174+
codes.append(c)
175+
chunk_length = c.shape[-1]
176+
177+
codes = mx.concatenate(codes, axis=-1)
178+
179+
dac_file = DACFile(
180+
codes=codes,
181+
chunk_length=chunk_length,
182+
original_length=signal_duration,
183+
input_db=input_db,
184+
channels=nac,
185+
sample_rate=original_sr,
186+
padding=self.padding,
187+
dac_version=SUPPORTED_VERSIONS[-1],
188+
)
189+
190+
if n_quantizers is not None:
191+
codes = codes[:, :n_quantizers, :]
192+
193+
self.padding = original_padding
194+
return dac_file
195+
196+
def decompress(self, obj: Union[str, Path, DACFile]) -> mx.array:
197+
if isinstance(obj, (str, Path)):
198+
obj = DACFile.load(obj)
199+
200+
if self.sample_rate != obj.sample_rate:
201+
raise ValueError(
202+
f"Sample rate of the audio signal ({obj.sample_rate}) does not match the sample rate of the model ({self.sample_rate})."
203+
)
204+
205+
original_padding = self.padding
206+
self.padding = obj.padding
207+
208+
codes = obj.codes
209+
chunk_length = obj.chunk_length
210+
recons = []
211+
212+
for i in range(0, codes.shape[-1], chunk_length):
213+
c = codes[..., i : i + chunk_length]
214+
z = self.quantizer.from_codes(c)[0]
215+
r = self.decode(z)
216+
recons.append(r)
217+
218+
recons = mx.concatenate(recons, axis=1)
219+
recons = rearrange(recons, "1 n 1 -> 1 n")
220+
221+
target_db = obj.input_db
222+
normalize_db = -16
223+
224+
if normalize_db is not None:
225+
recons = recons * mx.power(10, (target_db - normalize_db) / 20)
226+
227+
self.padding = original_padding
228+
return recons

0 commit comments

Comments
 (0)