Skip to content

Commit 18f4a05

Browse files
authored
Miscellaneous bug fixes (#540)
* Implement ConcatDataset.sample_rate * Update environment_gpu.yml * Implement nx, ny, sample_rate safely in ConcatDataset
1 parent e200409 commit 18f4a05

File tree

3 files changed

+74
-9
lines changed

3 files changed

+74
-9
lines changed

environments/environment_gpu.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ dependencies:
1717
- numpy<2
1818
- pip
1919
- pre-commit
20-
- pydantic>=2.0.0
20+
- pydantic>=2
2121
- pytest
2222
- pytest-mock
23-
- pytorch
23+
- pytorch::pytorch
2424
# If your GPU isn't being detected, you may need a different version.
2525
# You're going to need to look at Table 3 here:
2626
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions
27-
- pytorch-cuda=12.1 # GPU
27+
- pytorch::pytorch-cuda=12.1 # GPU
2828
- scipy
2929
- semver
3030
- tensorboard

nam/data.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ def __len__(self) -> int:
375375
single_pairs = n - self._nx + 1
376376
return single_pairs // self._ny
377377

378+
@property
379+
def nx(self) -> int:
380+
return self._nx
381+
378382
@property
379383
def ny(self) -> int:
380384
return self._ny
@@ -695,6 +699,14 @@ def _validate_preceding_silence(
695699
)
696700

697701

702+
class ConcatDatasetValidationError(ValueError):
703+
"""
704+
Error raised when a ConcatDataset fails validation
705+
"""
706+
707+
pass
708+
709+
698710
class ConcatDataset(AbstractDataset, _InitializableFromConfig):
699711
def __init__(self, datasets: _Sequence[Dataset], flatten=True):
700712
if flatten:
@@ -717,6 +729,21 @@ def __len__(self) -> int:
717729
def datasets(self):
718730
return self._datasets
719731

732+
@property
733+
def nx(self) -> int:
734+
# Validated at initialization
735+
return self.datasets[0].nx
736+
737+
@property
738+
def ny(self) -> int:
739+
# Validated at initialization
740+
return self.datasets[0].ny
741+
742+
@property
743+
def sample_rate(self) -> _Optional[float]:
744+
# This is validated to be consistent across datasets during initialization
745+
return self.datasets[0].sample_rate
746+
720747
@classmethod
721748
def parse_config(cls, config):
722749
init = _dataset_init_registry[config.get("type", "dataset")]
@@ -767,14 +794,20 @@ def _make_lookup(self):
767794

768795
@classmethod
769796
def _validate_datasets(cls, datasets: _Sequence[Dataset]):
797+
# Ensure that a couple attrs are consistent across the sub-datasets.
770798
Reference = _namedtuple("Reference", ("index", "val"))
771-
ref_keys, ref_ny = None, None
799+
references = {name: None for name in ("nx", "ny", "sample_rate")}
772800
for i, d in enumerate(datasets):
773-
ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny
774-
if d.ny != ref_ny.val:
775-
raise ValueError(
776-
f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and {i} ({d.ny})"
777-
)
801+
for name in references.keys():
802+
this_val = getattr(d, name)
803+
if references[name] is None:
804+
references[name] = Reference(i, this_val)
805+
806+
if this_val != references[name].val:
807+
raise ConcatDatasetValidationError(
808+
f"Mismatch between {name} of datasets {references[name].index} "
809+
f"({references[name].val}) and {i} ({this_val})"
810+
)
778811

779812

780813
_dataset_init_registry = {"dataset": Dataset.init_from_config}

tests/test_nam/test_data.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,38 @@ def test_sample_widths(self, sample_width: int):
385385
assert info.sampwidth == sample_width
386386

387387

388+
class TestConcatDataset(object):
389+
@pytest.mark.parametrize("attrname", ("nx", "ny", "sample_rate"))
390+
def test_valiation_sample_rate_fail(self, attrname: str):
391+
"""
392+
Assert failed validation for datasets with different nx, ny, sample rates
393+
"""
394+
nx, ny, sample_rate = 1, 2, 48_000.0
395+
396+
n1 = 16
397+
ds1_kwargs = dict(
398+
x=torch.zeros((n1,)),
399+
y=torch.zeros((n1,)),
400+
nx=nx,
401+
ny=ny,
402+
sample_rate=sample_rate,
403+
)
404+
ds1 = data.Dataset(**ds1_kwargs)
405+
n2 = 7
406+
ds2_kwargs = dict(
407+
x=torch.zeros((n2,)),
408+
y=torch.zeros((n2,)),
409+
nx=nx,
410+
ny=ny,
411+
sample_rate=sample_rate,
412+
)
413+
# Cause the error by moving the named attr:
414+
ds2_kwargs[attrname] += 1
415+
ds2 = data.Dataset(**ds2_kwargs)
416+
with pytest.raises(data.ConcatDatasetValidationError):
417+
data.ConcatDataset([ds1, ds2])
418+
419+
388420
def test_audio_mismatch_shapes_in_order():
389421
"""
390422
https://github.com/sdatkinson/neural-amp-modeler/issues/257

0 commit comments

Comments
 (0)