Skip to content

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025

Conversation

sighingnow
Copy link
Contributor

@sighingnow sighingnow commented Jan 8, 2025

This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.

This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.

FIX #12452

Copy link

github-actions bot commented Jan 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 8, 2025
@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch 2 times, most recently from 82b5a4c to 4c4a33e Compare January 9, 2025 06:17
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph.
Do you plan to fix this in the future?

Copy link

mergify bot commented Jan 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sighingnow
Copy link
Contributor Author

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

st] = decode_metadata.block_tables[i, st:ed]
decode_metadata.block_tables_intra = block_tables_intra

seq_lens_succ = (chunk_num_curr -

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.

seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)

Copy link

mergify bot commented Jan 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 15, 2025
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

Capturing CUDA graph shapes:   0%|                                                                                                                                                                                                               | 0/35 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 435, in <module>
[rank0]:     ht = LLMNeedleHaystackTester(
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 94, in __init__
[rank0]:     self.model_to_test = LLM(model=model_name)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/utils.py", line 1044, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/entrypoints/llm.py", line 228, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 517, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 276, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 429, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/executor/gpu_executor.py", line 83, in initialize_cache
[rank0]:     self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 274, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 292, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1533, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1885, in capture
[rank0]:     self.model(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 496, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/compilation/decorators.py", line 170, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 359, in forward
[rank0]:     hidden_states, residual = layer(
[rank0]:                               ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 267, in forward
[rank0]:     hidden_states = self.self_attn(
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 189, in forward
[rank0]:     attn_output = self.attn(q,
[rank0]:                   ^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 185, in forward
[rank0]:     return torch.ops.vllm.unified_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 280, in unified_attention
[rank0]:     return self.impl.forward(query, key, value, kv_cache, attn_metadata,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 373, in forward
[rank0]:     assert decode_meta.scaling_factor is not None
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError

@mergify mergify bot removed the needs-rebase label Jan 16, 2025
@sighingnow
Copy link
Contributor Author

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

The dual chunk attention doesn't support cuda graph and I have added an assertion in arg_utils.py.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.

It is indeed a bug introduced during preparing this PR, fixed. Thanks!

@sighingnow
Copy link
Contributor Author

sighingnow commented Jan 19, 2025

Rebase against main.

Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request?

Thanks!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?

Copy link

mergify bot commented Jan 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 20, 2025
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq)
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq)
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think these einsum's are still slow on cuda than (a * b).sum(-1), not on the hot path though so not critical

pytorch/pytorch#101249

ran bench_einsum.py from that issue on an H100 and got:

python einsum_bench.py 
[-------------------------------------  -------------------------------------]
                                  |  mul/sum  |  torch.einsum  |  numpy.einsum
1 threads: -------------------------------------------------------------------
      Nc,Nc->N cpu (1048576, 2)   |    5000   |      3100      |      4000    
      Nc,Nc->N cuda (1048576, 2)  |      20   |       747      |      3300    

Times are in microseconds (us).

Comment on lines 102 to 115
logits_soft_cap, attn_type, **{
"dual_chunk_attention_config": dual_chunk_attention_config,
"prefix": prefix,
} if dual_chunk_attention_config is not None else {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this messy, I think we should maybe do something like:

def __init__(..., **extra_attn_kwargs):
   self.impl = impl_cls(..., **extra_attn_kwargs)

the challenge here is prefix would not be captured by extra_attn_kwargs but is only (currently) used by DualChunkFlashAttentionImpl. I do think it would be less messy though to do this any make prefix a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon

Comment on lines 148 to 158
if self.dual_chunk_attention_config:
assert query_succ_and_inter is not None
dca_kwargs = {
"query_succ": query_succ_and_inter[0],
"query_inter": query_succ_and_inter[1],
"query_succ_critical": query_succ_and_inter[2],
"query_inter_critical": query_succ_and_inter[3],
} if query_succ_and_inter else {}
else:
dca_kwargs = {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q tensor that get reinterpreted as components via a combination of slicing and .view calls in the attn impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would take a try to see if it can be simplified.

@mergify mergify bot removed the needs-rebase label Apr 12, 2025
@sighingnow
Copy link
Contributor Author

I think this is getting very close, thanks for rebasing it! My main concern right now is the large text files in the repo. Also there appear to still be unaddressed review comments from before, please ping us when this is ready for final review.

Hi @LucasWilkinson, thanks for these comments. I have rebased this branch over current main, removed those example prompts and provided them as URLs, and address the reviewer comments above in this PR. Now I think it should be ready for landing.

Before landing, a bugfix in flash-attention should be merged first: vllm-project/flash-attention#60. After that, I will revise the dependency version of vllm-flash-attention in this PR.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for rebasing on current main! The code looks pretty clean to me now.

Before this lands, I think we should make sure there's a plan in place to get support for this in vLLM V1. vLLM has switched to V1 by default and we are trying to deprecate V0.

A couple of questions:

  • What will happen with this PR when running Qwen2 on systems where the dual-chunk attention backend is not supported? (e.g. AMD GPUs, TPUs, etc)
  • Does vLLM automatically fall back to V0 when using dual-chunk attention?

Comment on lines 19 to 20
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt") as response:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a timeout to this

Suggested change
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt") as response:
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt", timeout=5) as response:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@sighingnow
Copy link
Contributor Author

A couple of questions:

  • What will happen with this PR when running Qwen2 on systems where the dual-chunk attention backend is not supported? (e.g. AMD GPUs, TPUs, etc)
  • Does vLLM automatically fall back to V0 when using dual-chunk attention?

We have launched the work of migrating qwen related changes in our internal repo to v1 since v1 becomes the default option in vLLM. The dual-chunk-attn backend would be adapted to v1, too, and most of changesets could be reused.

I have added an assertion in arg_utils.py to check if the current platform is cuda(the sparse_attn_func is only available in vllm-project/flash-attention for cuda) and if the current engine is v0.

@LucasWilkinson
Copy link
Collaborator

vllm-project/flash-attention#60 has landed can you please update this PR?

@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch from 795190d to a3efe49 Compare April 15, 2025 18:18
@sighingnow
Copy link
Contributor Author

vllm-project/flash-attention#60 has landed can you please update this PR?

Done, and rebased to main.

@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch 2 times, most recently from 5bbbfe2 to e30cd11 Compare April 21, 2025 01:55
@sighingnow
Copy link
Contributor Author

@LucasWilkinson I have rebased to current main again. Could you please take another look on this PR? Thanks!

Copy link

mergify bot commented May 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 1, 2025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can actually be left as 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa since that includes the PR you need

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies overall this looks good now! Thanks for all the updates, the only things left to see on my end would be:

@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch from e30cd11 to 0dbb63f Compare May 9, 2025 08:50
@mergify mergify bot removed the needs-rebase label May 9, 2025
@sighingnow
Copy link
Contributor Author

@LucasWilkinson
Copy link
Collaborator

@sighingnow Thanks for the update! looking into the CI failure it does not appear to be related (V1 code, this PR does not touch V1) but this is a bit out of my area of expertise, asking around (cc @russellb)

Copy link

mergify bot commented May 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 9, 2025
…h sparse attention support.

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch from 0dbb63f to 44bf2ab Compare May 10, 2025 07:46
@mergify mergify bot removed the needs-rebase label May 10, 2025
@sighingnow
Copy link
Contributor Author

@sighingnow Thanks for the update! looking into the CI failure it does not appear to be related (V1 code, this PR does not touch V1) but this is a bit out of my area of expertise, asking around (cc @russellb)

Rebased against main again. The failed test cases shouldn't be caused by this PR. It failed on a speculative decoding cases and seems that that case is not executed by all PRs.

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) May 12, 2025 19:09
@simon-mo simon-mo merged commit 60f7624 into vllm-project:main May 13, 2025
86 of 91 checks passed
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
@sighingnow sighingnow deleted the dev/dual-chunk-attn branch May 16, 2025 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Support Qwen/Qwen2.5-14B-Instruct-1M