Skip to content

Commit 630cae2

Browse files
authored
Add option for cauasl (#107)
1 parent 41ef5bf commit 630cae2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

attn_gym/masks/document_mask.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
from torch import Tensor
7-
from torch.nn.attention.flex_attention import _mask_mod_signature
7+
from torch.nn.attention.flex_attention import _mask_mod_signature, noop_mask
88
from attn_gym.masks import causal_mask
99

1010

@@ -59,7 +59,7 @@ def doc_mask_mod(b, h, q_idx, kv_idx):
5959
return doc_mask_mod
6060

6161

62-
def main(device: str = "cpu"):
62+
def main(device: str = "cpu", causal: bool = True):
6363
"""Visualize the attention scores of document causal mask mod.
6464
6565
Args:
@@ -93,7 +93,11 @@ def make_tensor():
9393
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
9494

9595
query, key = make_tensor(), make_tensor()
96-
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
96+
if causal:
97+
base_mask_mod = causal_mask
98+
else:
99+
base_mask_mod = noop_mask
100+
document_causal_mask = generate_doc_mask_mod(base_mask_mod, offsets)
97101

98102
visualize_attention_scores(
99103
query,

0 commit comments

Comments
 (0)