Skip to content

Commit 6d47547

Browse files
committed
Added limited support for non-SDXL models
1 parent 1422719 commit 6d47547

8 files changed

+65
-18
lines changed

fooocus_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
version = '2.0.78.2 MRE'
1+
version = '2.0.78.3 MRE'
22
full_version = 'Fooocus ' + version

launch.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def prepare_environment():
8585

8686
vae_approx_filenames = [
8787
('xlvaeapp.pth',
88-
'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth')
88+
'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
89+
('taesd_decoder.pth',
90+
'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth')
8991
]
9092

9193

modules/async_worker.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ def handler(task):
310310
s1=freeu_s1,
311311
s2=freeu_s2)
312312

313+
is_sdxl = pipeline.is_base_sdxl()
314+
if not is_sdxl:
315+
print('WARNING: using non-SDXL base model (supported in limited scope).')
316+
control_lora_canny = False
317+
control_lora_depth = False
318+
revision_mode = False
319+
313320
pipeline.set_clip_skips(base_clip_skip, refiner_clip_skip)
314321
if revision_mode:
315322
pipeline.refresh_clip_vision()
@@ -456,10 +463,12 @@ def callback(step, x0, x, total_steps, y):
456463
input_image = None
457464
if input_image_path != None:
458465
img2img_megapixels = width * height * img2img_scale ** 2 / 2**20
459-
if img2img_megapixels < constants.MIN_MEGAPIXELS:
460-
img2img_megapixels = constants.MIN_MEGAPIXELS
461-
elif img2img_megapixels > constants.MAX_MEGAPIXELS:
462-
img2img_megapixels = constants.MAX_MEGAPIXELS
466+
min_mp = constants.MIN_MEGAPIXELS if is_sdxl else constants.MIN_MEGAPIXELS_SD
467+
max_mp = constants.MAX_MEGAPIXELS if is_sdxl else constants.MAX_MEGAPIXELS_SD
468+
if img2img_megapixels < min_mp:
469+
img2img_megapixels = min_mp
470+
elif img2img_megapixels > max_mp:
471+
img2img_megapixels = max_mp
463472
input_image = get_image(input_image_path, img2img_megapixels)
464473

465474
try:

modules/constants.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
# exclusive, needed by modules\expansion.py -> transformers\trainer_utils.py -> np.random.seed()
1010
SEED_LIMIT_NUMPY = 2**32
1111

12-
# min - native SDXL resolution, max - determined by SDXL context size (2048)
12+
# min - native SDXL resolution (1024x1024), max - determined by SDXL context size (2048)
1313
MIN_MEGAPIXELS = 1.0
1414
MAX_MEGAPIXELS = 4.0
15+
16+
# min - native SD 1.5 resolution (512x512), max - determined by SD 2.x context size (1024)
17+
MIN_MEGAPIXELS_SD = 0.25
18+
MAX_MEGAPIXELS_SD = 1.0

modules/core.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels
2020
from comfy_extras.nodes_canny import Canny
2121
from comfy_extras.nodes_freelunch import FreeU
22-
from comfy.model_base import SDXLRefiner
22+
from comfy.model_base import SDXL, SDXLRefiner
2323
from comfy.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
2424
from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
2525
from modules.path import embeddings_path
@@ -236,14 +236,15 @@ def forward(self, x):
236236

237237

238238
VAE_approx_model = None
239+
taesd = None
239240

240241

241242
@torch.no_grad()
242243
@torch.inference_mode()
243-
def get_previewer(device, latent_format):
244-
global VAE_approx_model
244+
def get_previewer(device, latent_format, is_sdxl=True):
245+
global VAE_approx_model, taesd
245246

246-
if VAE_approx_model is None:
247+
if VAE_approx_model is None and is_sdxl:
247248
from modules.path import vae_approx_path
248249
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth')
249250
sd = torch.load(vae_approx_filename, map_location='cpu')
@@ -271,8 +272,26 @@ def preview_function(x0, step, total_steps):
271272
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
272273
return x_sample
273274

274-
return preview_function
275+
if taesd is None and not is_sdxl:
276+
from latent_preview import TAESD, TAESDPreviewerImpl
277+
taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx", latent_format.taesd_decoder_name)))
275278

