@@ -210,6 +210,7 @@ def __init__(
210
210
upcast_softmax : bool = False ,
211
211
cross_attention_norm : Optional [str ] = None ,
212
212
cross_attention_norm_num_groups : int = 32 ,
213
+ qk_norm : Optional [str ] = None ,
213
214
added_kv_proj_dim : Optional [int ] = None ,
214
215
norm_num_groups : Optional [int ] = None ,
215
216
spatial_norm_dim : Optional [int ] = None ,
@@ -223,7 +224,7 @@ def __init__(
223
224
processor : Optional ["AttnProcessor2_0" ] = None ,
224
225
out_dim : int = None ,
225
226
):
226
- super (CausalAttention , self ).__init__ (query_dim , cross_attention_dim , heads , dim_head , dropout , bias , upcast_attention , upcast_softmax , cross_attention_norm , cross_attention_norm_num_groups ,
227
+ super (CausalAttention , self ).__init__ (query_dim , cross_attention_dim , heads , dim_head , dropout , bias , upcast_attention , upcast_softmax , cross_attention_norm , cross_attention_norm_num_groups , qk_norm ,
227
228
added_kv_proj_dim , norm_num_groups , spatial_norm_dim , out_bias , scale_qk , only_cross_attention , eps , rescale_output_factor , residual_connection , _from_deprecated_attn_block , processor , out_dim )
228
229
processor = CausalAttnProcessor2_0 ()
229
230
self .set_processor (processor )
@@ -505,7 +506,7 @@ def initialize_weights(self):
505
506
if m .bias is not None :
506
507
nn .init .constant_ (m .bias , 0 )
507
508
508
- def forward (self , x , mask , mu , t , spks = None , cond = None ):
509
+ def forward (self , x , mask , mu , t , spks = None , cond = None , streaming = False ):
509
510
"""Forward pass of the UNet1DConditional model.
510
511
511
512
Args:
@@ -540,7 +541,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
540
541
mask_down = masks [- 1 ]
541
542
x = resnet (x , mask_down , t )
542
543
x = rearrange (x , "b c t -> b t c" ).contiguous ()
543
- attn_mask = ( torch . matmul ( mask_down .transpose ( 1 , 2 ). contiguous (), mask_down ) == 1 )
544
+ attn_mask = add_optional_chunk_mask ( x , mask_down .bool (), False , False , 0 , 0 , - 1 ). repeat ( 1 , x . size ( 1 ), 1 )
544
545
attn_mask = mask_to_bias (attn_mask , x .dtype )
545
546
for transformer_block in transformer_blocks :
546
547
x = transformer_block (
@@ -558,7 +559,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
558
559
for resnet , transformer_blocks in self .mid_blocks :
559
560
x = resnet (x , mask_mid , t )
560
561
x = rearrange (x , "b c t -> b t c" ).contiguous ()
561
- attn_mask = ( torch . matmul ( mask_mid .transpose ( 1 , 2 ). contiguous (), mask_mid ) == 1 )
562
+ attn_mask = add_optional_chunk_mask ( x , mask_mid .bool (), False , False , 0 , 0 , - 1 ). repeat ( 1 , x . size ( 1 ), 1 )
562
563
attn_mask = mask_to_bias (attn_mask , x .dtype )
563
564
for transformer_block in transformer_blocks :
564
565
x = transformer_block (
@@ -574,7 +575,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
574
575
x = pack ([x [:, :, :skip .shape [- 1 ]], skip ], "b * t" )[0 ]
575
576
x = resnet (x , mask_up , t )
576
577
x = rearrange (x , "b c t -> b t c" ).contiguous ()
577
- attn_mask = ( torch . matmul ( mask_up .transpose ( 1 , 2 ). contiguous (), mask_up ) == 1 )
578
+ attn_mask = add_optional_chunk_mask ( x , mask_up .bool (), False , False , 0 , 0 , - 1 ). repeat ( 1 , x . size ( 1 ), 1 )
578
579
attn_mask = mask_to_bias (attn_mask , x .dtype )
579
580
for transformer_block in transformer_blocks :
580
581
x = transformer_block (
@@ -700,7 +701,7 @@ def __init__(
700
701
self .final_proj = nn .Conv1d (channels [- 1 ], self .out_channels , 1 )
701
702
self .initialize_weights ()
702
703
703
- def forward (self , x , mask , mu , t , spks = None , cond = None ):
704
+ def forward (self , x , mask , mu , t , spks = None , cond = None , streaming = False ):
704
705
"""Forward pass of the UNet1DConditional model.
705
706
706
707
Args:
@@ -735,7 +736,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
735
736
mask_down = masks [- 1 ]
736
737
x , _ , _ = resnet (x , mask_down , t )
737
738
x = rearrange (x , "b c t -> b t c" ).contiguous ()
738
- attn_mask = add_optional_chunk_mask (x , mask_down .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
739
+ if streaming is True :
740
+ attn_mask = add_optional_chunk_mask (x , mask_down .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
741
+ else :
742
+ attn_mask = add_optional_chunk_mask (x , mask_down .bool (), False , False , 0 , 0 , - 1 ).repeat (1 , x .size (1 ), 1 )
739
743
attn_mask = mask_to_bias (attn_mask , x .dtype )
740
744
for transformer_block in transformer_blocks :
741
745
x , _ = transformer_block (
@@ -753,7 +757,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
753
757
for resnet , transformer_blocks in self .mid_blocks :
754
758
x , _ , _ = resnet (x , mask_mid , t )
755
759
x = rearrange (x , "b c t -> b t c" ).contiguous ()
756
- attn_mask = add_optional_chunk_mask (x , mask_mid .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
760
+ if streaming is True :
761
+ attn_mask = add_optional_chunk_mask (x , mask_mid .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
762
+ else :
763
+ attn_mask = add_optional_chunk_mask (x , mask_mid .bool (), False , False , 0 , 0 , - 1 ).repeat (1 , x .size (1 ), 1 )
757
764
attn_mask = mask_to_bias (attn_mask , x .dtype )
758
765
for transformer_block in transformer_blocks :
759
766
x , _ = transformer_block (
@@ -769,7 +776,10 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
769
776
x = pack ([x [:, :, :skip .shape [- 1 ]], skip ], "b * t" )[0 ]
770
777
x , _ , _ = resnet (x , mask_up , t )
771
778
x = rearrange (x , "b c t -> b t c" ).contiguous ()
772
- attn_mask = add_optional_chunk_mask (x , mask_up .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
779
+ if streaming is True :
780
+ attn_mask = add_optional_chunk_mask (x , mask_up .bool (), False , False , 0 , self .static_chunk_size , self .num_decoding_left_chunks )
781
+ else :
782
+ attn_mask = add_optional_chunk_mask (x , mask_up .bool (), False , False , 0 , 0 , - 1 ).repeat (1 , x .size (1 ), 1 )
773
783
attn_mask = mask_to_bias (attn_mask , x .dtype )
774
784
for transformer_block in transformer_blocks :
775
785
x , _ = transformer_block (
0 commit comments