1
1
import torch
2
- import numpy as np
3
2
from ldm_patched .ldm .modules .diffusionmodules .util import make_beta_schedule
4
3
import math
4
+ import numpy as np
5
5
6
6
class EPS :
7
7
def calculate_input (self , sigma , noise ):
@@ -12,12 +12,28 @@ def calculate_denoised(self, sigma, model_output, model_input):
12
12
sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
13
13
return model_input - model_output * sigma
14
14
15
+ def noise_scaling (self , sigma , noise , latent_image , max_denoise = False ):
16
+ if max_denoise :
17
+ noise = noise * torch .sqrt (1.0 + sigma ** 2.0 )
18
+ else :
19
+ noise = noise * sigma
20
+
21
+ noise += latent_image
22
+ return noise
23
+
24
+ def inverse_noise_scaling (self , sigma , latent ):
25
+ return latent
15
26
16
27
class V_PREDICTION (EPS ):
17
28
def calculate_denoised (self , sigma , model_output , model_input ):
18
29
sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
19
30
return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) - model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
20
31
32
+ class EDM (V_PREDICTION ):
33
+ def calculate_denoised (self , sigma , model_output , model_input ):
34
+ sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
35
+ return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) + model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
36
+
21
37
22
38
class ModelSamplingDiscrete (torch .nn .Module ):
23
39
def __init__ (self , model_config = None ):
@@ -42,21 +58,25 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
42
58
else :
43
59
betas = make_beta_schedule (beta_schedule , timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = cosine_s )
44
60
alphas = 1. - betas
45
- alphas_cumprod = torch .tensor (np .cumprod (alphas , axis = 0 ), dtype = torch .float32 )
46
- # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
61
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
47
62
48
63
timesteps , = betas .shape
49
64
self .num_timesteps = int (timesteps )
50
65
self .linear_start = linear_start
51
66
self .linear_end = linear_end
52
67
68
+ # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
69
+ # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
70
+ # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
71
+
53
72
sigmas = ((1 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
73
+ alphas_cumprod = torch .tensor (np .cumprod (alphas , axis = 0 ), dtype = torch .float32 )
54
74
self .set_sigmas (sigmas )
55
75
self .set_alphas_cumprod (alphas_cumprod .float ())
56
76
57
77
def set_sigmas (self , sigmas ):
58
- self .register_buffer ('sigmas' , sigmas )
59
- self .register_buffer ('log_sigmas' , sigmas .log ())
78
+ self .register_buffer ('sigmas' , sigmas . float () )
79
+ self .register_buffer ('log_sigmas' , sigmas .log (). float () )
60
80
61
81
def set_alphas_cumprod (self , alphas_cumprod ):
62
82
self .register_buffer ("alphas_cumprod" , alphas_cumprod .float ())
@@ -94,18 +114,18 @@ def percent_to_sigma(self, percent):
94
114
class ModelSamplingContinuousEDM (torch .nn .Module ):
95
115
def __init__ (self , model_config = None ):
96
116
super ().__init__ ()
97
- self .sigma_data = 1.0
98
-
99
117
if model_config is not None :
100
118
sampling_settings = model_config .sampling_settings
101
119
else :
102
120
sampling_settings = {}
103
121
104
122
sigma_min = sampling_settings .get ("sigma_min" , 0.002 )
105
123
sigma_max = sampling_settings .get ("sigma_max" , 120.0 )
106
- self .set_sigma_range (sigma_min , sigma_max )
124
+ sigma_data = sampling_settings .get ("sigma_data" , 1.0 )
125
+ self .set_parameters (sigma_min , sigma_max , sigma_data )
107
126
108
- def set_sigma_range (self , sigma_min , sigma_max ):
127
+ def set_parameters (self , sigma_min , sigma_max , sigma_data ):
128
+ self .sigma_data = sigma_data
109
129
sigmas = torch .linspace (math .log (sigma_min ), math .log (sigma_max ), 1000 ).exp ()
110
130
111
131
self .register_buffer ('sigmas' , sigmas ) #for compatibility with some schedulers
@@ -134,3 +154,56 @@ def percent_to_sigma(self, percent):
134
154
135
155
log_sigma_min = math .log (self .sigma_min )
136
156
return math .exp ((math .log (self .sigma_max ) - log_sigma_min ) * percent + log_sigma_min )
157
+
158
+ class StableCascadeSampling (ModelSamplingDiscrete ):
159
+ def __init__ (self , model_config = None ):
160
+ super ().__init__ ()
161
+
162
+ if model_config is not None :
163
+ sampling_settings = model_config .sampling_settings
164
+ else :
165
+ sampling_settings = {}
166
+
167
+ self .set_parameters (sampling_settings .get ("shift" , 1.0 ))
168
+
169
+ def set_parameters (self , shift = 1.0 , cosine_s = 8e-3 ):
170
+ self .shift = shift
171
+ self .cosine_s = torch .tensor (cosine_s )
172
+ self ._init_alpha_cumprod = torch .cos (self .cosine_s / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2
173
+
174
+ #This part is just for compatibility with some schedulers in the codebase
175
+ self .num_timesteps = 10000
176
+ sigmas = torch .empty ((self .num_timesteps ), dtype = torch .float32 )
177
+ for x in range (self .num_timesteps ):
178
+ t = (x + 1 ) / self .num_timesteps
179
+ sigmas [x ] = self .sigma (t )
180
+
181
+ self .set_sigmas (sigmas )
182
+
183
+ def sigma (self , timestep ):
184
+ alpha_cumprod = (torch .cos ((timestep + self .cosine_s ) / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2 / self ._init_alpha_cumprod )
185
+
186
+ if self .shift != 1.0 :
187
+ var = alpha_cumprod
188
+ logSNR = (var / (1 - var )).log ()
189
+ logSNR += 2 * torch .log (1.0 / torch .tensor (self .shift ))
190
+ alpha_cumprod = logSNR .sigmoid ()
191
+
192
+ alpha_cumprod = alpha_cumprod .clamp (0.0001 , 0.9999 )
193
+ return ((1 - alpha_cumprod ) / alpha_cumprod ) ** 0.5
194
+
195
+ def timestep (self , sigma ):
196
+ var = 1 / ((sigma * sigma ) + 1 )
197
+ var = var .clamp (0 , 1.0 )
198
+ s , min_var = self .cosine_s .to (var .device ), self ._init_alpha_cumprod .to (var .device )
199
+ t = (((var * min_var ) ** 0.5 ).acos () / (torch .pi * 0.5 )) * (1 + s ) - s
200
+ return t
201
+
202
+ def percent_to_sigma (self , percent ):
203
+ if percent <= 0.0 :
204
+ return 999999999.9
205
+ if percent >= 1.0 :
206
+ return 0.0
207
+
208
+ percent = 1.0 - percent
209
+ return self .sigma (torch .tensor (percent ))
0 commit comments