|
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import triton
|
9 |
| -from scipy import sparse |
| 9 | + |
| 10 | +try: |
| 11 | + from scipy import sparse |
| 12 | +except ImportError as err: |
| 13 | + raise ImportError("Please install scipy via " |
| 14 | + "`pip install scipy` to use " |
| 15 | + "BlockSparseAttention in " |
| 16 | + "models such as Phi-3.") from err |
10 | 17 |
|
11 | 18 |
|
12 | 19 | def dense_to_crow_col(x: torch.Tensor):
|
@@ -77,11 +84,11 @@ def _get_sparse_attn_mask_homo_head(
|
77 | 84 | ):
|
78 | 85 | """
|
79 | 86 | :return: a tuple of 3:
|
80 |
| - - tuple of crow_indices, col_indices representation |
| 87 | + - tuple of crow_indices, col_indices representation |
81 | 88 | of CSR format.
|
82 | 89 | - block dense mask
|
83 |
| - - all token dense mask (be aware that it can be |
84 |
| - OOM if it is too big) if `return_dense==True`, |
| 90 | + - all token dense mask (be aware that it can be |
| 91 | + OOM if it is too big) if `return_dense==True`, |
85 | 92 | otherwise, None
|
86 | 93 | """
|
87 | 94 | with torch.no_grad():
|
@@ -148,10 +155,10 @@ def get_sparse_attn_mask(
|
148 | 155 | :param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
149 | 156 | or "bias" (-inf for skip token, 0 or others)
|
150 | 157 | :return: a tuple of 3:
|
151 |
| - - tuple of crow_indices, col_indices representation |
| 158 | + - tuple of crow_indices, col_indices representation |
152 | 159 | of CSR format.
|
153 | 160 | - block dense mask
|
154 |
| - - all token dense mask (be aware that it can be OOM if it |
| 161 | + - all token dense mask (be aware that it can be OOM if it |
155 | 162 | is too big) if `return_dense==True`, otherwise, None
|
156 | 163 | """
|
157 | 164 | assert dense_mask_type in ("binary", "bias")
|
|
0 commit comments