Skip to content

Commit b08b588

Browse files
fix batch rotary embedding test
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 0b34593 commit b08b588

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/kernels/core/test_pos_encoding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def test_batched_rotary_embedding(
152152
query = torch.randn(query_shape, dtype=dtype)
153153
key = torch.randn_like(query) if use_key else None
154154

155+
# slice tensor if required, noop otherwise
156+
query = query[..., :head_size]
157+
key = key[..., :head_size] if use_key else None
158+
155159
# NOTE(woosuk): The reference implementation should be executed first
156160
# because the custom kernel is in-place.
157161
ref_query, ref_key = rope.forward_native(positions, query, key)

0 commit comments

Comments
 (0)