Skip to content

Commit ad091cc

Browse files
committed
make kv cache default
1 parent a173568 commit ad091cc

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

models/vision_language_model.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, input_ids, image, attention_mask=None, targets=None):
6060
return logits, loss
6161

6262
@torch.inference_mode()
63-
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False, use_kv_cache: bool = True):
63+
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
6464

6565
# 1. Process image
6666
image_embd = self.vision_encoder(image) # [B, T_img, D_model]
@@ -122,27 +122,14 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
122122
if attention_mask is not None:
123123
attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)
124124

125-
if use_kv_cache:
126-
# With KV cache: only process the new token
127-
decode_step_output, kv_cache_list = self.decoder(
128-
next_token_embed,
129-
attention_mask=attention_mask,
130-
kv_cache=kv_cache_list,
131-
start_pos=current_token_start_pos
132-
)
133-
else:
134-
# Without KV cache: process the entire sequence from scratch
135-
# Reconstruct the full sequence: image + prompt + generated tokens so far
136-
generated_token_embeds = torch.cat([self.decoder.token_embedding(tid) for tid in newly_generated_ids_list], dim=1)
137-
full_sequence_embeds = torch.cat([initial_combined_embeds, generated_token_embeds], dim=1)
138-
139-
decode_step_output, _ = self.decoder(
140-
full_sequence_embeds,
141-
attention_mask=attention_mask,
142-
kv_cache=None,
143-
start_pos=0
144-
)
145-
125+
# With KV cache: only process the new token
126+
decode_step_output, kv_cache_list = self.decoder(
127+
next_token_embed,
128+
attention_mask=attention_mask,
129+
kv_cache=kv_cache_list,
130+
start_pos=current_token_start_pos
131+
)
132+
146133
last_token_output = decode_step_output[:, -1, :]
147134

148135
# Apply head to get logits (if model is in embedding mode)

0 commit comments

Comments
 (0)