Skip to content

Commit 312f400

Browse files
committed
Fix vocab padding in generator
1 parent 7feed80 commit 312f400

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

exllamav2/generator/streaming.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = No
6262

6363
self.no_tokens = torch.empty((1, 0), dtype = torch.long)
6464
self.no_probs = torch.empty((1, 0), dtype = torch.float)
65-
self.no_logits = torch.empty((0, self.model.config.vocab_size), dtype = torch.float)
65+
self.no_logits = torch.empty((0, ((self.model.config.vocab_size + 31) // 32) * 32), dtype = torch.float)
6666

6767
if draft_model:
6868
self.draft_model = draft_model
@@ -193,7 +193,8 @@ def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor):
193193
self.held_text += new_text
194194
self.held_tokens = torch.cat([self.held_tokens, next_token], dim = -1)
195195
self.held_probs = torch.cat([self.held_probs, next_prob], dim = -1)
196-
self.held_logits = torch.cat([self.held_logits, next_logits], dim = 0)
196+
if self.return_logits:
197+
self.held_logits = torch.cat([self.held_logits, next_logits], dim = 0)
197198

198199
# Return now if newly added token ends a filter
199200

0 commit comments

Comments
 (0)