Skip to content

Commit 50df283

Browse files
authored
[BREAKING] Remove support for non-integer latency compensation (#547)
* Remove support for non-integer delay, Black * Remove scipy dependency, clean up environment definitions, make pytest verbose on workflow
1 parent 758e24e commit 50df283

File tree

9 files changed

+9
-105
lines changed

9 files changed

+9
-105
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ jobs:
4242
- name: Test with pytest
4343
run: |
4444
python -m pip install pytest pytest-mock
45-
xvfb-run -a pytest
45+
xvfb-run -a pytest -v

environments/environment_cpu_apple.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ name: nam
88
channels:
99
- conda-forge # pytest-mock
1010
- pytorch
11+
- defaults
1112
dependencies:
1213
- python>=3.9
1314
- black
@@ -24,7 +25,6 @@ dependencies:
2425
# Performance note:
2526
# https://github.com/sdatkinson/neural-amp-modeler/issues/505
2627
- pytorch
27-
- scipy
2828
- semver
2929
- tensorboard
3030
- tqdm

environments/environment_gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ channels:
77
- conda-forge # pytest-mock
88
- pytorch
99
- nvidia # GPU
10+
- defaults
1011
dependencies:
1112
- python>=3.9
1213
- black
@@ -25,7 +26,6 @@ dependencies:
2526
# You're going to need to look at Table 3 here:
2627
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions
2728
- pytorch::pytorch-cuda=12.1 # GPU
28-
- scipy
2929
- semver
3030
- tensorboard
3131
- tqdm

nam/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def nam_hello_world():
9292
was installed successfully
9393
"""
9494
from nam import __version__
95+
9596
msg = f"""
9697
Neural Amp Modeler
9798

nam/data.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import numpy as _np
2727
import torch as _torch
2828
import wavio as _wavio
29-
from scipy.interpolate import interp1d as _interp1d
3029
from torch.utils.data import Dataset as _Dataset
3130
from tqdm import tqdm as _tqdm
3231

@@ -176,37 +175,6 @@ def __getitem__(self, idx: int):
176175
pass
177176

178177

179-
class _DelayInterpolationMethod(_Enum):
180-
"""
181-
:param LINEAR: Linear interpolation
182-
:param CUBIC: Cubic spline interpolation
183-
"""
184-
185-
# Note: these match scipy.interpolate.interp1d kwarg "kind"
186-
LINEAR = "linear"
187-
CUBIC = "cubic"
188-
189-
190-
def _interpolate_delay(
191-
x: _torch.Tensor, delay: float, method: _DelayInterpolationMethod
192-
) -> _np.ndarray:
193-
"""
194-
NOTE: This breaks the gradient tape!
195-
"""
196-
if delay == 0.0:
197-
return x
198-
t_in = _np.arange(len(x))
199-
n_out = len(x) - int(_np.ceil(_np.abs(delay)))
200-
if delay > 0:
201-
t_out = _np.arange(n_out) + delay
202-
elif delay < 0:
203-
t_out = _np.arange(len(x) - n_out, len(x)) - _np.abs(delay)
204-
205-
return _torch.Tensor(
206-
_interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out)
207-
)
208-
209-
210178
class XYError(ValueError, DataError):
211179
"""
212180
Exceptions related to invalid x and y provided for data sets
@@ -268,9 +236,6 @@ def __init__(
268236
start_seconds: _Optional[_Union[int, float]] = None,
269237
stop_seconds: _Optional[_Union[int, float]] = None,
270238
delay: _Optional[_Union[int, float]] = None,
271-
delay_interpolation_method: _Union[
272-
str, _DelayInterpolationMethod
273-
] = _DelayInterpolationMethod.CUBIC,
274239
y_scale: float = 1.0,
275240
x_path: _Optional[_Union[str, _Path]] = None,
276241
y_path: _Optional[_Union[str, _Path]] = None,
@@ -305,8 +270,7 @@ def __init__(
305270
the end of the audio. Requires providing `sample_rate`.
306271
:param delay: In samples. Positive means we get rid of the start of x, end of y
307272
(i.e. we are correcting for an alignment error in which y is delayed behind
308-
x). If a non-integer delay is provided, then y is interpolated, with
309-
the extra sample removed.
273+
x). Only integer delays are supported.
310274
:param y_scale: Multiplies the output signal by a factor (e.g. if the data are
311275
too quiet).
312276
:param input_gain: In dB. If the input signal wasn't fed to the amp at unity
@@ -335,17 +299,13 @@ def __init__(
335299
stop_seconds,
336300
self.sample_rate,
337301
)
338-
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
339-
delay_interpolation_method = _DelayInterpolationMethod(
340-
delay_interpolation_method
341-
)
342302
if require_input_pre_silence is not None:
343303
self._validate_preceding_silence(
344304
x, start, require_input_pre_silence, self.sample_rate
345305
)
346306
x, y = [z[start:stop] for z in (x, y)]
347307
if delay is not None and delay != 0:
348-
x, y = self._apply_delay(x, y, delay, delay_interpolation_method)
308+
x, y = self._apply_delay(x, y, delay)
349309
x_scale = 10.0 ** (input_gain / 20.0)
350310
x = x * x_scale
351311
y = y * y_scale
@@ -477,15 +437,12 @@ def _apply_delay(
477437
x: _torch.Tensor,
478438
y: _torch.Tensor,
479439
delay: _Union[int, float],
480-
method: _DelayInterpolationMethod,
481440
) -> _Tuple[_torch.Tensor, _torch.Tensor]:
482441
# Check for floats that could be treated like ints (simpler algorithm)
483442
if isinstance(delay, float) and int(delay) == delay:
484443
delay = int(delay)
485444
if isinstance(delay, int):
486445
return cls._apply_delay_int(x, y, delay)
487-
elif isinstance(delay, float):
488-
return cls._apply_delay_float(x, y, delay, method)
489446
else:
490447
raise TypeError(type(delay))
491448

@@ -501,22 +458,6 @@ def _apply_delay_int(
501458
y = y[:delay]
502459
return x, y
503460

504-
@classmethod
505-
def _apply_delay_float(
506-
cls,
507-
x: _torch.Tensor,
508-
y: _torch.Tensor,
509-
delay: float,
510-
method: _DelayInterpolationMethod,
511-
) -> _Tuple[_torch.Tensor, _torch.Tensor]:
512-
n_out = len(y) - int(_np.ceil(_np.abs(delay)))
513-
if delay > 0:
514-
x = x[:n_out]
515-
elif delay < 0:
516-
x = x[-n_out:]
517-
y = _interpolate_delay(y, delay, method)
518-
return x, y
519-
520461
@classmethod
521462
def _validate_start_stop(
522463
cls,

nam/models/recurrent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _export_config(self):
212212
"hidden_size": self._core.hidden_size,
213213
"num_layers": self._core.num_layers,
214214
}
215+
215216
def _export_weights(self):
216217
"""
217218
* Loop over cells:

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ dependencies = [
1717
"numpy<2",
1818
"pydantic>=2.0.0",
1919
"pytorch_lightning",
20-
"scipy",
2120
"sounddevice",
2221
"tensorboard",
2322
"torch",

tests/test_nam/test_data.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,10 @@ def test_apply_delay_zero(self):
3636
no change.
3737
"""
3838
x, y = self._create_xy()
39-
x_out, y_out = data.Dataset._apply_delay(
40-
x, y, 0, data._DelayInterpolationMethod.CUBIC
41-
)
39+
x_out, y_out = data.Dataset._apply_delay(x, y, 0)
4240
assert torch.all(x == x_out)
4341
assert torch.all(y == y_out)
4442

45-
@pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
46-
def test_apply_delay_float_negative(self, method):
47-
n = 7
48-
delay = -2.5
49-
x_out, y_out = self._t_apply_delay_float(n, delay, method)
50-
51-
assert torch.all(x_out == torch.Tensor([3, 4, 5, 6]))
52-
assert torch.all(y_out == torch.Tensor([0.5, 1.5, 2.5, 3.5]))
53-
54-
@pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
55-
def test_apply_delay_float_positive(self, method):
56-
n = 7
57-
delay = 2.5
58-
x_out, y_out = self._t_apply_delay_float(n, delay, method)
59-
60-
assert torch.all(x_out == torch.Tensor([0, 1, 2, 3]))
61-
assert torch.all(y_out == torch.Tensor([2.5, 3.5, 4.5, 5.5]))
62-
6343
def test_apply_delay_int_negative(self):
6444
"""
6545
Assert proper function of Dataset._apply_delay() when a positive integer delay
@@ -299,29 +279,12 @@ def _create_xy(
299279
torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1))
300280
)
301281

302-
def _t_apply_delay_float(
303-
self, n: int, delay: int, method: data._DelayInterpolationMethod
304-
):
305-
x, y = self._create_xy(
306-
n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
307-
)
308-
309-
x_out, y_out = data.Dataset._apply_delay(x, y, delay, method)
310-
# 7, +/-2.5 -> 4
311-
n_out = n - int(np.ceil(np.abs(delay)))
312-
assert len(x_out) == n_out
313-
assert len(y_out) == n_out
314-
315-
return x_out, y_out
316-
317282
def _t_apply_delay_int(self, n: int, delay: int):
318283
x, y = self._create_xy(
319284
n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
320285
)
321286

322-
x_out, y_out = data.Dataset._apply_delay(
323-
x, y, delay, data._DelayInterpolationMethod.CUBIC
324-
)
287+
x_out, y_out = data.Dataset._apply_delay(x, y, delay)
325288
n_out = n - np.abs(delay)
326289
assert len(x_out) == n_out
327290
assert len(y_out) == n_out

tests/test_nam/test_models/test_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def pad_start_default(self) -> bool:
3131
def receptive_field(self) -> int:
3232
return 1
3333

34-
3534
def _export_config(self):
3635
pass
3736

0 commit comments

Comments
 (0)