Skip to content

Commit 78029f0

Browse files
committed
update v2.2.1
1 parent 8b3aa71 commit 78029f0

File tree

6 files changed

+66
-32
lines changed

6 files changed

+66
-32
lines changed

Diff for: NotaGenNode.py

+60-31
Original file line numberDiff line numberDiff line change
@@ -38,44 +38,56 @@ class NotaGenRun:
3838
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
3939
nota_model_path = nota_model_path
4040
node_dir = node_dir
41+
42+
model_cache = None
4143
@classmethod
4244
def INPUT_TYPES(s):
43-
4445
return {
4546
"required": {
4647
"model": (s.model_names, {"default": "notagenx.pth"}),
47-
"period": (s.periods, {"default": "Romantic", }),
48-
"composer": (s.composers, {"default": "Bach, Johann Sebastian", }),
49-
"instrumentation": (s.instrumentations, {"default": "Keyboard", }),
50-
"custom_prompt": ("STRING", {"default": "Romantic | Bach, Johann Sebastian | Keyboard",
51-
"multiline": True,
52-
"tooltip": "Custom prompt must <period>|<composer>|<instrumentation>."}),
53-
# "num_samples": ("INT", {"default": 1, "min": 1}),
54-
# "temperature": ("FLOAT", {"default": 0.8, "min": 0, "max": 1, "step": 0.1}),
55-
# "top_k": ("INT", {"default": 50, "min": 0}),
56-
# "top_p": ("FLOAT", {"default": 0.95, "min": 0, "max": 1, "step": 0.01}),
48+
"period": (s.periods, {"default": "Romantic"}),
49+
"composer": (s.composers, {"default": "Bach, Johann Sebastian"}),
50+
"instrumentation": (s.instrumentations, {"default": "Keyboard"}),
51+
"custom_prompt": ("STRING", {
52+
"default": "Romantic | Bach, Johann Sebastian | Keyboard",
53+
"multiline": True,
54+
"tooltip": "Custom prompt must follow format: <period>|<composer>|<instrumentation>"
55+
}),
56+
"unload_model":("BOOLEAN", {"default": False}),
57+
"temperature": ("FLOAT", {"default": 0.8, "min": 0, "max": 5, "step": 0.1}),
58+
"top_k": ("INT", {"default": 50, "min": 0}),
59+
"top_p": ("FLOAT", {"default": 0.95, "min": 0, "max": 1, "step": 0.01}),
5760
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
58-
"comfy_python_path": ("STRING", {"default": "", "multiline": False, "tooltip": "The absolute path of python.exe in the Comfyui environment"}),
59-
# "audio_sheet_music": ("BOOLEAN", {"default": True}),
60-
"musescore4_path": ("STRING", {"default": "", "tooltip": r"The absolute path as `D:\APP\MuseScorePortable\App\MuseScore\bin\MuseScore4.exe`"}),
61-
# "abc2xml": ("BOOLEAN", {"default": True}),
61+
"comfy_python_path": ("STRING", {
62+
"default": "",
63+
"multiline": False,
64+
"tooltip": "Absolute path of python.exe in ComfyUI environment"
65+
}),
66+
"musescore4_path": ("STRING", {
67+
"default": "",
68+
"tooltip": r"Absolute path e.g. D:\APP\MuseScorePortable\App\MuseScore\bin\MuseScore4.exe"
69+
}),
6270
},
6371
}
6472

6573
RETURN_TYPES = ("AUDIO", "IMAGE", "STRING")
6674
RETURN_NAMES = ("audio", "score", "message")
6775
FUNCTION = "inference_patch"
68-
CATEGORY = "MW-NotaGen"
76+
CATEGORY = "MW/MW-NotaGen"
6977

