Skip to content

Doc mask returns negative sparsity #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
staghado opened this issue Dec 21, 2024 · 3 comments
Open

Doc mask returns negative sparsity #93

staghado opened this issue Dec 21, 2024 · 3 comments

Comments

@staghado
Copy link

staghado commented Dec 21, 2024

The displayed sparsity for the block mask is negative when the block size is bigger than the max size of the mask.
Similar to this issue #68

torch nightly : 2.6.0.dev20241221+cu124
repro code :

import torch
from torch.nn.attention.flex_attention import create_block_mask

document_id = torch.zeros(100, dtype=torch.int, device="cuda")
document_id[:10] = 0
document_id[10:20] = 1
for i in range(20, 100, 20):
    document_id[i : i + 20] = i // 20 + 1

def document_causal_mask(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    document_mask = document_id[q_idx] == document_id[kv_idx]
    return causal_mask & document_mask

mask = create_block_mask(document_causal_mask, 1, 1, 100, 100, "cuda")
print(mask)

output :

BlockMask(shape=(1, 1, 100, 100), sparsity=-63.84%, 
(0, 0)
██
)

does this have an effect on the results obtained with such a mask?

@drisspg
Copy link
Contributor

drisspg commented Dec 21, 2024

This should not have any effect on the result. This is an independent code path from what is used to run w/ flex attention.

I noticed this as well and I know the root cause. I am working on the fix here

pytorch/pytorch#143534

Cc @Chillee

@NuanBaobao
Copy link

When running the above code, I encountered the following error. My PyTorch version is 2.5.1+cu121. Does FlexAttention support Q, K, V with BLOCK_SIZE % 128 ≠ 0? Any help or insights would be greatly appreciated.

2.5.1+cu121
Traceback (most recent call last):
  File "/data/zhangjinhua/VAR-0320-flexattn/flex.py", line 26, in <module>
    mask = create_block_mask(document_causal_mask, 1, 1, 100, 100, "cuda")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 853, in create_block_mask
    partial_block_mask, full_block_mask = inner_func(
                                          ^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 775, in _create_block_mask_inner
    mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 755, in create_mask
    mask = mask_mod(b, h, m, n)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/VAR-0320-flexattn/flex.py", line 22, in document_causal_mask
    document_mask = document_id[q_idx] == document_id[kv_idx]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/zhangjinhua/anaconda3/envs/attn/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 85, in __torch_function__
    return func(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [100,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [101,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [102,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [103,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [104,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [105,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [106,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [107,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [108,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [109,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [110,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [111,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [112,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [113,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [114,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [115,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [116,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [117,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [118,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [119,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [120,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [121,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [122,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [123,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [124,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [125,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [126,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [127,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.

@drisspg
Copy link
Contributor

drisspg commented Mar 22, 2025

Can you try upgrading to a newer version of PyTorch we have fixed many bugs since 2.5.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants