Skip to content

Commit 90c3ed1

Browse files
authored
move cast to before softmax in attention (tinygrad#9213)
* move cast to before softmax in attention saved some memory because exp (which is used for backward) are done in half. training bert seems fine and can fit BS=78 now (from 66) * test
1 parent f0b24d2 commit 90c3ed1

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/unit/test_attention.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
from tinygrad import Tensor, dtypes
3+
4+
# TODO: test_scheduler, but just in uint
5+
class TestAttention(unittest.TestCase):
6+
def test_half_qkv_buffers(self):
7+
BS, seqlen, dim = 10, 4, 100
8+
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
9+
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
10+
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
11+
attn = q.scaled_dot_product_attention(k, v)
12+
sched = attn.schedule()
13+
# attention has 5 kernels now
14+
self.assertEqual(len(sched), 5)
15+
softmax_inputs = sched[1:4]
16+
for si in softmax_inputs:
17+
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
18+
19+
if __name__ == '__main__':
20+
unittest.main()

tinygrad/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3603,7 +3603,7 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tenso
36033603
if attn_mask is not None:
36043604
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
36053605
qk = qk + attn_mask
3606-
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3606+
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
36073607

36083608
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
36093609
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")

0 commit comments

Comments
 (0)