|
3 | 3 |
|
4 | 4 | from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
|
5 | 5 | from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
|
6 |
| -from vllm_flash_attn import sparse_attn_func as _sparse_attn_func |
7 | 6 |
|
8 | 7 |
|
9 | 8 | @torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
|
@@ -102,33 +101,3 @@ def _(
|
102 | 101 | softcap: float = 0.0,
|
103 | 102 | ) -> torch.Tensor:
|
104 | 103 | return torch.empty_like(decode_query)
|
105 |
| - |
106 |
| -@torch.library.custom_op("vllm::sparse_attn_func", mutates_args=[]) |
107 |
| -def sparse_attn_func( |
108 |
| - q: torch.Tensor, |
109 |
| - k: torch.Tensor, |
110 |
| - v: torch.Tensor, |
111 |
| - block_count: torch.Tensor, |
112 |
| - block_offset: torch.Tensor, |
113 |
| - column_count: torch.Tensor, |
114 |
| - column_index: torch.Tensor, |
115 |
| - softmax_scale: Optional[float] = None, |
116 |
| - causal: bool = False, |
117 |
| - softcap: float = 0.0, |
118 |
| - alibi_slopes: Optional[torch.Tensor] = None, |
119 |
| - return_softmax_lse: Optional[bool] = False, |
120 |
| -) -> torch.Tensor: |
121 |
| - return _sparse_attn_func( |
122 |
| - q=q, |
123 |
| - k=k, |
124 |
| - v=v, |
125 |
| - block_count=block_count, |
126 |
| - block_offset=block_offset, |
127 |
| - column_count=column_count, |
128 |
| - column_index=column_index, |
129 |
| - softmax_scale=softmax_scale, |
130 |
| - causal=causal, |
131 |
| - softcap=softcap, |
132 |
| - alibi_slopes=alibi_slopes, |
133 |
| - return_softmax_lse=return_softmax_lse, |
134 |
| - ) |
0 commit comments