Skip to content

Commit 9aac21f

Browse files
Fix issues with new hunyuan img2vid model and bumb version to v0.3.26
1 parent 528d1b3 commit 9aac21f

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

comfy/ldm/flux/layers.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,20 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
159159
)
160160
self.flipped_img_txt = flipped_img_txt
161161

162-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None):
162+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
163163
img_mod1, img_mod2 = self.img_mod(vec)
164164
txt_mod1, txt_mod2 = self.txt_mod(vec)
165165

166166
# prepare image for attention
167167
img_modulated = self.img_norm1(img)
168-
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims)
168+
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
169169
img_qkv = self.img_attn.qkv(img_modulated)
170170
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
171171
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
172172

173173
# prepare txt for attention
174174
txt_modulated = self.txt_norm1(txt)
175-
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims)
175+
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
176176
txt_qkv = self.txt_attn.qkv(txt_modulated)
177177
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
178178
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -195,12 +195,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
195195
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
196196

197197
# calculate the img bloks
198-
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims)
199-
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims)
198+
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
199+
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
200200

201201
# calculate the txt bloks
202-
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims)
203-
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims)
202+
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
203+
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
204204

205205
if txt.dtype == torch.float16:
206206
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)

comfy/ldm/hunyuan_video/model.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,11 @@ def forward_orig(
244244
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
245245
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
246246
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
247+
modulation_dims_txt = [(0, None, 1)]
247248
else:
248249
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
249250
modulation_dims = None
251+
modulation_dims_txt = None
250252

251253
if self.params.guidance_embed:
252254
if guidance is not None:
@@ -273,14 +275,14 @@ def forward_orig(
273275
if ("double_block", i) in blocks_replace:
274276
def block_wrap(args):
275277
out = {}
276-
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
278+
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
277279
return out
278280

279-
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
281+
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
280282
txt = out["txt"]
281283
img = out["img"]
282284
else:
283-
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
285+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
284286

285287
if control is not None: # Controlnet
286288
control_i = control.get("input")
@@ -295,10 +297,10 @@ def block_wrap(args):
295297
if ("single_block", i) in blocks_replace:
296298
def block_wrap(args):
297299
out = {}
298-
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
300+
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
299301
return out
300302

301-
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
303+
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
302304
img = out["img"]
303305
else:
304306
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)

comfyui_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# This file is automatically generated by the build process when version is
22
# updated in pyproject.toml.
3-
__version__ = "0.3.25"
3+
__version__ = "0.3.26"

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ComfyUI"
3-
version = "0.3.25"
3+
version = "0.3.26"
44
readme = "README.md"
55
license = { file = "LICENSE" }
66
requires-python = ">=3.9"

0 commit comments

Comments
 (0)