Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica committed Feb 26, 2025
1 parent 7f9d6c4 commit 18309d8
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 42 deletions.
12 changes: 0 additions & 12 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,6 @@ void ApplyTokenBitmaskInplaceCPU(
std::vector<int> indices_value;
if (indices.has_value()) {
indices_value = indices.value();
std::sort(indices_value.begin(), indices_value.end());
indices_value.erase(
std::unique(indices_value.begin(), indices_value.end()), indices_value.end()
);
XGRAMMAR_CHECK(indices_value.front() >= 0)
<< "The provided indices is negative: " << indices_value.front();
XGRAMMAR_CHECK(indices_value.back() < logits_shape.first)
<< "The provided indices is larger than the logits's batch size: " << indices_value.back()
<< " >= " << logits_shape.first;
XGRAMMAR_CHECK(static_cast<int>(indices_value.size()) <= bitmask_shape.first)
<< "The provided indices is larger than the bitmask's batch size: " << indices_value.size()
<< " >= " << bitmask_shape.first;
} else {
XGRAMMAR_CHECK(logits_shape.first == bitmask_shape.first)
<< "When indices is not provided, the logits's batch size should be equal to the "
Expand Down
12 changes: 0 additions & 12 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,6 @@ void ApplyTokenBitmaskInplace(
TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous.");
TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor.");
TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32.");
TORCH_CHECK(
indices->size(0) <= bitmask_shape.first,
"indices must have the batch size no larger than bitmask's batch size."
);
TORCH_CHECK(
indices->index({0}).item<int32_t>() >= 0,
"indices must have the minimum value no less than 0."
);
TORCH_CHECK(
indices->index({indices->size(0) - 1}).item<int32_t>() < logits_shape.first,
"indices must have the maximum value no larger than logits's batch size."
);
num_rows = indices->size(0);
indices_ptr = indices->data_ptr<int32_t>();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def apply_token_bitmask_inplace_cuda(
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = torch.unique(indices.to(logits.device))
indices = indices.to(logits.device)
torch.ops.xgrammar.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
10 changes: 0 additions & 10 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,6 @@ def apply_token_bitmask_inplace_triton(
num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
indices = torch.unique(indices)

assert (
indices.shape[0] <= bitmask_shape[0]
), f"indices count ({indices.shape[0]}) exceeds bitmask batch size ({bitmask_shape[0]})"
assert indices.min() >= 0, f"negative index found: {indices.min()}"
assert (
indices.max() < logits_shape[0]
), f"index {indices.max()} out of bounds for logits batch size {logits_shape[0]}"

num_rows = indices.shape[0]
else:
assert (
Expand Down
8 changes: 1 addition & 7 deletions tests/python/test_token_bitmask_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,7 @@ def test_apply_token_bitmask_inplace_select_indices(
torch.testing.assert_close(logits, logits_expected)


logits_shape__bitmask_shape__indices = [
((2, 128), (1, 4), None),
((2, 128), (2, 5), None),
((2, 128), (1, 4), [0, 1]),
((2, 128), (2, 5), [-1]),
((2, 128), (3, 4), [2]),
]
logits_shape__bitmask_shape__indices = [((2, 128), (1, 4), None), ((2, 128), (2, 5), None)]


@pytest.mark.parametrize(
Expand Down

0 comments on commit 18309d8

Please sign in to comment.