Skip to content

Commit 9b0e3ec

Browse files
authored
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
1 parent 86e9c8d commit 9b0e3ec

File tree

8 files changed

+64
-38
lines changed

8 files changed

+64
-38
lines changed

tests/lora/test_punica_sizes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def test_punica_sgmv(
169169
device,
170170
)
171171
max_seq_length = seq_len_tensor.max()
172+
token_nums = seq_len_tensor.sum().item()
172173
if isinstance(max_seq_length, tuple):
173174
max_seq_length = max_seq_length[0].item()
174175
else:
@@ -183,6 +184,7 @@ def test_punica_sgmv(
183184
lora_indices_tensor,
184185
batches,
185186
max_seq_length,
187+
token_nums,
186188
scaling,
187189
)
188190
else:
@@ -195,6 +197,7 @@ def test_punica_sgmv(
195197
lora_indices_tensor,
196198
batches,
197199
max_seq_length,
200+
token_nums,
198201
add_inputs=True,
199202
)
200203
ref_torch_groupgemm(
@@ -347,6 +350,7 @@ def test_punica_expand_nslices(
347350
device,
348351
)
349352
max_seq_length = seq_len_tensor.max()
353+
token_nums = seq_len_tensor.sum().item()
350354
if isinstance(max_seq_length, tuple):
351355
max_seq_length = max_seq_length[0].item()
352356
else:
@@ -364,6 +368,7 @@ def test_punica_expand_nslices(
364368
lora_indices_tensor,
365369
batches,
366370
max_seq_length,
371+
token_nums,
367372
slice_offset,
368373
hidden_size,
369374
add_inputs=True,

tests/lora/test_punica_variation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_punica_sgmv(
8484
device,
8585
)
8686
max_seq_length = seq_len_tensor.max()
87+
token_nums = seq_len_tensor.sum().item()
8788
if isinstance(max_seq_length, tuple):
8889
max_seq_length = max_seq_length[0].item()
8990
else:
@@ -98,6 +99,7 @@ def test_punica_sgmv(
9899
lora_indices_tensor,
99100
batches,
100101
max_seq_length,
102+
token_nums,
101103
scaling,
102104
)
103105
else:
@@ -110,6 +112,7 @@ def test_punica_sgmv(
110112
lora_indices_tensor,
111113
batches,
112114
max_seq_length,
115+
token_nums,
113116
add_inputs=True,
114117
)
115118
ref_torch_groupgemm(
@@ -262,6 +265,7 @@ def test_punica_expand_nslices(
262265
device,
263266
)
264267
max_seq_length = seq_len_tensor.max()
268+
token_nums = seq_len_tensor.sum().item()
265269
if isinstance(max_seq_length, tuple):
266270
max_seq_length = max_seq_length[0].item()
267271
else:
@@ -279,6 +283,7 @@ def test_punica_expand_nslices(
279283
lora_indices_tensor,
280284
batches,
281285
max_seq_length,
286+
token_nums,
282287
slice_offset,
283288
hidden_size,
284289
add_inputs=True,

vllm/lora/ops/bgmv_expand.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _bgmv_expand(
100100
corresponding to each batch, An index of -1 means no lora should be
101101
applied.
102102
batches (int): batch size
103-
add_inputs (bool, optional): Defaults to False. adds the final lora
103+
add_inputs (bool, optional): Defaults to False, adds the final lora
104104
results to the output.
105105
"""
106106
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]

vllm/lora/ops/bgmv_expand_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
104104
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
105105
corresponding to each batch, An index of -1 means no lora should be
106106
applied.
107-
slice_offst (int): output_tensor's offst
107+
slice_offset (int): output_tensor's offset
108108
slice_size (int): current output_tensor's size
109109
batches (int): batch size
110110
add_inputs (bool, optional): Defaults to False.

vllm/lora/ops/sgmv_expand.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _sgmv_expand(
106106
lora_indices_tensor: torch.Tensor,
107107
batches: int,
108108
max_seq_length: int,
109+
token_nums: int,
109110
add_inputs: bool = False,
110111
) -> None:
111112
"""
@@ -115,17 +116,19 @@ def _sgmv_expand(
115116
output_tensor (torch.Tensor): output tensor
116117
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
117118
sequence lengths of the sequences in the batch, used to index
118-
into sequence. E.g.,if the sequence length is [4, 6], it is
119+
into sequence. E.g., if the sequence length is [4, 6], it is
119120
[0, 4, 10].
120-
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
121-
length of the sequences in the batch
121+
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
122+
length of the sequences in the batch.
122123
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
123124
corresponding to each batch. An index of -1 means no lora should be
124125
applied.
125126
batches (int): batch size
126-
max_seq_length (int): The max sequence lengths of the sequences
127-
in the batch
128-
add_inputs (bool, optional): Defaults to False. adds the final lora
127+
max_seq_length (int): The max sequence lengths of the sequences in the
128+
batch.
129+
token_nums (int): The token numbers in the batch. Used to verify if the
130+
token numbers in the inputs matches the one in the metadata.
131+
add_inputs (bool, optional): Defaults to False, adds the final lora
129132
results to the output.
130133
"""
131134

@@ -134,6 +137,7 @@ def _sgmv_expand(
134137
torch.float16,
135138
torch.bfloat16,
136139
]
140+
assert inputs.size(0) == token_nums
137141
assert inputs.size(1) == lora_b_weights.size(-1)
138142
assert b_seq_start_loc.size(0) == batches
139143
assert lora_indices_tensor.size(0) == batches

vllm/lora/ops/sgmv_expand_slice.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
112112
lora_indices_tensor: torch.Tensor,
113113
batches: int,
114114
max_seq_length: int,
115+
token_nums: int,
115116
slice_offset: int,
116117
slice_size: int,
117118
add_inputs: bool = False,
@@ -124,27 +125,30 @@ def _sgmv_expand_slice(
124125
output_tensor (torch.Tensor): output tensor
125126
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
126127
sequence lengths of the sequences in the batch, used to index
127-
into sequence. E.g.,if the sequence length is [4, 6], it is
128+
into sequence. E.g., if the sequence length is [4, 6], it is
128129
[0, 4, 10].
129-
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
130-
length of the sequences in the batch
130+
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
131+
length of the sequences in the batch
131132
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
132133
corresponding to each batch. An index of -1 means no lora should be
133134
applied.
134135
batches (int): batch size
135-
max_seq_length (int): The max sequence lengths of the sequences
136+
max_seq_length (int): The max sequence lengths of the sequences
136137
in the batch
137-
slice_offst (int): output_tensor's offst
138+
token_nums (int): The token numbers in the batch. Used to verify if the
139+
token numbers in the inputs matches the one in the metadata.
140+
slice_offset (int): output_tensor's offset
138141
slice_size (int): current output_tensor's size
139-
add_inputs (bool, optional): Defaults to False. adds the final lora
140-
results to the output..
142+
add_inputs (bool, optional): Defaults to False, adds the final lora
143+
results to the output.
141144
"""
142145

143146
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
144147
assert lora_b_weights.dtype in [
145148
torch.float16,
146149
torch.bfloat16,
147150
]
151+
assert inputs.size(0) == token_nums
148152
assert inputs.size(1) == lora_b_weights.size(-1)
149153
assert b_seq_start_loc.size(0) == batches
150154
assert lora_indices_tensor.size(0) == batches

vllm/lora/ops/sgmv_shrink.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _sgmv_shrink(
110110
lora_indices_tensor: torch.Tensor,
111111
batches: int,
112112
max_seq_length: int,
113+
token_nums: int,
113114
scaling: float,
114115
) -> None:
115116
"""
@@ -120,24 +121,27 @@ def _sgmv_shrink(
120121
output_tensor (torch.Tensor): output tensor
121122
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
122123
sequence lengths of the sequences in the batch, used to index
123-
into sequence. E.g.,if the sequence length is [4, 6], it is
124+
into sequence. E.g., if the sequence length is [4, 6], it is
124125
[0, 4].
125-
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
126-
length of the sequences in the batch
126+
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
127+
length of the sequences in the batch.
127128
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
128129
corresponding to each batch. An index of -1 means no lora should be
129130
applied.
130131
batches (int): batch size
131-
max_seq_length (int): The max sequence lengths of the sequences
132-
in the batch
133-
scaling (float): Scaling factor.
132+
max_seq_length (int): The max sequence lengths of the sequences in the
133+
batch.
134+
token_nums (int): The token numbers in the batch. Used to verify if the
135+
token numbers in the inputs matches the one in the metadata.
136+
scaling (float): Scaling factor.
134137
"""
135138
assert inputs.dtype == lora_a_weights.dtype
136139
assert inputs.dtype in [torch.float16, torch.bfloat16]
137140
assert lora_a_weights.dtype in [
138141
torch.float16,
139142
torch.bfloat16,
140143
]
144+
assert inputs.size(0) == token_nums
141145
assert inputs.size(1) == lora_a_weights.size(-1)
142146
assert b_seq_start_loc.size(0) == batches
143147
assert lora_indices_tensor.size(0) == batches

vllm/lora/punica.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def compute_meta(
2929
token_lora_tensor: torch.Tensor
30-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
30+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
3131
"""
3232
Get the information required for the sgmv kernel. With the features:
3333
1. If consecutive requests in the batch use the same LoRA, this function
@@ -43,7 +43,7 @@ def compute_meta(
4343
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
4444
b_seq_start_tensor[1:].copy_(cum_result[:-1])
4545
max_length = seq_length_tensor.max().item()
46-
46+
token_nums = seq_length_tensor.sum().item()
4747
batch_size = lora_indices_tensor.size(0)
4848
no_lora = False
4949
# -1 means no lora should be applied. Use `no_lora` to determine whether
@@ -52,7 +52,7 @@ def compute_meta(
5252
if batch_size == 1 and lora_indices_tensor == -1:
5353
no_lora = True
5454
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
55-
batch_size, max_length, no_lora)
55+
batch_size, max_length, token_nums, no_lora)
5656

5757

5858
# TODO see if this can be vectorized
@@ -178,7 +178,7 @@ def convert_mapping(
178178
class PunicaWrapper:
179179
"""
180180
PunicaWrapper is designed to manage and provide metadata for the punica
181-
kernel. The main function is to maintain the state information for
181+
kernel. The main function is to maintain the state information for
182182
Multi-LoRA, and to provide the interface for the punica kernel.
183183
"""
184184

@@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
216216
dtype=torch.long,
217217
device=device)
218218
self.max_length: int = 0
219+
self.token_nums: int = 0
219220
self.batch_size: int = -1
220221
self.is_prefill = False
221222
self.no_lora = False
@@ -276,13 +277,13 @@ def _update_base_metadata(
276277
long_lora_offsets_tensor)
277278
else:
278279
self._long_lora_indices.zero_()
279-
280280
self.indices_len[:] = indices_len
281281

282282
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
283283

284284
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
285-
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
285+
batch_size, max_length, token_nums,
286+
no_lora) = compute_meta(token_lora_tensor)
286287

287288
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
288289
b_seq_start_tensor)
@@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
291292
lora_indices_tensor)
292293
self.batch_size = batch_size
293294
self.max_length = max_length
295+
self.token_nums = token_nums
294296
self.no_lora = no_lora
295297

296298
@property
297299
def prefill_metadata(
298-
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
300+
self
301+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
299302
"""
300303
This property provides a convenient way to access the necessary
301304
metadata for prefill-related kernel computations.
302-
1. seq_start_locs: Tensor of sequence start positions
303-
2. seq_lengths: Tensor of sequence lengths
305+
1. seq_start_locs: Tensor of sequence start positions.
306+
2. seq_lengths: Tensor of sequence lengths.
304307
3. lora_indices_per_batch: Tensor of lora indices, and an index of
305308
-1 means no lora should be applied.
306-
4. batch_size: batch size after clustering identical lora indices
307-
5. max_length: The maximum sequence length in the batch
309+
4. batch_size: Batch size after clustering identical lora indices.
310+
5. max_length: The maximum sequence length in the batch.
311+
6. token_nums: The token numbers in the batch.
308312
"""
309313
return (self._seq_start_locs[:self.batch_size],
310314
self._seq_lengths[:self.batch_size],
311315
self._lora_indices_per_batch[:self.batch_size],
312-
self.batch_size, self.max_length)
316+
self.batch_size, self.max_length, self.token_nums)
313317

314318
@property
315319
def token_lora_indices(self) -> torch.Tensor:
@@ -324,15 +328,15 @@ def token_lora_indices(self) -> torch.Tensor:
324328
def sampler_indices(self) -> torch.Tensor:
325329
"""
326330
This property is used to access the lora indices specifically for
327-
LogitsProcessorWithLoRA
331+
LogitsProcessorWithLoRA.
328332
"""
329333
sampler_indices_len = self.indices_len[1]
330334
return self._sampler_indices[:sampler_indices_len]
331335

332336
@property
333337
def sampler_indices_padded(self) -> torch.Tensor:
334338
"""
335-
This property provides access to padded sampler indices
339+
This property provides access to padded sampler indices.
336340
"""
337341
indices_padded_len = self.indices_len[2]
338342
return self._sampler_indices_padded[:indices_padded_len]
@@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor:
341345
def embeddings_indices(self) -> torch.Tensor:
342346
"""
343347
This property provides access to the indices used for lora embeddings,
344-
specifically for VocabParallelEmbeddingWithLoRA
348+
specifically for VocabParallelEmbeddingWithLoRA.
345349
"""
346350
embeddings_indices_len = self.indices_len[3]
347351
return self._embeddings_indices[:, :embeddings_indices_len]
@@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor:
350354
def long_lora_indices(self) -> torch.Tensor:
351355
"""
352356
This property provides access to the indices used for long context
353-
lora, specifically for LinearScalingRotaryEmbeddingWithLora
357+
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
354358
"""
355359
long_lora_len = self.indices_len[4]
356360
return self._long_lora_indices[:long_lora_len]
@@ -524,7 +528,7 @@ def add_lora(self,
524528
scale (float): Scaling factor.
525529
y_offset (Optional[int], optional): Offset to apply to the starting
526530
column of y.
527-
y_slice_size (Optional[int], optional): Size of the y column slice..
531+
y_slice_size (Optional[int], optional): Size of the y column slice.
528532
buffer (Optional[torch.Tensor], optional): Defaults to None.
529533
"""
530534
y_org = y

0 commit comments

Comments
 (0)