From 18309d8586924f34c1a5e050dd5c217bd62d78ca Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 26 Feb 2025 08:27:07 -0500 Subject: [PATCH] update --- cpp/grammar_matcher.cc | 12 ------------ .../kernels/apply_token_bitmask_inplace_cuda.cu | 12 ------------ .../kernels/apply_token_bitmask_inplace_cuda.py | 2 +- .../kernels/apply_token_bitmask_inplace_triton.py | 10 ---------- tests/python/test_token_bitmask_operations.py | 8 +------- 5 files changed, 2 insertions(+), 42 deletions(-) diff --git a/cpp/grammar_matcher.cc b/cpp/grammar_matcher.cc index 4cb596e..f7ebf43 100644 --- a/cpp/grammar_matcher.cc +++ b/cpp/grammar_matcher.cc @@ -112,18 +112,6 @@ void ApplyTokenBitmaskInplaceCPU( std::vector indices_value; if (indices.has_value()) { indices_value = indices.value(); - std::sort(indices_value.begin(), indices_value.end()); - indices_value.erase( - std::unique(indices_value.begin(), indices_value.end()), indices_value.end() - ); - XGRAMMAR_CHECK(indices_value.front() >= 0) - << "The provided indices is negative: " << indices_value.front(); - XGRAMMAR_CHECK(indices_value.back() < logits_shape.first) - << "The provided indices is larger than the logits's batch size: " << indices_value.back() - << " >= " << logits_shape.first; - XGRAMMAR_CHECK(static_cast(indices_value.size()) <= bitmask_shape.first) - << "The provided indices is larger than the bitmask's batch size: " << indices_value.size() - << " >= " << bitmask_shape.first; } else { XGRAMMAR_CHECK(logits_shape.first == bitmask_shape.first) << "When indices is not provided, the logits's batch size should be equal to the " diff --git a/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu b/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu index 02c6750..3f58ab4 100644 --- a/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu +++ b/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu @@ -222,18 +222,6 @@ void ApplyTokenBitmaskInplace( TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); - TORCH_CHECK( - indices->size(0) <= bitmask_shape.first, - "indices must have the batch size no larger than bitmask's batch size." - ); - TORCH_CHECK( - indices->index({0}).item() >= 0, - "indices must have the minimum value no less than 0." - ); - TORCH_CHECK( - indices->index({indices->size(0) - 1}).item() < logits_shape.first, - "indices must have the maximum value no larger than logits's batch size." - ); num_rows = indices->size(0); indices_ptr = indices->data_ptr(); } else { diff --git a/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py b/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py index 401fac5..0058546 100644 --- a/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py +++ b/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py @@ -73,5 +73,5 @@ def apply_token_bitmask_inplace_cuda( if isinstance(indices, list): indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) if indices is not None: - indices = torch.unique(indices.to(logits.device)) + indices = indices.to(logits.device) torch.ops.xgrammar.apply_token_bitmask_inplace_cuda(logits, bitmask, indices) diff --git a/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py b/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py index 93364c8..0b1fb8c 100644 --- a/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py +++ b/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py @@ -110,16 +110,6 @@ def apply_token_bitmask_inplace_triton( num_rows = None if isinstance(indices, list) or isinstance(indices, torch.Tensor): indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) - indices = torch.unique(indices) - - assert ( - indices.shape[0] <= bitmask_shape[0] - ), f"indices count ({indices.shape[0]}) exceeds bitmask batch size ({bitmask_shape[0]})" - assert indices.min() >= 0, f"negative index found: {indices.min()}" - assert ( - indices.max() < logits_shape[0] - ), f"index {indices.max()} out of bounds for logits batch size {logits_shape[0]}" - num_rows = indices.shape[0] else: assert ( diff --git a/tests/python/test_token_bitmask_operations.py b/tests/python/test_token_bitmask_operations.py index 92d13f7..0af50e7 100644 --- a/tests/python/test_token_bitmask_operations.py +++ b/tests/python/test_token_bitmask_operations.py @@ -253,13 +253,7 @@ def test_apply_token_bitmask_inplace_select_indices( torch.testing.assert_close(logits, logits_expected) -logits_shape__bitmask_shape__indices = [ - ((2, 128), (1, 4), None), - ((2, 128), (2, 5), None), - ((2, 128), (1, 4), [0, 1]), - ((2, 128), (2, 5), [-1]), - ((2, 128), (3, 4), [2]), -] +logits_shape__bitmask_shape__indices = [((2, 128), (1, 4), None), ((2, 128), (2, 5), None)] @pytest.mark.parametrize(