@@ -43,7 +43,7 @@ def __init__(self, config):
43
43
self .register_buffer ("bias" , torch .tril (torch .ones (config .block_size , config .block_size ))
44
44
.view (1 , 1 , config .block_size , config .block_size ))
45
45
46
- def forward (self , x ):
46
+ def forward (self , x , past_kv = None , use_cache = False ):
47
47
B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
48
48
49
49
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -52,22 +52,44 @@ def forward(self, x):
52
52
q = q .view (B , T , self .n_head , C // self .n_head ).transpose (1 , 2 ) # (B, nh, T, hs)
53
53
v = v .view (B , T , self .n_head , C // self .n_head ).transpose (1 , 2 ) # (B, nh, T, hs)
54
54
55
+ if past_kv is not None :
56
+ past_key = past_kv [0 ]
57
+ past_value = past_kv [1 ]
58
+ k = torch .cat ((past_key , k ), dim = - 2 )
59
+ v = torch .cat ((past_value , v ), dim = - 2 )
60
+
61
+ FULL_T = k .shape [- 2 ]
62
+
63
+ if use_cache is True :
64
+ present = (k , v )
65
+ else :
66
+ present = None
67
+
55
68
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
56
69
if self .flash :
57
70
# efficient attention using Flash Attention CUDA kernels
58
- y = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = self .dropout , is_causal = True )
71
+ if past_kv is not None :
72
+ # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
73
+ # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
74
+ # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
75
+ # to work around this we set is_causal=False.
76
+ is_causal = False
77
+ else :
78
+ is_causal = True
79
+
80
+ y = torch .nn .functional .scaled_dot_product_attention (q , k , v , dropout_p = self .dropout , is_causal = is_causal )
59
81
else :
60
82
# manual implementation of attention
61
83
att = (q @ k .transpose (- 2 , - 1 )) * (1.0 / math .sqrt (k .size (- 1 )))
62
- att = att .masked_fill (self .bias [:,:,: T ,: T ] == 0 , float ('-inf' ))
84
+ att = att .masked_fill (self .bias [:,:,FULL_T - T : FULL_T ,: FULL_T ] == 0 , float ('-inf' ))
63
85
att = F .softmax (att , dim = - 1 )
64
86
att = self .attn_dropout (att )
65
87
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
66
88
y = y .transpose (1 , 2 ).contiguous ().view (B , T , C ) # re-assemble all head outputs side by side
67
89
68
90
# output projection
69
91
y = self .resid_dropout (self .c_proj (y ))
70
- return y
92
+ return ( y , present )
71
93
72
94
class MLP (nn .Module ):
73
95
@@ -95,10 +117,11 @@ def __init__(self, config, layer_idx):
95
117
self .mlp = MLP (config )
96
118
self .layer_idx = layer_idx
97
119
98
- def forward (self , x ):
99
- x = x + self .attn (self .ln_1 (x ))
120
+ def forward (self , x , past_kv = None , use_cache = False ):
121
+ attn_output , prev_kvs = self .attn (self .ln_1 (x ), past_kv = past_kv , use_cache = use_cache )
122
+ x = x + attn_output
100
123
x = x + self .mlp (self .ln_2 (x ))
101
- return x
124
+ return ( x , prev_kvs )
102
125
103
126
@dataclass
104
127
class GPTConfig :
@@ -142,33 +165,55 @@ def get_num_params(self, non_embedding=True):
142
165
n_params -= self .transformer .wpe .weight .numel ()
143
166
return n_params
144
167
145
- def forward (self , idx , merge_context = False ):
168
+ def forward (self , idx , merge_context = False , past_kv = None , position_ids = None , use_cache = False ):
146
169
device = idx .device
147
170
b , t = idx .size ()
148
- if merge_context :
149
- assert ( idx . shape [ 1 ] >= 256 + 256 + 1 )
150
- t = idx . shape [ 1 ] - 256
171
+ if past_kv is not None :
172
+ assert t == 1
173
+ tok_emb = self . transformer . wte ( idx ) # token embeddings of shape (b, t, n_embd)
151
174
else :
152
- assert t <= self .config .block_size , f"Cannot forward sequence of length { t } , block size is only { self .config .block_size } "
153
-
154
- # forward the GPT model itself
155
- if merge_context :
156
- tok_emb = torch .cat ([
157
- self .transformer .wte (idx [:,:256 ]) + self .transformer .wte (idx [:,256 :256 + 256 ]),
158
- self .transformer .wte (idx [:,256 + 256 :])
159
- ], dim = 1 )
175
+ if merge_context :
176
+ assert (idx .shape [1 ] >= 256 + 256 + 1 )
177
+ t = idx .shape [1 ] - 256
178
+ else :
179
+ assert t <= self .config .block_size , f"Cannot forward sequence of length { t } , block size is only { self .config .block_size } "
180
+
181
+ # forward the GPT model itself
182
+ if merge_context :
183
+ tok_emb = torch .cat ([
184
+ self .transformer .wte (idx [:,:256 ]) + self .transformer .wte (idx [:,256 :256 + 256 ]),
185
+ self .transformer .wte (idx [:,256 + 256 :])
186
+ ], dim = 1 )
187
+ else :
188
+ tok_emb = self .transformer .wte (idx ) # token embeddings of shape (b, t, n_embd)
189
+
190
+ if past_kv is None :
191
+ past_length = 0
192
+ past_kv = tuple ([None ] * len (self .transformer .h ))
160
193
else :
161
- tok_emb = self .transformer .wte (idx ) # token embeddings of shape (b, t, n_embd)
194
+ past_length = past_kv [0 ][0 ].size (- 2 )
195
+
196
+ if position_ids is None :
197
+ position_ids = torch .arange (past_length , t + past_length , dtype = torch .long , device = device )
198
+ position_ids = position_ids .unsqueeze (0 ) # shape (1, t)
199
+ assert position_ids .shape == (1 , t )
200
+
201
+ pos_emb = self .transformer .wpe (position_ids ) # position embeddings of shape (1, t, n_embd)
162
202
163
- pos = torch .arange (0 , t , dtype = torch .long , device = device ).unsqueeze (0 ) # shape (1, t)
164
- pos_emb = self .transformer .wpe (pos ) # position embeddings of shape (1, t, n_embd)
165
203
166
204
x = self .transformer .drop (tok_emb + pos_emb )
167
- for block in self .transformer .h :
168
- x = block (x )
205
+
206
+ new_kv = () if use_cache else None
207
+
208
+ for i , (block , past_layer_kv ) in enumerate (zip (self .transformer .h , past_kv )):
209
+ x , kv = block (x , past_kv = past_layer_kv , use_cache = use_cache )
210
+
211
+ if use_cache :
212
+ new_kv = new_kv + (kv ,)
213
+
169
214
x = self .transformer .ln_f (x )
170
215
171
216
# inference-time mini-optimization: only forward the lm_head on the very last position
172
217
logits = self .lm_head (x [:, [- 1 ], :]) # note: using list [-1] to preserve the time dim
173
218
174
- return logits
219
+ return ( logits , new_kv )
0 commit comments