Skip to content

Commit 3247106

Browse files
authored
Merge pull request serp-ai#27 from zygi/main
Add key/value caching for autoregressive generation
2 parents 874af1b + acfd65b commit 3247106

File tree

4 files changed

+94
-29
lines changed

4 files changed

+94
-29
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
__pycache__/
2-

bark/api.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def text_to_semantic(
1010
history_prompt: Optional[str] = None,
1111
temp: float = 0.7,
1212
silent: bool = False,
13+
use_kv_caching = False,
1314
):
1415
"""Generate semantic array from text.
1516
@@ -27,6 +28,7 @@ def text_to_semantic(
2728
history_prompt=history_prompt,
2829
temp=temp,
2930
silent=silent,
31+
use_kv_caching=use_kv_caching
3032
)
3133
return x_semantic
3234

@@ -37,6 +39,7 @@ def semantic_to_waveform(
3739
temp: float = 0.7,
3840
silent: bool = False,
3941
output_full: bool = False,
42+
use_kv_caching = False
4043
):
4144
"""Generate audio array from semantic input.
4245
@@ -55,6 +58,7 @@ def semantic_to_waveform(
5558
history_prompt=history_prompt,
5659
temp=temp,
5760
silent=silent,
61+
use_kv_caching=use_kv_caching
5862
)
5963
fine_tokens = generate_fine(
6064
coarse_tokens,
@@ -88,6 +92,7 @@ def generate_audio(
8892
waveform_temp: float = 0.7,
8993
silent: bool = False,
9094
output_full: bool = False,
95+
use_kv_caching = False
9196
):
9297
"""Generate audio array from input text.
9398
@@ -103,14 +108,15 @@ def generate_audio(
103108
numpy audio array at sample frequency 24khz
104109
"""
105110
semantic_tokens = text_to_semantic(
106-
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
111+
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
107112
)
108113
out = semantic_to_waveform(
109114
semantic_tokens,
110115
history_prompt=history_prompt,
111116
temp=waveform_temp,
112117
silent=silent,
113118
output_full=output_full,
119+
use_kv_caching=use_kv_caching
114120
)
115121
if output_full:
116122
full_generation, audio_arr = out

bark/generation.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def generate_text_semantic(
359359
max_gen_duration_s=None,
360360
allow_early_stop=True,
361361
model=None,
362+
use_kv_caching=False
362363
):
363364
"""Generate semantic tokens from text."""
364365
assert isinstance(text, str)
@@ -420,8 +421,14 @@ def generate_text_semantic(
420421
pbar = tqdm.tqdm(disable=silent, total=100)
421422
pbar_state = 0
422423
tot_generated_duration_s = 0
424+
kv_cache = None
423425
for n in range(n_tot_steps):
424-
logits = model(x, merge_context=True)
426+
if use_kv_caching and kv_cache is not None:
427+
x_input = x[:, [-1]]
428+
else:
429+
x_input = x
430+
431+
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
425432
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
426433
if allow_early_stop:
427434
relevant_logits = torch.hstack(
@@ -498,6 +505,7 @@ def generate_coarse(
498505
max_coarse_history=630, # min 60 (faster), max 630 (more context)
499506
sliding_window_len=60,
500507
model=None,
508+
use_kv_caching=False
501509
):
502510
"""Generate coarse audio codes from semantic tokens."""
503511
assert (
@@ -592,11 +600,18 @@ def generate_coarse(
592600
x_coarse_in[:, -max_coarse_history:],
593601
]
594602
)
603+
kv_cache = None
595604
for _ in range(sliding_window_len):
596605
if n_step >= n_steps:
597606
continue
598607
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
599-
logits = model(x_in)
608+
609+
if use_kv_caching and kv_cache is not None:
610+
x_input = x_in[:, [-1]]
611+
else:
612+
x_input = x_in
613+
614+
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
600615
logit_start_idx = (
601616
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
602617
)

bark/model.py

+70-25
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, config):
4343
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
4444
.view(1, 1, config.block_size, config.block_size))
4545

46-
def forward(self, x):
46+
def forward(self, x, past_kv=None, use_cache=False):
4747
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
4848

4949
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -52,22 +52,44 @@ def forward(self, x):
5252
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
5353
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
5454

55+
if past_kv is not None:
56+
past_key = past_kv[0]
57+
past_value = past_kv[1]
58+
k = torch.cat((past_key, k), dim=-2)
59+
v = torch.cat((past_value, v), dim=-2)
60+
61+
FULL_T = k.shape[-2]
62+
63+
if use_cache is True:
64+
present = (k, v)
65+
else:
66+
present = None
67+
5568
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
5669
if self.flash:
5770
# efficient attention using Flash Attention CUDA kernels
58-
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
71+
if past_kv is not None:
72+
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
73+
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
74+
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
75+
# to work around this we set is_causal=False.
76+
is_causal = False
77+
else:
78+
is_causal = True
79+
80+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
5981
else:
6082
# manual implementation of attention
6183
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
62-
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
84+
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
6385
att = F.softmax(att, dim=-1)
6486
att = self.attn_dropout(att)
6587
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
6688
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
6789

6890
# output projection
6991
y = self.resid_dropout(self.c_proj(y))
70-
return y
92+
return (y, present)
7193

7294
class MLP(nn.Module):
7395

@@ -95,10 +117,11 @@ def __init__(self, config, layer_idx):
95117
self.mlp = MLP(config)
96118
self.layer_idx = layer_idx
97119

98-
def forward(self, x):
99-
x = x + self.attn(self.ln_1(x))
120+
def forward(self, x, past_kv=None, use_cache=False):
121+
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
122+
x = x + attn_output
100123
x = x + self.mlp(self.ln_2(x))
101-
return x
124+
return (x, prev_kvs)
102125

103126
@dataclass
104127
class GPTConfig:
@@ -142,33 +165,55 @@ def get_num_params(self, non_embedding=True):
142165
n_params -= self.transformer.wpe.weight.numel()
143166
return n_params
144167

145-
def forward(self, idx, merge_context=False):
168+
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
146169
device = idx.device
147170
b, t = idx.size()
148-
if merge_context:
149-
assert(idx.shape[1] >= 256+256+1)
150-
t = idx.shape[1] - 256
171+
if past_kv is not None:
172+
assert t == 1
173+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
151174
else:
152-
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
153-
154-
# forward the GPT model itself
155-
if merge_context:
156-
tok_emb = torch.cat([
157-
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
158-
self.transformer.wte(idx[:,256+256:])
159-
], dim=1)
175+
if merge_context:
176+
assert(idx.shape[1] >= 256+256+1)
177+
t = idx.shape[1] - 256
178+
else:
179+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180+
181+
# forward the GPT model itself
182+
if merge_context:
183+
tok_emb = torch.cat([
184+
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
185+
self.transformer.wte(idx[:,256+256:])
186+
], dim=1)
187+
else:
188+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
189+
190+
if past_kv is None:
191+
past_length = 0
192+
past_kv = tuple([None] * len(self.transformer.h))
160193
else:
161-
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
194+
past_length = past_kv[0][0].size(-2)
195+
196+
if position_ids is None:
197+
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
198+
position_ids = position_ids.unsqueeze(0) # shape (1, t)
199+
assert position_ids.shape == (1, t)
200+
201+
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
162202

163-
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
164-
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
165203

166204
x = self.transformer.drop(tok_emb + pos_emb)
167-
for block in self.transformer.h:
168-
x = block(x)
205+
206+
new_kv = () if use_cache else None
207+
208+
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
209+
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
210+
211+
if use_cache:
212+
new_kv = new_kv + (kv,)
213+
169214
x = self.transformer.ln_f(x)
170215

171216
# inference-time mini-optimization: only forward the lm_head on the very last position
172217
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
173218

174-
return logits
219+
return (logits, new_kv)

0 commit comments

Comments
 (0)