Skip to content

Commit c371f18

Browse files
authored
[FEATURE] Implement loading WaveNet model from .nam file (#565)
Implement loading WaveNet model from .nam file
1 parent 9e39835 commit c371f18

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

nam/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import base # noqa F401
1010
from . import exportable # noqa F401
1111
from . import losses # noqa F401
12+
from ._from_nam import init_from_nam # noqa F401
1213
from .conv_net import ConvNet # noqa F401
1314
from .linear import Linear # noqa F401
1415
from .recurrent import LSTM # noqa F401

nam/models/_from_nam.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# File: _from_nam.py
2+
# Created Date: Tuesday May 27th 2025
3+
# Author: Steven Atkinson (steven@atkinson.mn)
4+
5+
"""
6+
Initialize models from .nam files
7+
"""
8+
9+
import torch as _torch
10+
11+
from .base import BaseNet as _BaseNet
12+
from .wavenet import WaveNet as _WaveNet
13+
14+
def _init_wavenet(config) -> _WaveNet:
15+
return _WaveNet(layers_configs=config["layers"], head_config=config["head"], head_scale=config["head_scale"])
16+
17+
18+
def init_from_nam(config) -> _BaseNet:
19+
"""
20+
Taking the contents of a .nam file, initialize a model
21+
22+
E.g.
23+
>>> with open("model.nam", "r") as fp:
24+
... config = json.load(fp)
25+
... model = init_from_nam(config)
26+
"""
27+
model = {
28+
"WaveNet": _init_wavenet
29+
}[config["architecture"]](config["config"])
30+
model.import_weights(_torch.Tensor(config["weights"]))
31+
return model

nam/models/wavenet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ def export_weights(self) -> _np.ndarray:
327327
return weights.detach().cpu().numpy()
328328

329329
def import_weights(self, weights: _torch.Tensor):
330+
if self._head is not None:
331+
raise NotImplementedError("Head importing isn't implemented yet.")
330332
i = 0
331333
for layer in self._layers:
332334
i = layer.import_weights(weights, i)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Test loading from a .nam file
3+
"""
4+
5+
import json as _json
6+
from pathlib import Path as _Path
7+
from tempfile import TemporaryDirectory as _TemporaryDirectory
8+
import pytest as _pytest
9+
10+
from nam.models import _from_nam
11+
from nam.models.wavenet import WaveNet as _WaveNet
12+
13+
14+
@_pytest.mark.parametrize(
15+
"factory,kwargs",
16+
(
17+
# A standard WaveNet
18+
(
19+
_WaveNet, # i.e. .__init__()
20+
{
21+
"layers_configs": [
22+
{
23+
"condition_size": 1,
24+
"input_size": 1,
25+
"channels": 16,
26+
"head_size": 8,
27+
"kernel_size": 3,
28+
"dilations": [
29+
1,
30+
2,
31+
4,
32+
8,
33+
16,
34+
32,
35+
64,
36+
128,
37+
256,
38+
512
39+
],
40+
"activation": "Tanh",
41+
"gated": False,
42+
"head_bias": False
43+
},
44+
{
45+
"condition_size": 1,
46+
"input_size": 16,
47+
"channels": 8,
48+
"head_size": 1,
49+
"kernel_size": 3,
50+
"dilations": [
51+
1,
52+
2,
53+
4,
54+
8,
55+
16,
56+
32,
57+
64,
58+
128,
59+
256,
60+
512
61+
],
62+
"activation": "Tanh",
63+
"gated": False,
64+
"head_bias": True
65+
}
66+
],
67+
"head_scale": 0.02
68+
}
69+
),
70+
),
71+
)
72+
def test_load_from_nam(factory, kwargs):
73+
"""
74+
Assert that loading from a .nam file works by saving the model twice
75+
"""
76+
model = factory(**kwargs)
77+
with _TemporaryDirectory() as tmpdir:
78+
model.export(_Path(tmpdir), basename="model")
79+
with open(_Path(tmpdir, "model.nam"), "r") as fp:
80+
nam_file_contents = _json.load(fp)
81+
model2 = _from_nam.init_from_nam(nam_file_contents)
82+
model2.export(_Path(tmpdir), basename="model2")
83+
with open(_Path(tmpdir, "model2.nam"), "r") as fp:
84+
nam_file_contents2 = _json.load(fp)
85+
86+
# Metadata isn't preseved. At least creation time will be slightly different
87+
# Could improve this to preserve metadata on load
88+
nam_file_contents.pop("metadata")
89+
nam_file_contents2.pop("metadata")
90+
assert nam_file_contents == nam_file_contents2
91+

0 commit comments

Comments
 (0)