Skip to content

Commit 9ee6ca9

Browse files
bebebe666bebebe666
and
bebebe666
authored
add_optimalsteps (#7584)
Co-authored-by: bebebe666 <jianningpei@tencent.com>
1 parent bb495cc commit 9ee6ca9

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

comfy_extras/nodes_optimalsteps.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# from https://github.com/bebebe666/OptimalSteps
2+
3+
4+
import numpy as np
5+
import torch
6+
7+
def loglinear_interp(t_steps, num_steps):
8+
"""
9+
Performs log-linear interpolation of a given array of decreasing numbers.
10+
"""
11+
xs = np.linspace(0, 1, len(t_steps))
12+
ys = np.log(t_steps[::-1])
13+
14+
new_xs = np.linspace(0, 1, num_steps)
15+
new_ys = np.interp(new_xs, xs, ys)
16+
17+
interped_ys = np.exp(new_ys)[::-1].copy()
18+
return interped_ys
19+
20+
21+
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
22+
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
23+
}
24+
25+
class OptimalStepsScheduler:
26+
@classmethod
27+
def INPUT_TYPES(s):
28+
return {"required":
29+
{"model_type": (["FLUX", "Wan"], ),
30+
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
31+
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
32+
}
33+
}
34+
RETURN_TYPES = ("SIGMAS",)
35+
CATEGORY = "sampling/custom_sampling/schedulers"
36+
37+
FUNCTION = "get_sigmas"
38+
39+
def get_sigmas(self, model_type, steps, denoise):
40+
total_steps = steps
41+
if denoise < 1.0:
42+
if denoise <= 0.0:
43+
return (torch.FloatTensor([]),)
44+
total_steps = round(steps * denoise)
45+
46+
sigmas = NOISE_LEVELS[model_type][:]
47+
if (steps + 1) != len(sigmas):
48+
sigmas = loglinear_interp(sigmas, steps + 1)
49+
50+
sigmas = sigmas[-(total_steps + 1):]
51+
sigmas[-1] = 0
52+
return (torch.FloatTensor(sigmas), )
53+
54+
NODE_CLASS_MAPPINGS = {
55+
"OptimalStepsScheduler": OptimalStepsScheduler,
56+
}

nodes.py

+1
Original file line numberDiff line numberDiff line change
@@ -2280,6 +2280,7 @@ def init_builtin_extra_nodes():
22802280
"nodes_hunyuan3d.py",
22812281
"nodes_primitive.py",
22822282
"nodes_cfg.py",
2283+
"nodes_optimalsteps.py"
22832284
]
22842285

22852286
import_failed = []

0 commit comments

Comments
 (0)