diff --git a/kernels/attn/h100/h100_bench.py b/kernels/attn/h100/h100_bench.py index 43689172..e9449b8c 100644 --- a/kernels/attn/h100/h100_bench.py +++ b/kernels/attn/h100/h100_bench.py @@ -37,9 +37,6 @@ def benchmark_attention(configurations): kg = torch.zeros_like(k, requires_grad=False, dtype=torch.float).contiguous() vg = torch.zeros_like(v, requires_grad=False, dtype=torch.float).contiguous() - l_vec = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=torch.float, requires_grad=False).contiguous() - d_vec = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=torch.float, requires_grad=False).contiguous() - # Prepare for timing forward pass start_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)] end_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)] @@ -49,12 +46,12 @@ def benchmark_attention(configurations): # Warmup for forward pass for _ in range(10): - tk.mha_forward(q, k, v, o, l_vec, causal) + tk.mha_forward(q, k, v, causal) # Time the forward pass for i in range(10): start_events_fwd[i].record() - tk.mha_forward(q, k, v, o, l_vec, causal) + _, l_vec = tk.mha_forward(q, k, v, causal) end_events_fwd[i].record() torch.cuda.synchronize() @@ -77,12 +74,12 @@ def benchmark_attention(configurations): # Warmup for backward pass for _ in range(10): - qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, d_vec, grad_output, causal) + qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal) # Time the backward pass for i in range(10): start_events_bwd[i].record() - qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, d_vec, grad_output, causal) + qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal) end_events_bwd[i].record() torch.cuda.synchronize()