279+
if not os.path.exists(taesd_decoder_path):
280+
print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}")
281+
return None
282+
283+
taesd = TAESD(None, taesd_decoder_path).to(device)
284+
285+
@torch.no_grad()
286+
@torch.inference_mode()
287+
def preview_function_sd(x0, step, total_steps):
288+
with torch.no_grad():
289+
x_sample = taesd.decoder(torch.nn.functional.avg_pool2d(x0, kernel_size=(2, 2))).detach() * 255.0
290+
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')
291+
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
292+
return x_sample[0]
293+
294+
return preview_function if is_sdxl else preview_function_sd
276295

277296
@torch.no_grad()
278297
@torch.inference_mode()
@@ -299,7 +318,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
299318
if "noise_mask" in latent:
300319
noise_mask = latent["noise_mask"]
301320

302-
previewer = get_previewer(device, model.model.latent_format)
321+
previewer = get_previewer(device, model.model.latent_format, isinstance(model.model, SDXL))
303322

304323
pbar = comfy.utils.ProgressBar(steps)
305324

@@ -372,7 +391,7 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive,
372391
if "noise_mask" in latent:
373392
noise_mask = latent["noise_mask"]
374393

375-
previewer = get_previewer(device, model.model.latent_format)
394+
previewer = get_previewer(device, model.model.latent_format, isinstance(model.model, SDXL))
376395

377396
pbar = comfy.utils.ProgressBar(steps)
378397

modules/default_pipeline.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import modules.virtual_memory as virtual_memory
88
import comfy.model_management
99

10-
from comfy.model_base import SDXL, SDXLRefiner
10+
from comfy.model_base import BaseModel, SDXL, SDXLRefiner
1111
from modules.settings import default_settings
1212
from modules.patch import set_comfy_adm_encoding, set_fooocus_adm_encoding, cfg_patched, patched_model_function
1313
from modules.expansion import FooocusExpansion
@@ -48,8 +48,8 @@ def refresh_base_model(name):
4848
xl_base = None
4949

5050
xl_base = core.load_model(filename)
51-
if not isinstance(xl_base.unet.model, SDXL):
52-
print('Model not supported. Fooocus only support SDXL model as the base model.')
51+
if not isinstance(xl_base.unet.model, BaseModel):
52+
print(f'Model not supported: {name}, using default base model instead.')
5353
xl_base = None
5454
xl_base_hash = ''
5555
refresh_base_model(modules.path.default_base_model_name)
@@ -58,13 +58,21 @@ def refresh_base_model(name):
5858
xl_base_patched_hash = ''
5959
return
6060

61+
if not isinstance(xl_base.unet.model, SDXL):
62+
print('WARNING: loading non-SDXL base model.')
63+
6164
xl_base_hash = model_hash
6265
xl_base_patched = xl_base
6366
xl_base_patched_hash = ''
6467
print(f'Base model loaded: {model_hash}')
6568
return
6669

6770

71+
def is_base_sdxl():
72+
assert xl_base is not None
73+
return isinstance(xl_base.unet.model, SDXL)
74+
75+
6876
@torch.no_grad()
6977
@torch.inference_mode()
7078
def refresh_refiner_model(name):
@@ -369,7 +377,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
369377
positive_conditions, negative_conditions = core.apply_controlnet(positive_conditions, negative_conditions,
370378
controlnet_depth, input_image, depth_strength, depth_start, depth_stop)
371379

372-
if xl_refiner is not None:
380+
if xl_refiner is not None and is_base_sdxl():
373381
positive_conditions_refiner = positive_cond[1]
374382
negative_conditions_refiner = negative_cond[1]
375383

readme.md

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Below things are already inside the software, and **users do not need to do anyt
190190
25. Support for authentication in --share mode (credentials loaded from auth.json - use auth-example.json as a template).
191191
26. Support for wildcards (ported from RuinedFooocus - put them in wildcards folder, then try prompts like `__color__ sports car` with different seeds).
192192
27. Support for [FreeU](https://chenyangsi.top/FreeU/).
193+
28. Limited support for non-SDXL models (no refiner, Control-LoRAs, Revision, inpainting, outpainting).
193194

194195
## Thanks
195196

update_log_mre.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
### 2.0.78.3 MRE
2+
3+
* Added limited support for non-SDXL models (no refiner, Control-LoRAs, Revision, inpainting, outpainting).
4+
15
### 2.0.78.2 MRE
26

37
* Added support for FreeU.

0 commit comments

Comments
 (0)