-
-
Notifications
You must be signed in to change notification settings - Fork 7.5k
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
82b5a4c
to
4c4a33e
Compare
I see that you have |
4c4a33e
to
6b7c49e
Compare
This pull request has merge conflicts that must be resolved before it can be |
6b7c49e
to
35aac26
Compare
35aac26
to
91d5476
Compare
All conflicts fixed, could you please take another look? thanks! |
st] = decode_metadata.block_tables[i, st:ed] | ||
decode_metadata.block_tables_intra = block_tables_intra | ||
|
||
seq_lens_succ = (chunk_num_curr - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.
seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)
This pull request has merge conflicts that must be resolved before it can be |
I tested it because I thought it was fixed, but I still have the same problem as below.
|
91d5476
to
c8781cd
Compare
The dual chunk attention doesn't support cuda graph and I have added an assertion in
It is indeed a bug introduced during preparing this PR, fixed. Thanks! |
c8781cd
to
8648b1e
Compare
Rebase against main. Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?
This pull request has merge conflicts that must be resolved before it can be |
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq) | ||
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq) | ||
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq) | ||
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think these einsum's are still slow on cuda than (a * b).sum(-1)
, not on the hot path though so not critical
ran bench_einsum.py
from that issue on an H100 and got:
python einsum_bench.py
[------------------------------------- -------------------------------------]
| mul/sum | torch.einsum | numpy.einsum
1 threads: -------------------------------------------------------------------
Nc,Nc->N cpu (1048576, 2) | 5000 | 3100 | 4000
Nc,Nc->N cuda (1048576, 2) | 20 | 747 | 3300
Times are in microseconds (us).
vllm/attention/layer.py
Outdated
logits_soft_cap, attn_type, **{ | ||
"dual_chunk_attention_config": dual_chunk_attention_config, | ||
"prefix": prefix, | ||
} if dual_chunk_attention_config is not None else {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this messy, I think we should maybe do something like:
def __init__(..., **extra_attn_kwargs):
self.impl = impl_cls(..., **extra_attn_kwargs)
the challenge here is prefix
would not be captured by extra_attn_kwargs
but is only (currently) used by DualChunkFlashAttentionImpl
. I do think it would be less messy though to do this any make prefix
a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon
vllm/attention/layer.py
Outdated
if self.dual_chunk_attention_config: | ||
assert query_succ_and_inter is not None | ||
dca_kwargs = { | ||
"query_succ": query_succ_and_inter[0], | ||
"query_inter": query_succ_and_inter[1], | ||
"query_succ_critical": query_succ_and_inter[2], | ||
"query_inter_critical": query_succ_and_inter[3], | ||
} if query_succ_and_inter else {} | ||
else: | ||
dca_kwargs = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q
tensor that get reinterpreted as components via a combination of slicing and .view
calls in the attn impl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would take a try to see if it can be simplified.
Hi @LucasWilkinson, thanks for these comments. I have rebased this branch over current main, removed those example prompts and provided them as URLs, and address the reviewer comments above in this PR. Now I think it should be ready for landing. Before landing, a bugfix in flash-attention should be merged first: vllm-project/flash-attention#60. After that, I will revise the dependency version of vllm-flash-attention in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for rebasing on current main! The code looks pretty clean to me now.
Before this lands, I think we should make sure there's a plan in place to get support for this in vLLM V1. vLLM has switched to V1 by default and we are trying to deprecate V0.
A couple of questions:
- What will happen with this PR when running Qwen2 on systems where the dual-chunk attention backend is not supported? (e.g. AMD GPUs, TPUs, etc)
- Does vLLM automatically fall back to V0 when using dual-chunk attention?
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com" | ||
"/Qwen2.5-1M/test-data/600k.txt") as response: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a timeout to this
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com" | |
"/Qwen2.5-1M/test-data/600k.txt") as response: | |
with urlopen("https://qianwen-res.oss-cn-beijing.aliyuncs.com" | |
"/Qwen2.5-1M/test-data/600k.txt", timeout=5) as response: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
We have launched the work of migrating qwen related changes in our internal repo to v1 since v1 becomes the default option in vLLM. The dual-chunk-attn backend would be adapted to v1, too, and most of changesets could be reused. I have added an assertion in arg_utils.py to check if the current platform is cuda(the sparse_attn_func is only available in vllm-project/flash-attention for cuda) and if the current engine is v0. |
vllm-project/flash-attention#60 has landed can you please update this PR? |
795190d
to
a3efe49
Compare
Done, and rebased to main. |
5bbbfe2
to
e30cd11
Compare
@LucasWilkinson I have rebased to current main again. Could you please take another look on this PR? Thanks! |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can actually be left as 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
since that includes the PR you need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies overall this looks good now! Thanks for all the updates, the only things left to see on my end would be:
e30cd11
to
0dbb63f
Compare
Hi @LucasWilkinson, thanks for the feedback. The first three comments has been addressed. |
@sighingnow Thanks for the update! looking into the CI failure it does not appear to be related (V1 code, this PR does not touch V1) but this is a bit out of my area of expertise, asking around (cc @russellb) |
This pull request has merge conflicts that must be resolved before it can be |
…h sparse attention support. Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
0dbb63f
to
44bf2ab
Compare
Rebased against main again. The failed test cases shouldn't be caused by this PR. It failed on a speculative decoding cases and seems that that case is not executed by all PRs. |
…h sparse attention support (vllm-project#11844)
This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.
This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.
FIX #12452