Skip to content

Commit 8be6c6a

Browse files
Merge pull request #84 from huggingface/kv-default
Make KV Cache as the only strategy for inference
2 parents a173568 + 9d2a556 commit 8be6c6a

File tree

1 file changed

+10
-23
lines changed

1 file changed

+10
-23
lines changed

models/vision_language_model.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, input_ids, image, attention_mask=None, targets=None):
6060
return logits, loss
6161

6262
@torch.inference_mode()
63-
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False, use_kv_cache: bool = True):
63+
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
6464

6565
# 1. Process image
6666
image_embd = self.vision_encoder(image) # [B, T_img, D_model]
@@ -122,27 +122,14 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
122122
if attention_mask is not None:
123123
attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)
124124

125-
if use_kv_cache:
126-
# With KV cache: only process the new token
127-
decode_step_output, kv_cache_list = self.decoder(
128-
next_token_embed,
129-
attention_mask=attention_mask,
130-
kv_cache=kv_cache_list,
131-
start_pos=current_token_start_pos
132-
)
133-
else:
134-
# Without KV cache: process the entire sequence from scratch
135-
# Reconstruct the full sequence: image + prompt + generated tokens so far
136-
generated_token_embeds = torch.cat([self.decoder.token_embedding(tid) for tid in newly_generated_ids_list], dim=1)
137-
full_sequence_embeds = torch.cat([initial_combined_embeds, generated_token_embeds], dim=1)
138-
139-
decode_step_output, _ = self.decoder(
140-
full_sequence_embeds,
141-
attention_mask=attention_mask,
142-
kv_cache=None,
143-
start_pos=0
144-
)
145-
125+
# With KV cache: only process the new token
126+
decode_step_output, kv_cache_list = self.decoder(
127+
next_token_embed,
128+
attention_mask=attention_mask,
129+
kv_cache=kv_cache_list,
130+
start_pos=current_token_start_pos
131+
)
132+
146133
last_token_output = decode_step_output[:, -1, :]
147134

148135
# Apply head to get logits (if model is in embedding mode)
@@ -152,7 +139,7 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
152139
current_logits = last_token_output
153140

154141
if not newly_generated_ids_list: # Handle case where max_new_tokens might be 0
155-
return torch.empty((B,0), dtype=torch.long, device=input_ids.device)
142+
return torch.empty((batch_size,0), dtype=torch.long, device=input_ids.device)
156143

157144
return torch.cat(newly_generated_ids_list, dim=1)
158145

0 commit comments

Comments
 (0)