@@ -159,20 +159,20 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
159
159
)
160
160
self .flipped_img_txt = flipped_img_txt
161
161
162
- def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims = None ):
162
+ def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims_img = None , modulation_dims_txt = None ):
163
163
img_mod1 , img_mod2 = self .img_mod (vec )
164
164
txt_mod1 , txt_mod2 = self .txt_mod (vec )
165
165
166
166
# prepare image for attention
167
167
img_modulated = self .img_norm1 (img )
168
- img_modulated = apply_mod (img_modulated , (1 + img_mod1 .scale ), img_mod1 .shift , modulation_dims )
168
+ img_modulated = apply_mod (img_modulated , (1 + img_mod1 .scale ), img_mod1 .shift , modulation_dims_img )
169
169
img_qkv = self .img_attn .qkv (img_modulated )
170
170
img_q , img_k , img_v = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
171
171
img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
172
172
173
173
# prepare txt for attention
174
174
txt_modulated = self .txt_norm1 (txt )
175
- txt_modulated = apply_mod (txt_modulated , (1 + txt_mod1 .scale ), txt_mod1 .shift , modulation_dims )
175
+ txt_modulated = apply_mod (txt_modulated , (1 + txt_mod1 .scale ), txt_mod1 .shift , modulation_dims_txt )
176
176
txt_qkv = self .txt_attn .qkv (txt_modulated )
177
177
txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
178
178
txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
@@ -195,12 +195,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
195
195
txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
196
196
197
197
# calculate the img bloks
198
- img = img + apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims )
199
- img = img + apply_mod (self .img_mlp (apply_mod (self .img_norm2 (img ), (1 + img_mod2 .scale ), img_mod2 .shift , modulation_dims )), img_mod2 .gate , None , modulation_dims )
198
+ img = img + apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims_img )
199
+ img = img + apply_mod (self .img_mlp (apply_mod (self .img_norm2 (img ), (1 + img_mod2 .scale ), img_mod2 .shift , modulation_dims_img )), img_mod2 .gate , None , modulation_dims_img )
200
200
201
201
# calculate the txt bloks
202
- txt += apply_mod (self .txt_attn .proj (txt_attn ), txt_mod1 .gate , None , modulation_dims )
203
- txt += apply_mod (self .txt_mlp (apply_mod (self .txt_norm2 (txt ), (1 + txt_mod2 .scale ), txt_mod2 .shift , modulation_dims )), txt_mod2 .gate , None , modulation_dims )
202
+ txt += apply_mod (self .txt_attn .proj (txt_attn ), txt_mod1 .gate , None , modulation_dims_txt )
203
+ txt += apply_mod (self .txt_mlp (apply_mod (self .txt_norm2 (txt ), (1 + txt_mod2 .scale ), txt_mod2 .shift , modulation_dims_txt )), txt_mod2 .gate , None , modulation_dims_txt )
204
204
205
205
if txt .dtype == torch .float16 :
206
206
txt = torch .nan_to_num (txt , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
0 commit comments