1
1
import torch
2
- import numpy as np
3
2
from ldm_patched .ldm .modules .diffusionmodules .util import make_beta_schedule
4
3
import math
5
4
@@ -12,12 +11,28 @@ def calculate_denoised(self, sigma, model_output, model_input):
12
11
sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
13
12
return model_input - model_output * sigma
14
13
14
+ def noise_scaling (self , sigma , noise , latent_image , max_denoise = False ):
15
+ if max_denoise :
16
+ noise = noise * torch .sqrt (1.0 + sigma ** 2.0 )
17
+ else :
18
+ noise = noise * sigma
19
+
20
+ noise += latent_image
21
+ return noise
22
+
23
+ def inverse_noise_scaling (self , sigma , latent ):
24
+ return latent
15
25
16
26
class V_PREDICTION (EPS ):
17
27
def calculate_denoised (self , sigma , model_output , model_input ):
18
28
sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
19
29
return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) - model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
20
30
31
+ class EDM (V_PREDICTION ):
32
+ def calculate_denoised (self , sigma , model_output , model_input ):
33
+ sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
34
+ return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) + model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
35
+
21
36
22
37
class ModelSamplingDiscrete (torch .nn .Module ):
23
38
def __init__ (self , model_config = None ):
@@ -42,24 +57,23 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
42
57
else :
43
58
betas = make_beta_schedule (beta_schedule , timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = cosine_s )
44
59
alphas = 1. - betas
45
- alphas_cumprod = torch .tensor (np .cumprod (alphas , axis = 0 ), dtype = torch .float32 )
46
- # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
60
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
47
61
48
62
timesteps , = betas .shape
49
63
self .num_timesteps = int (timesteps )
50
64
self .linear_start = linear_start
51
65
self .linear_end = linear_end
52
66
67
+ # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
68
+ # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
69
+ # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
70
+
53
71
sigmas = ((1 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
54
72
self .set_sigmas (sigmas )
55
- self .set_alphas_cumprod (alphas_cumprod .float ())
56
73
57
74
def set_sigmas (self , sigmas ):
58
- self .register_buffer ('sigmas' , sigmas )
59
- self .register_buffer ('log_sigmas' , sigmas .log ())
60
-
61
- def set_alphas_cumprod (self , alphas_cumprod ):
62
- self .register_buffer ("alphas_cumprod" , alphas_cumprod .float ())
75
+ self .register_buffer ('sigmas' , sigmas .float ())
76
+ self .register_buffer ('log_sigmas' , sigmas .log ().float ())
63
77
64
78
@property
65
79
def sigma_min (self ):
@@ -94,18 +108,18 @@ def percent_to_sigma(self, percent):
94
108
class ModelSamplingContinuousEDM (torch .nn .Module ):
95
109
def __init__ (self , model_config = None ):
96
110
super ().__init__ ()
97
- self .sigma_data = 1.0
98
-
99
111
if model_config is not None :
100
112
sampling_settings = model_config .sampling_settings
101
113
else :
102
114
sampling_settings = {}
103
115
104
116
sigma_min = sampling_settings .get ("sigma_min" , 0.002 )
105
117
sigma_max = sampling_settings .get ("sigma_max" , 120.0 )
106
- self .set_sigma_range (sigma_min , sigma_max )
118
+ sigma_data = sampling_settings .get ("sigma_data" , 1.0 )
119
+ self .set_parameters (sigma_min , sigma_max , sigma_data )
107
120
108
- def set_sigma_range (self , sigma_min , sigma_max ):
121
+ def set_parameters (self , sigma_min , sigma_max , sigma_data ):
122
+ self .sigma_data = sigma_data
109
123
sigmas = torch .linspace (math .log (sigma_min ), math .log (sigma_max ), 1000 ).exp ()
110
124
111
125
self .register_buffer ('sigmas' , sigmas ) #for compatibility with some schedulers
@@ -134,3 +148,56 @@ def percent_to_sigma(self, percent):
134
148
135
149
log_sigma_min = math .log (self .sigma_min )
136
150
return math .exp ((math .log (self .sigma_max ) - log_sigma_min ) * percent + log_sigma_min )
151
+
152
+ class StableCascadeSampling (ModelSamplingDiscrete ):
153
+ def __init__ (self , model_config = None ):
154
+ super ().__init__ ()
155
+
156
+ if model_config is not None :
157
+ sampling_settings = model_config .sampling_settings
158
+ else :
159
+ sampling_settings = {}
160
+
161
+ self .set_parameters (sampling_settings .get ("shift" , 1.0 ))
162
+
163
+ def set_parameters (self , shift = 1.0 , cosine_s = 8e-3 ):
164
+ self .shift = shift
165
+ self .cosine_s = torch .tensor (cosine_s )
166
+ self ._init_alpha_cumprod = torch .cos (self .cosine_s / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2
167
+
168
+ #This part is just for compatibility with some schedulers in the codebase
169
+ self .num_timesteps = 10000
170
+ sigmas = torch .empty ((self .num_timesteps ), dtype = torch .float32 )
171
+ for x in range (self .num_timesteps ):
172
+ t = (x + 1 ) / self .num_timesteps
173
+ sigmas [x ] = self .sigma (t )
174
+
175
+ self .set_sigmas (sigmas )
176
+
177
+ def sigma (self , timestep ):
178
+ alpha_cumprod = (torch .cos ((timestep + self .cosine_s ) / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2 / self ._init_alpha_cumprod )
179
+
180
+ if self .shift != 1.0 :
181
+ var = alpha_cumprod
182
+ logSNR = (var / (1 - var )).log ()
183
+ logSNR += 2 * torch .log (1.0 / torch .tensor (self .shift ))
184
+ alpha_cumprod = logSNR .sigmoid ()
185
+
186
+ alpha_cumprod = alpha_cumprod .clamp (0.0001 , 0.9999 )
187
+ return ((1 - alpha_cumprod ) / alpha_cumprod ) ** 0.5
188
+
189
+ def timestep (self , sigma ):
190
+ var = 1 / ((sigma * sigma ) + 1 )
191
+ var = var .clamp (0 , 1.0 )
192
+ s , min_var = self .cosine_s .to (var .device ), self ._init_alpha_cumprod .to (var .device )
193
+ t = (((var * min_var ) ** 0.5 ).acos () / (torch .pi * 0.5 )) * (1 + s ) - s
194
+ return t
195
+
196
+ def percent_to_sigma (self , percent ):
197
+ if percent <= 0.0 :
198
+ return 999999999.9
199
+ if percent >= 1.0 :
200
+ return 0.0
201
+
202
+ percent = 1.0 - percent
203
+ return self .sigma (torch .tensor (percent ))
0 commit comments