From 1ec071d684472b9444a44743750c0ddd6d0bbc57 Mon Sep 17 00:00:00 2001 From: Vlad <45468127+acforvs@users.noreply.github.com> Date: Wed, 19 Feb 2025 18:31:07 +0400 Subject: [PATCH 1/2] Remove l_vec from mha_forward in h100_bench.py Fixes #91 --- kernels/attn/h100/h100_bench.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/kernels/attn/h100/h100_bench.py b/kernels/attn/h100/h100_bench.py index 43689172..70d1417c 100644 --- a/kernels/attn/h100/h100_bench.py +++ b/kernels/attn/h100/h100_bench.py @@ -37,9 +37,6 @@ def benchmark_attention(configurations): kg = torch.zeros_like(k, requires_grad=False, dtype=torch.float).contiguous() vg = torch.zeros_like(v, requires_grad=False, dtype=torch.float).contiguous() - l_vec = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=torch.float, requires_grad=False).contiguous() - d_vec = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=torch.float, requires_grad=False).contiguous() - # Prepare for timing forward pass start_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)] end_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)] @@ -49,12 +46,12 @@ def benchmark_attention(configurations): # Warmup for forward pass for _ in range(10): - tk.mha_forward(q, k, v, o, l_vec, causal) + tk.mha_forward(q, k, v, o, causal) # Time the forward pass for i in range(10): start_events_fwd[i].record() - tk.mha_forward(q, k, v, o, l_vec, causal) + _, l_vec = tk.mha_forward(q, k, v, o, causal) end_events_fwd[i].record() torch.cuda.synchronize() @@ -77,12 +74,12 @@ def benchmark_attention(configurations): # Warmup for backward pass for _ in range(10): - qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, d_vec, grad_output, causal) + qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal) # Time the backward pass for i in range(10): start_events_bwd[i].record() - qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, d_vec, grad_output, causal) + qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal) end_events_bwd[i].record() torch.cuda.synchronize() From b95a02101ddd22ff2295466be1c01a9f7542bb3c Mon Sep 17 00:00:00 2001 From: Vlad <45468127+acforvs@users.noreply.github.com> Date: Wed, 19 Feb 2025 18:34:35 +0400 Subject: [PATCH 2/2] remove o from mha_forward in h100_bench --- kernels/attn/h100/h100_bench.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/attn/h100/h100_bench.py b/kernels/attn/h100/h100_bench.py index 70d1417c..e9449b8c 100644 --- a/kernels/attn/h100/h100_bench.py +++ b/kernels/attn/h100/h100_bench.py @@ -46,12 +46,12 @@ def benchmark_attention(configurations): # Warmup for forward pass for _ in range(10): - tk.mha_forward(q, k, v, o, causal) + tk.mha_forward(q, k, v, causal) # Time the forward pass for i in range(10): start_events_fwd[i].record() - _, l_vec = tk.mha_forward(q, k, v, o, causal) + _, l_vec = tk.mha_forward(q, k, v, causal) end_events_fwd[i].record() torch.cuda.synchronize()