70-
# Note_list = Note_list + ['z', 'x']
7178
def inference_patch(self, model, period, composer, instrumentation,
7279
custom_prompt,
73-
# num_samples,
74-
# abc2xml,
7580
comfy_python_path,
76-
# audio_sheet_music,
7781
musescore4_path,
82+
unload_model,
83+
temperature,
84+
top_k,
85+
top_p,
7886
seed):
87+
if seed != 0:
88+
torch.manual_seed(seed)
89+
torch.cuda.manual_seed(seed)
90+
7991
if model == "notagenx.pth" or model == "notagen_large.pth":
8092
cf = nota_lx
8193
elif model == "notagen_small.pth":
@@ -106,7 +118,11 @@ def inference_patch(self, model, period, composer, instrumentation,
106118
print("Parameter Number: " + str(sum(p.numel() for p in nota_model.parameters() if p.requires_grad)))
107119

108120
nota_model_path = os.path.join(self.nota_model_path, model)
109-
checkpoint = torch.load(nota_model_path, map_location=torch.device(self.device))
121+
if self.model_cache is None:
122+
checkpoint = torch.load(nota_model_path, map_location=torch.device(self.device))
123+
self.model_cache = checkpoint
124+
else:
125+
checkpoint = self.model_cache
110126
nota_model.load_state_dict(checkpoint['model'])
111127
nota_model = nota_model.to(self.device)
112128
nota_model.eval()
@@ -148,17 +164,17 @@ def inference_patch(self, model, period, composer, instrumentation,
148164
tunebody_flag = False
149165
while True:
150166
predicted_patch = nota_model.generate(input_patches.unsqueeze(0),
151-
top_k=9,
152-
top_p=0.9,
153-
temperature=1.2)
167+
top_k=top_k,
168+
top_p=top_p,
169+
temperature=temperature)
154170
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # start with [r:0/
155171
tunebody_flag = True
156172
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(self.device)
157173
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
158174
predicted_patch = nota_model.generate(temp_input_patches.unsqueeze(0),
159-
top_k=9,
160-
top_p=0.9,
161-
temperature=1.2)
175+
top_k=top_k,
176+
top_p=top_p,
177+
temperature=temperature)
162178
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
163179
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
164180
end_flag = True
@@ -319,23 +335,36 @@ def inference_patch(self, model, period, composer, instrumentation,
319335
else:
320336
image1 = self.get_empty_image()
321337

338+
if unload_model:
339+
del patchilizer
340+
del nota_model
341+
del checkpoint
342+
torch.cuda.empty_cache()
343+
self.model_cache = None
344+
322345
return (
323346
audio,
324347
image1,
325348
f"Saved to {INTERLEAVED_OUTPUT_FOLDER} and {ORIGINAL_OUTPUT_FOLDER}",
326349
)
327350

328351
else:
352+
if unload_model:
353+
del patchilizer
354+
del nota_model
355+
del checkpoint
356+
torch.cuda.empty_cache()
357+
self.model_cache = None
329358
print(f".abc and .xml was saved to {INTERLEAVED_OUTPUT_FOLDER} and {ORIGINAL_OUTPUT_FOLDER}")
330359
raise Exception("Conversion of .mp3 and .png failed, try again or check if MuseScore4 installation was successful.")
331360

332361

333362
def get_empty_audio(self):
334-
"""返回空音频"""
363+
"""Return empty audio"""
335364
return {"waveform": torch.zeros(1, 2, 1), "sample_rate": 44100}
336365

337366
def get_empty_image(self):
338-
"""返回空图片"""
367+
"""Return empty image"""
339368
import numpy as np
340369
return torch.from_numpy(np.zeros((1, 64, 64, 3), dtype=np.float32))
341370

@@ -429,7 +458,7 @@ def rest_unreduce(self, abc_lines):
429458
return unreduced_lines
430459

431460
def wait_for_file(self, file_path, timeout=15, check_interval=0.3):
432-
"""等待文件生成完成"""
461+
"""Wait for file generation to complete"""
433462
start_time = time.time()
434463

435464
while time.time() - start_time < timeout:
@@ -446,7 +475,7 @@ def wait_for_file(self, file_path, timeout=15, check_interval=0.3):
446475
return False
447476

448477
def wait_for_png_sequence(self, base_path, timeout=15, check_interval=0.3):
449-
"""等待PNG序列生成完成"""
478+
"""Wait for PNG sequence generation to complete"""
450479
import glob
451480

452481
start_time = time.time()

Diff for: README-CN.md

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ https://github.com/user-attachments/assets/0671657f-e66b-4000-a0aa-48520f15b782
99

1010
## 📣 更新
1111

12+
[2025-03-21]⚒️: 增加更多可调参数, 更自由畅玩. 可选是否卸载模型.
13+
1214
[2025-03-15]⚒️: 支持 Linux Ubuntu/Debian 系列, 以及服务器, 其他未测试.
1315

1416
本地 Linux 电脑, 安装 `musescore` 等:

Diff for: README.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ https://github.com/user-attachments/assets/0671657f-e66b-4000-a0aa-48520f15b782
88

99
## 📣 Updates
1010

11+
[2025-03-21] ⚒️: Added more tunable parameters for more creative freedom. Optional model unloading.
12+
1113
[2025-03-15]⚒️: Supports Linux Ubuntu/Debian series, as well as servers, others untested, as well as servers.
1214

1315
For local Linux computers, install `musescore` etc.:

Diff for: images/2025-03-10_06-24-03.png

-78.3 KB
Loading

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "notagen-mw"
33
description = "Symbolic Music Generation, NotaGen node for ComfyUI."
4-
version = "2.1.2"
4+
version = "2.2.1"
55
license = {file = "LICENSE"}
66

77
[project.urls]

Diff for: requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ wandb>=0.17.2
22
abctoolkit>=0.0.6
33
samplings>=0.1.7
44
pyparsing>=3.2.1
5+
transformers>=4.40.0

0 commit comments

Comments
 (0)