Skip to content

Commit 4c82741

Browse files
Support official SD3.5 Controlnets.
1 parent 15c39ea commit 4c82741

File tree

2 files changed

+210
-3
lines changed

2 files changed

+210
-3
lines changed

comfy/cldm/dit_embedder.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import math
2+
from typing import List, Optional, Tuple
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
from einops import rearrange
8+
from torch import Tensor
9+
10+
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
11+
12+
13+
class ControlNetEmbedder(nn.Module):
14+
15+
def __init__(
16+
self,
17+
img_size: int,
18+
patch_size: int,
19+
in_chans: int,
20+
attention_head_dim: int,
21+
num_attention_heads: int,
22+
adm_in_channels: int,
23+
num_layers: int,
24+
main_model_double: int,
25+
double_y_emb: bool,
26+
device: torch.device,
27+
dtype: torch.dtype,
28+
pos_embed_max_size: Optional[int] = None,
29+
operations = None,
30+
):
31+
super().__init__()
32+
self.main_model_double = main_model_double
33+
self.dtype = dtype
34+
self.hidden_size = num_attention_heads * attention_head_dim
35+
self.patch_size = patch_size
36+
self.x_embedder = PatchEmbed(
37+
img_size=img_size,
38+
patch_size=patch_size,
39+
in_chans=in_chans,
40+
embed_dim=self.hidden_size,
41+
strict_img_size=pos_embed_max_size is None,
42+
device=device,
43+
dtype=dtype,
44+
operations=operations,
45+
)
46+
47+
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
48+
49+
self.double_y_emb = double_y_emb
50+
if self.double_y_emb:
51+
self.orig_y_embedder = VectorEmbedder(
52+
adm_in_channels, self.hidden_size, dtype, device, operations=operations
53+
)
54+
self.y_embedder = VectorEmbedder(
55+
self.hidden_size, self.hidden_size, dtype, device, operations=operations
56+
)
57+
else:
58+
self.y_embedder = VectorEmbedder(
59+
adm_in_channels, self.hidden_size, dtype, device, operations=operations
60+
)
61+
62+
self.transformer_blocks = nn.ModuleList(
63+
DismantledBlock(
64+
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
65+
dtype=dtype, device=device, operations=operations
66+
)
67+
for _ in range(num_layers)
68+
)
69+
70+
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
71+
# TODO double check this logic when 8b
72+
self.use_y_embedder = True
73+
74+
self.controlnet_blocks = nn.ModuleList([])
75+
for _ in range(len(self.transformer_blocks)):
76+
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
77+
self.controlnet_blocks.append(controlnet_block)
78+
79+
self.pos_embed_input = PatchEmbed(
80+
img_size=img_size,
81+
patch_size=patch_size,
82+
in_chans=in_chans,
83+
embed_dim=self.hidden_size,
84+
strict_img_size=False,
85+
device=device,
86+
dtype=dtype,
87+
operations=operations,
88+
)
89+
90+
def forward(
91+
self,
92+
x: torch.Tensor,
93+
timesteps: torch.Tensor,
94+
y: Optional[torch.Tensor] = None,
95+
context: Optional[torch.Tensor] = None,
96+
hint = None,
97+
) -> Tuple[Tensor, List[Tensor]]:
98+
x_shape = list(x.shape)
99+
x = self.x_embedder(x)
100+
if not self.double_y_emb:
101+
h = (x_shape[-2] + 1) // self.patch_size
102+
w = (x_shape[-1] + 1) // self.patch_size
103+
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
104+
c = self.t_embedder(timesteps, dtype=x.dtype)
105+
if y is not None and self.y_embedder is not None:
106+
if self.double_y_emb:
107+
y = self.orig_y_embedder(y)
108+
y = self.y_embedder(y)
109+
c = c + y
110+
111+
x = x + self.pos_embed_input(hint)
112+
113+
block_out = ()
114+
115+
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
116+
for i in range(len(self.transformer_blocks)):
117+
out = self.transformer_blocks[i](x, c)
118+
if not self.double_y_emb:
119+
x = out
120+
block_out += (self.controlnet_blocks[i](out),) * repeat
121+
122+
return {"output": block_out}

comfy/controlnet.py

+88-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import comfy.cldm.mmdit
3636
import comfy.ldm.hydit.controlnet
3737
import comfy.ldm.flux.controlnet
38-
38+
import comfy.cldm.dit_embedder
3939

