@@ -38,44 +38,56 @@ class NotaGenRun:
38
38
device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
39
39
nota_model_path = nota_model_path
40
40
node_dir = node_dir
41
+
42
+ model_cache = None
41
43
@classmethod
42
44
def INPUT_TYPES (s ):
43
-
44
45
return {
45
46
"required" : {
46
47
"model" : (s .model_names , {"default" : "notagenx.pth" }),
47
- "period" : (s .periods , {"default" : "Romantic" , }),
48
- "composer" : (s .composers , {"default" : "Bach, Johann Sebastian" , }),
49
- "instrumentation" : (s .instrumentations , {"default" : "Keyboard" , }),
50
- "custom_prompt" : ("STRING" , {"default" : "Romantic | Bach, Johann Sebastian | Keyboard" ,
51
- "multiline" : True ,
52
- "tooltip" : "Custom prompt must <period>|<composer>|<instrumentation>." }),
53
- # "num_samples": ("INT", {"default": 1, "min": 1}),
54
- # "temperature": ("FLOAT", {"default": 0.8, "min": 0, "max": 1, "step": 0.1}),
55
- # "top_k": ("INT", {"default": 50, "min": 0}),
56
- # "top_p": ("FLOAT", {"default": 0.95, "min": 0, "max": 1, "step": 0.01}),
48
+ "period" : (s .periods , {"default" : "Romantic" }),
49
+ "composer" : (s .composers , {"default" : "Bach, Johann Sebastian" }),
50
+ "instrumentation" : (s .instrumentations , {"default" : "Keyboard" }),
51
+ "custom_prompt" : ("STRING" , {
52
+ "default" : "Romantic | Bach, Johann Sebastian | Keyboard" ,
53
+ "multiline" : True ,
54
+ "tooltip" : "Custom prompt must follow format: <period>|<composer>|<instrumentation>"
55
+ }),
56
+ "unload_model" :("BOOLEAN" , {"default" : False }),
57
+ "temperature" : ("FLOAT" , {"default" : 0.8 , "min" : 0 , "max" : 5 , "step" : 0.1 }),
58
+ "top_k" : ("INT" , {"default" : 50 , "min" : 0 }),
59
+ "top_p" : ("FLOAT" , {"default" : 0.95 , "min" : 0 , "max" : 1 , "step" : 0.01 }),
57
60
"seed" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 0xffffffffffffffff }),
58
- "comfy_python_path" : ("STRING" , {"default" : "" , "multiline" : False , "tooltip" : "The absolute path of python.exe in the Comfyui environment" }),
59
- # "audio_sheet_music": ("BOOLEAN", {"default": True}),
60
- "musescore4_path" : ("STRING" , {"default" : "" , "tooltip" : r"The absolute path as `D:\APP\MuseScorePortable\App\MuseScore\bin\MuseScore4.exe`" }),
61
- # "abc2xml": ("BOOLEAN", {"default": True}),
61
+ "comfy_python_path" : ("STRING" , {
62
+ "default" : "" ,
63
+ "multiline" : False ,
64
+ "tooltip" : "Absolute path of python.exe in ComfyUI environment"
65
+ }),
66
+ "musescore4_path" : ("STRING" , {
67
+ "default" : "" ,
68
+ "tooltip" : r"Absolute path e.g. D:\APP\MuseScorePortable\App\MuseScore\bin\MuseScore4.exe"
69
+ }),
62
70
},
63
71
}
64
72
65
73
RETURN_TYPES = ("AUDIO" , "IMAGE" , "STRING" )
66
74
RETURN_NAMES = ("audio" , "score" , "message" )
67
75
FUNCTION = "inference_patch"
68
- CATEGORY = "MW-NotaGen"
76
+ CATEGORY = "MW/MW -NotaGen"
69
77
70
- # Note_list = Note_list + ['z', 'x']
71
78
def inference_patch (self , model , period , composer , instrumentation ,
72
79
custom_prompt ,
73
- # num_samples,
74
- # abc2xml,
75
80
comfy_python_path ,
76
- # audio_sheet_music,
77
81
musescore4_path ,
82
+ unload_model ,
83
+ temperature ,
84
+ top_k ,
85
+ top_p ,
78
86
seed ):
87
+ if seed != 0 :
88
+ torch .manual_seed (seed )
89
+ torch .cuda .manual_seed (seed )
90
+
79
91
if model == "notagenx.pth" or model == "notagen_large.pth" :
80
92
cf = nota_lx
81
93
elif model == "notagen_small.pth" :
@@ -106,7 +118,11 @@ def inference_patch(self, model, period, composer, instrumentation,
106
118
print ("Parameter Number: " + str (sum (p .numel () for p in nota_model .parameters () if p .requires_grad )))
107
119
108
120
nota_model_path = os .path .join (self .nota_model_path , model )
109
- checkpoint = torch .load (nota_model_path , map_location = torch .device (self .device ))
121
+ if self .model_cache is None :
122
+ checkpoint = torch .load (nota_model_path , map_location = torch .device (self .device ))
123
+ self .model_cache = checkpoint
124
+ else :
125
+ checkpoint = self .model_cache
110
126
nota_model .load_state_dict (checkpoint ['model' ])
111
127
nota_model = nota_model .to (self .device )
112
128
nota_model .eval ()
@@ -148,17 +164,17 @@ def inference_patch(self, model, period, composer, instrumentation,
148
164
tunebody_flag = False
149
165
while True :
150
166
predicted_patch = nota_model .generate (input_patches .unsqueeze (0 ),
151
- top_k = 9 ,
152
- top_p = 0.9 ,
153
- temperature = 1.2 )
167
+ top_k = top_k ,
168
+ top_p = top_p ,
169
+ temperature = temperature )
154
170
if not tunebody_flag and patchilizer .decode ([predicted_patch ]).startswith ('[r:' ): # start with [r:0/
155
171
tunebody_flag = True
156
172
r0_patch = torch .tensor ([ord (c ) for c in '[r:0/' ]).unsqueeze (0 ).to (self .device )
157
173
temp_input_patches = torch .concat ([input_patches , r0_patch ], axis = - 1 )
158
174
predicted_patch = nota_model .generate (temp_input_patches .unsqueeze (0 ),
159
- top_k = 9 ,
160
- top_p = 0.9 ,
161
- temperature = 1.2 )
175
+ top_k = top_k ,
176
+ top_p = top_p ,
177
+ temperature = temperature )
162
178
predicted_patch = [ord (c ) for c in '[r:0/' ] + predicted_patch
163
179
if predicted_patch [0 ] == patchilizer .bos_token_id and predicted_patch [1 ] == patchilizer .eos_token_id :
164
180
end_flag = True
@@ -319,23 +335,36 @@ def inference_patch(self, model, period, composer, instrumentation,
319
335
else :
320
336
image1 = self .get_empty_image ()
321
337
338
+ if unload_model :
339
+ del patchilizer
340
+ del nota_model
341
+ del checkpoint
342
+ torch .cuda .empty_cache ()
343
+ self .model_cache = None
344
+
322
345
return (
323
346
audio ,
324
347
image1 ,
325
348
f"Saved to { INTERLEAVED_OUTPUT_FOLDER } and { ORIGINAL_OUTPUT_FOLDER } " ,
326
349
)
327
350
328
351
else :
352
+ if unload_model :
353
+ del patchilizer
354
+ del nota_model
355
+ del checkpoint
356
+ torch .cuda .empty_cache ()
357
+ self .model_cache = None
329
358
print (f".abc and .xml was saved to { INTERLEAVED_OUTPUT_FOLDER } and { ORIGINAL_OUTPUT_FOLDER } " )
330
359
raise Exception ("Conversion of .mp3 and .png failed, try again or check if MuseScore4 installation was successful." )
331
360
332
361
333
362
def get_empty_audio (self ):
334
- """返回空音频 """
363
+ """Return empty audio """
335
364
return {"waveform" : torch .zeros (1 , 2 , 1 ), "sample_rate" : 44100 }
336
365
337
366
def get_empty_image (self ):
338
- """返回空图片 """
367
+ """Return empty image """
339
368
import numpy as np
340
369
return torch .from_numpy (np .zeros ((1 , 64 , 64 , 3 ), dtype = np .float32 ))
341
370
@@ -429,7 +458,7 @@ def rest_unreduce(self, abc_lines):
429
458
return unreduced_lines
430
459
431
460
def wait_for_file (self , file_path , timeout = 15 , check_interval = 0.3 ):
432
- """等待文件生成完成 """
461
+ """Wait for file generation to complete """
433
462
start_time = time .time ()
434
463
435
464
while time .time () - start_time < timeout :
@@ -446,7 +475,7 @@ def wait_for_file(self, file_path, timeout=15, check_interval=0.3):
446
475
return False
447
476
448
477
def wait_for_png_sequence (self , base_path , timeout = 15 , check_interval = 0.3 ):
449
- """等待PNG序列生成完成 """
478
+ """Wait for PNG sequence generation to complete """
450
479
import glob
451
480
452
481
start_time = time .time ()
0 commit comments