@@ -60,7 +60,7 @@ def forward(self, input_ids, image, attention_mask=None, targets=None):
60
60
return logits , loss
61
61
62
62
@torch .inference_mode ()
63
- def generate (self , input_ids , image , attention_mask = None , max_new_tokens = 5 , top_k = 50 , top_p = 0.9 , temperature = 0.5 , greedy = False , use_kv_cache : bool = True ):
63
+ def generate (self , input_ids , image , attention_mask = None , max_new_tokens = 5 , top_k = 50 , top_p = 0.9 , temperature = 0.5 , greedy = False ):
64
64
65
65
# 1. Process image
66
66
image_embd = self .vision_encoder (image ) # [B, T_img, D_model]
@@ -122,27 +122,14 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
122
122
if attention_mask is not None :
123
123
attention_mask = torch .cat ((attention_mask , torch .ones ((batch_size , 1 ), device = attention_mask .device , dtype = attention_mask .dtype )), dim = 1 )
124
124
125
- if use_kv_cache :
126
- # With KV cache: only process the new token
127
- decode_step_output , kv_cache_list = self .decoder (
128
- next_token_embed ,
129
- attention_mask = attention_mask ,
130
- kv_cache = kv_cache_list ,
131
- start_pos = current_token_start_pos
132
- )
133
- else :
134
- # Without KV cache: process the entire sequence from scratch
135
- # Reconstruct the full sequence: image + prompt + generated tokens so far
136
- generated_token_embeds = torch .cat ([self .decoder .token_embedding (tid ) for tid in newly_generated_ids_list ], dim = 1 )
137
- full_sequence_embeds = torch .cat ([initial_combined_embeds , generated_token_embeds ], dim = 1 )
138
-
139
- decode_step_output , _ = self .decoder (
140
- full_sequence_embeds ,
141
- attention_mask = attention_mask ,
142
- kv_cache = None ,
143
- start_pos = 0
144
- )
145
-
125
+ # With KV cache: only process the new token
126
+ decode_step_output , kv_cache_list = self .decoder (
127
+ next_token_embed ,
128
+ attention_mask = attention_mask ,
129
+ kv_cache = kv_cache_list ,
130
+ start_pos = current_token_start_pos
131
+ )
132
+
146
133
last_token_output = decode_step_output [:, - 1 , :]
147
134
148
135
# Apply head to get logits (if model is in embedding mode)
@@ -152,7 +139,7 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
152
139
current_logits = last_token_output
153
140
154
141
if not newly_generated_ids_list : # Handle case where max_new_tokens might be 0
155
- return torch .empty ((B ,0 ), dtype = torch .long , device = input_ids .device )
142
+ return torch .empty ((batch_size ,0 ), dtype = torch .long , device = input_ids .device )
156
143
157
144
return torch .cat (newly_generated_ids_list , dim = 1 )
158
145
0 commit comments