7
7
from typing import Optional , Tuple
8
8
9
9
import torch
10
+ import torch .nn .functional as F
10
11
import transformers .models .llama .modeling_llama
11
12
from torch import nn
12
13
@@ -38,21 +39,48 @@ def xformers_forward(
38
39
# pylint: disable=duplicate-code
39
40
bsz , q_len , _ = hidden_states .size ()
40
41
41
- query_states = (
42
- self .q_proj (hidden_states )
43
- .view (bsz , q_len , self .num_heads , self .head_dim )
44
- .transpose (1 , 2 )
45
- )
46
- key_states = (
47
- self .k_proj (hidden_states )
48
- .view (bsz , q_len , self .num_heads , self .head_dim )
49
- .transpose (1 , 2 )
50
- )
51
- value_states = (
52
- self .v_proj (hidden_states )
53
- .view (bsz , q_len , self .num_heads , self .head_dim )
54
- .transpose (1 , 2 )
55
- )
42
+ if not hasattr (self , "pretraining_tp" ):
43
+ self .pretraining_tp = 1
44
+
45
+ if self .pretraining_tp > 1 :
46
+ key_value_slicing = (
47
+ self .num_key_value_heads * self .head_dim
48
+ ) // self .pretraining_tp
49
+ query_slices = self .q_proj .weight .split (
50
+ (self .num_heads * self .head_dim ) // self .pretraining_tp , dim = 0
51
+ )
52
+ key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
53
+ value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
54
+
55
+ query_states = [
56
+ F .linear (hidden_states , query_slices [i ]) for i in range (self .pretraining_tp )
57
+ ]
58
+ query_states = torch .cat (query_states , dim = - 1 )
59
+
60
+ key_states = [
61
+ F .linear (hidden_states , key_slices [i ]) for i in range (self .pretraining_tp )
62
+ ]
63
+ key_states = torch .cat (key_states , dim = - 1 )
64
+
65
+ value_states = [
66
+ F .linear (hidden_states , value_slices [i ]) for i in range (self .pretraining_tp )
67
+ ]
68
+ value_states = torch .cat (value_states , dim = - 1 )
69
+
70
+ else :
71
+ query_states = self .q_proj (hidden_states )
72
+ key_states = self .k_proj (hidden_states )
73
+ value_states = self .v_proj (hidden_states )
74
+
75
+ query_states = query_states .view (
76
+ bsz , q_len , self .num_heads , self .head_dim
77
+ ).transpose (1 , 2 )
78
+ key_states = key_states .view (
79
+ bsz , q_len , self .num_key_value_heads , self .head_dim
80
+ ).transpose (1 , 2 )
81
+ value_states = value_states .view (
82
+ bsz , q_len , self .num_key_value_heads , self .head_dim
83
+ ).transpose (1 , 2 )
56
84
57
85
kv_seq_len = key_states .shape [- 2 ]
58
86
if past_key_value is not None :
@@ -73,6 +101,14 @@ def xformers_forward(
73
101
74
102
past_key_value = (key_states , value_states ) if use_cache else None
75
103
104
+ # repeat k/v heads if n_kv_heads < n_heads
105
+ key_states = transformers .models .llama .modeling_llama .repeat_kv (
106
+ key_states , self .num_key_value_groups
107
+ )
108
+ value_states = transformers .models .llama .modeling_llama .repeat_kv (
109
+ value_states , self .num_key_value_groups
110
+ )
111
+
76
112
# We only apply xformers optimizations if we don't need to output the whole attention matrix
77
113
if not output_attentions :
78
114
query_states = query_states .transpose (1 , 2 )
@@ -128,10 +164,23 @@ def xformers_forward(
128
164
f" { attn_output .size ()} "
129
165
)
130
166
131
- attn_output = attn_output .transpose (1 , 2 )
167
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
168
+ # end x-formers vs. not x-formers if-else block
132
169
133
170
attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
134
- attn_output = self .o_proj (attn_output )
171
+
172
+ if self .pretraining_tp > 1 :
173
+ attn_output = attn_output .split (self .hidden_size // self .pretraining_tp , dim = 2 )
174
+ o_proj_slices = self .o_proj .weight .split (
175
+ self .hidden_size // self .pretraining_tp , dim = 1
176
+ )
177
+ attn_output = sum (
178
+ F .linear (attn_output [i ], o_proj_slices [i ])
179
+ for i in range (self .pretraining_tp )
180
+ )
181
+ else :
182
+ attn_output = self .o_proj (attn_output )
183
+
135
184
return attn_output , attn_weights , past_key_value
136
185
137
186
0 commit comments