35
35
import comfy .cldm .mmdit
36
36
import comfy .ldm .hydit .controlnet
37
37
import comfy .ldm .flux .controlnet
38
-
38
+ import comfy . cldm . dit_embedder
39
39
40
40
def broadcast_image_to (tensor , target_batch_size , batched_number ):
41
41
current_batch_size = tensor .shape [0 ]
@@ -78,6 +78,7 @@ def __init__(self):
78
78
self .concat_mask = False
79
79
self .extra_concat_orig = []
80
80
self .extra_concat = None
81
+ self .preprocess_image = lambda a : a
81
82
82
83
def set_cond_hint (self , cond_hint , strength = 1.0 , timestep_percent_range = (0.0 , 1.0 ), vae = None , extra_concat = []):
83
84
self .cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ def copy_to(self, c):
129
130
c .strength_type = self .strength_type
130
131
c .concat_mask = self .concat_mask
131
132
c .extra_concat_orig = self .extra_concat_orig .copy ()
133
+ c .preprocess_image = self .preprocess_image
132
134
133
135
def inference_memory_requirements (self , dtype ):
134
136
if self .previous_controlnet is not None :
@@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None):
181
183
182
184
183
185
class ControlNet (ControlBase ):
184
- def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False ):
186
+ def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False , preprocess_image = lambda a : a ):
185
187
super ().__init__ ()
186
188
self .control_model = control_model
187
189
self .load_device = load_device
@@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
196
198
self .extra_conds += extra_conds
197
199
self .strength_type = strength_type
198
200
self .concat_mask = concat_mask
201
+ self .preprocess_image = preprocess_image
199
202
200
203
def get_control (self , x_noisy , t , cond , batched_number ):
201
204
control_prev = None
@@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
224
227
if self .latent_format is not None :
225
228
raise ValueError ("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it." )
226
229
self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , self .upscale_algorithm , "center" )
230
+ self .cond_hint = self .preprocess_image (self .cond_hint )
227
231
if self .vae is not None :
228
232
loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
229
233
self .cond_hint = self .vae .encode (self .cond_hint .movedim (1 , - 1 ))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
427
431
logging .debug ("unexpected controlnet keys: {}" .format (unexpected ))
428
432
return control_model
429
433
434
+
430
435
def load_controlnet_mmdit (sd , model_options = {}):
431
436
new_sd = comfy .model_detection .convert_diffusers_mmdit (sd , "" )
432
437
model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd , model_options = model_options )
@@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
448
453
return control
449
454
450
455
456
+ class ControlNetSD35 (ControlNet ):
457
+ def pre_run (self , model , percent_to_timestep_function ):
458
+ if self .control_model .double_y_emb :
459
+ missing , unexpected = self .control_model .orig_y_embedder .load_state_dict (model .diffusion_model .y_embedder .state_dict (), strict = False )
460
+ else :
461
+ missing , unexpected = self .control_model .x_embedder .load_state_dict (model .diffusion_model .x_embedder .state_dict (), strict = False )
462
+ super ().pre_run (model , percent_to_timestep_function )
463
+
464
+ def copy (self ):
465
+ c = ControlNetSD35 (None , global_average_pooling = self .global_average_pooling , load_device = self .load_device , manual_cast_dtype = self .manual_cast_dtype )
466
+ c .control_model = self .control_model
467
+ c .control_model_wrapped = self .control_model_wrapped
468
+ self .copy_to (c )
469
+ return c
470
+
471
+ def load_controlnet_sd35 (sd , model_options = {}):
472
+ control_type = - 1
473
+ if "control_type" in sd :
474
+ control_type = round (sd .pop ("control_type" ).item ())
475
+
476
+ # blur_cnet = control_type == 0
477
+ canny_cnet = control_type == 1
478
+ depth_cnet = control_type == 2
479
+
480
+ print (control_type , canny_cnet , depth_cnet )
481
+ new_sd = {}
482
+ for k in comfy .utils .MMDIT_MAP_BASIC :
483
+ if k [1 ] in sd :
484
+ new_sd [k [0 ]] = sd .pop (k [1 ])
485
+ for k in sd :
486
+ new_sd [k ] = sd [k ]
487
+ sd = new_sd
488
+
489
+ y_emb_shape = sd ["y_embedder.mlp.0.weight" ].shape
490
+ depth = y_emb_shape [0 ] // 64
491
+ hidden_size = 64 * depth
492
+ num_heads = depth
493
+ head_dim = hidden_size // num_heads
494
+ num_blocks = comfy .model_detection .count_blocks (new_sd , 'transformer_blocks.{}.' )
495
+
496
+ load_device = comfy .model_management .get_torch_device ()
497
+ offload_device = comfy .model_management .unet_offload_device ()
498
+ unet_dtype = comfy .model_management .unet_dtype (model_params = - 1 )
499
+
500
+ manual_cast_dtype = comfy .model_management .unet_manual_cast (unet_dtype , load_device )
501
+
502
+ operations = model_options .get ("custom_operations" , None )
503
+ if operations is None :
504
+ operations = comfy .ops .pick_operations (unet_dtype , manual_cast_dtype , disable_fast_fp8 = True )
505
+
506
+ control_model = comfy .cldm .dit_embedder .ControlNetEmbedder (img_size = None ,
507
+ patch_size = 2 ,
508
+ in_chans = 16 ,
509
+ num_layers = num_blocks ,
510
+ main_model_double = depth ,
511
+ double_y_emb = y_emb_shape [0 ] == y_emb_shape [1 ],
512
+ attention_head_dim = head_dim ,
513
+ num_attention_heads = num_heads ,
514
+ adm_in_channels = 2048 ,
515
+ device = offload_device ,
516
+ dtype = unet_dtype ,
517
+ operations = operations )
518
+
519
+ control_model = controlnet_load_state_dict (control_model , sd )
520
+
521
+ latent_format = comfy .latent_formats .SD3 ()
522
+ preprocess_image = lambda a : a
523
+ if canny_cnet :
524
+ preprocess_image = lambda a : (a * 255 * 0.5 + 0.5 )
525
+ elif depth_cnet :
526
+ preprocess_image = lambda a : 1.0 - a
527
+
528
+ control = ControlNetSD35 (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , preprocess_image = preprocess_image )
529
+ return control
530
+
531
+
532
+
451
533
def load_controlnet_hunyuandit (controlnet_data , model_options = {}):
452
534
model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (controlnet_data , model_options = model_options )
453
535
@@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
560
642
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data :
561
643
return load_controlnet_flux_xlabs_mistoline (controlnet_data , model_options = model_options )
562
644
elif "pos_embed_input.proj.weight" in controlnet_data :
563
- return load_controlnet_mmdit (controlnet_data , model_options = model_options ) #SD3 diffusers controlnet
645
+ if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data :
646
+ return load_controlnet_sd35 (controlnet_data , model_options = model_options ) #Stability sd3.5 format
647
+ else :
648
+ return load_controlnet_mmdit (controlnet_data , model_options = model_options ) #SD3 diffusers controlnet
564
649
elif "controlnet_x_embedder.weight" in controlnet_data :
565
650
return load_controlnet_flux_instantx (controlnet_data , model_options = model_options )
566
651
elif "controlnet_blocks.0.linear.weight" in controlnet_data : #mistoline flux
0 commit comments