4040
def broadcast_image_to(tensor, target_batch_size, batched_number):
4141
current_batch_size = tensor.shape[0]
@@ -78,6 +78,7 @@ def __init__(self):
7878
self.concat_mask = False
7979
self.extra_concat_orig = []
8080
self.extra_concat = None
81+
self.preprocess_image = lambda a: a
8182

8283
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
8384
self.cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ def copy_to(self, c):
129130
c.strength_type = self.strength_type
130131
c.concat_mask = self.concat_mask
131132
c.extra_concat_orig = self.extra_concat_orig.copy()
133+
c.preprocess_image = self.preprocess_image
132134

133135
def inference_memory_requirements(self, dtype):
134136
if self.previous_controlnet is not None:
@@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None):
181183

182184

183185
class ControlNet(ControlBase):
184-
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
186+
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
185187
super().__init__()
186188
self.control_model = control_model
187189
self.load_device = load_device
@@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
196198
self.extra_conds += extra_conds
197199
self.strength_type = strength_type
198200
self.concat_mask = concat_mask
201+
self.preprocess_image = preprocess_image
199202

200203
def get_control(self, x_noisy, t, cond, batched_number):
201204
control_prev = None
@@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
224227
if self.latent_format is not None:
225228
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
226229
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
230+
self.cond_hint = self.preprocess_image(self.cond_hint)
227231
if self.vae is not None:
228232
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
229233
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
427431
logging.debug("unexpected controlnet keys: {}".format(unexpected))
428432
return control_model
429433

434+
430435
def load_controlnet_mmdit(sd, model_options={}):
431436
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
432437
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
448453
return control
449454

450455

456+
class ControlNetSD35(ControlNet):
457+
def pre_run(self, model, percent_to_timestep_function):
458+
if self.control_model.double_y_emb:
459+
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
460+
else:
461+
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
462+
super().pre_run(model, percent_to_timestep_function)
463+
464+
def copy(self):
465+
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
466+
c.control_model = self.control_model
467+
c.control_model_wrapped = self.control_model_wrapped
468+
self.copy_to(c)
469+
return c
470+
471+
def load_controlnet_sd35(sd, model_options={}):
472+
control_type = -1
473+
if "control_type" in sd:
474+
control_type = round(sd.pop("control_type").item())
475+
476+
# blur_cnet = control_type == 0
477+
canny_cnet = control_type == 1
478+
depth_cnet = control_type == 2
479+
480+
print(control_type, canny_cnet, depth_cnet)
481+
new_sd = {}
482+
for k in comfy.utils.MMDIT_MAP_BASIC:
483+
if k[1] in sd:
484+
new_sd[k[0]] = sd.pop(k[1])
485+
for k in sd:
486+
new_sd[k] = sd[k]
487+
sd = new_sd
488+
489+
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
490+
depth = y_emb_shape[0] // 64
491+
hidden_size = 64 * depth
492+
num_heads = depth
493+
head_dim = hidden_size // num_heads
494+
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
495+
496+
load_device = comfy.model_management.get_torch_device()
497+
offload_device = comfy.model_management.unet_offload_device()
498+
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
499+
500+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
501+
502+
operations = model_options.get("custom_operations", None)
503+
if operations is None:
504+
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
505+
506+
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
507+
patch_size=2,
508+
in_chans=16,
509+
num_layers=num_blocks,
510+
main_model_double=depth,
511+
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
512+
attention_head_dim=head_dim,
513+
num_attention_heads=num_heads,
514+
adm_in_channels=2048,
515+
device=offload_device,
516+
dtype=unet_dtype,
517+
operations=operations)
518+
519+
control_model = controlnet_load_state_dict(control_model, sd)
520+
521+
latent_format = comfy.latent_formats.SD3()
522+
preprocess_image = lambda a: a
523+
if canny_cnet:
524+
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
525+
elif depth_cnet:
526+
preprocess_image = lambda a: 1.0 - a
527+
528+
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
529+
return control
530+
531+
532+
451533
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
452534
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
453535

@@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
560642
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
561643
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
562644
elif "pos_embed_input.proj.weight" in controlnet_data:
563-
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
645+
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
646+
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
647+
else:
648+
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
564649
elif "controlnet_x_embedder.weight" in controlnet_data:
565650
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
566651
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux

0 commit comments

Comments
 (0)