Skip to content

Commit dc71b2e

Browse files
committed
[Kernel] Change interface to Mamba selective_state_update for continuous batching
1 parent c334b18 commit dc71b2e

File tree

2 files changed

+186
-3
lines changed

2 files changed

+186
-3
lines changed

tests/kernels/test_mamba_ssm.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,161 @@ def test_selective_state_update(dim, dstate, has_z, itype):
322322

323323
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
324324
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
325+
326+
327+
@pytest.mark.parametrize("itype",
328+
[torch.float32, torch.float16, torch.bfloat16])
329+
# @pytest.mark.parametrize('itype', [torch.float16])
330+
@pytest.mark.parametrize("has_z", [False, True])
331+
# @pytest.mark.parametrize('has_z', [True])
332+
@pytest.mark.parametrize("dstate", [16, 32, 64])
333+
# @pytest.mark.parametrize("dstate", [16])
334+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
335+
# @pytest.mark.parametrize("dim", [2048])
336+
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
337+
device = "cuda"
338+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
339+
if itype == torch.bfloat16:
340+
rtol, atol = 6e-2, 6e-2
341+
if torch.version.hip:
342+
atol *= 2
343+
# set seed
344+
torch.random.manual_seed(0)
345+
batch_size = 16
346+
347+
total_entries = 10 * batch_size
348+
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
349+
state_indices = torch.randperm(total_entries)[:batch_size].to(
350+
dtype=torch.int32, device=device)
351+
352+
x = torch.randn(batch_size, dim, device=device, dtype=itype)
353+
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
354+
dt_bias = torch.rand(dim, device=device) - 4.0
355+
A = -torch.rand(dim, dstate, device=device) - 1.0
356+
B = torch.randn(batch_size, dstate, device=device)
357+
C = torch.randn(batch_size, dstate, device=device)
358+
D = torch.randn(dim, device=device)
359+
z = torch.randn_like(x) if has_z else None
360+
state_ref = state[state_indices, :].detach().clone()
361+
out = selective_state_update(state,
362+
x,
363+
dt,
364+
A,
365+
B,
366+
C,
367+
D=D,
368+
z=z,
369+
dt_bias=dt_bias,
370+
dt_softplus=True,
371+
state_batch_indices=state_indices)
372+
out_ref = selective_state_update_ref(state_ref,
373+
x,
374+
dt,
375+
A,
376+
B,
377+
C,
378+
D=D,
379+
z=z,
380+
dt_bias=dt_bias,
381+
dt_softplus=True)
382+
383+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
384+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
385+
assert torch.allclose(state[state_indices, :],
386+
state_ref,
387+
rtol=rtol,
388+
atol=atol)
389+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
390+
391+
392+
@pytest.mark.parametrize("itype",
393+
[torch.float32, torch.float16, torch.bfloat16])
394+
#@pytest.mark.parametrize('itype', [torch.float32])
395+
@pytest.mark.parametrize("has_z", [False, True])
396+
# @pytest.mark.parametrize('has_z', [True])
397+
@pytest.mark.parametrize("tie_hdim", [False, True])
398+
# @pytest.mark.parametrize('tie_hdim', [True])
399+
@pytest.mark.parametrize("ngroups", [1, 2, 4])
400+
# @pytest.mark.parametrize("ngroups", [2])
401+
@pytest.mark.parametrize("dstate", [16, 32, 64])
402+
# @pytest.mark.parametrize("dstate", [16])
403+
@pytest.mark.parametrize("dim", [2048, 4096])
404+
# @pytest.mark.parametrize("dim", [2048])
405+
def test_selective_state_update_with_heads_with_batch_indices(
406+
dim, dstate, ngroups, has_z, tie_hdim, itype):
407+
device = "cuda"
408+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
409+
if itype == torch.bfloat16:
410+
rtol, atol = 1e-1, 1e-1
411+
# set seed
412+
torch.random.manual_seed(0)
413+
batch_size = 16
414+
headdim = 64
415+
nheads = dim // headdim
416+
417+
total_entries = 10 * batch_size
418+
state = torch.randn(total_entries,
419+
nheads,
420+
headdim,
421+
dstate,
422+
dtype=itype,
423+
device=device)
424+
state_indices = torch.randperm(total_entries)[:batch_size].to(
425+
dtype=torch.int32, device=device)
426+
427+
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
428+
if not tie_hdim:
429+
dt = torch.randn(batch_size,
430+
nheads,
431+
headdim,
432+
device=device,
433+
dtype=itype)
434+
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
435+
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
436+
D = torch.randn(nheads, headdim, device=device)
437+
else:
438+
dt = repeat(torch.randn(batch_size, nheads, device=device,
439+
dtype=itype),
440+
"b h -> b h p",
441+
p=headdim)
442+
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
443+
"h -> h p",
444+
p=headdim)
445+
A = repeat(-torch.rand(nheads, device=device) - 1.0,
446+
"h -> h p n",
447+
p=headdim,
448+
n=dstate)
449+
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
450+
B = torch.randn(batch_size, ngroups, dstate, device=device)
451+
C = torch.randn(batch_size, ngroups, dstate, device=device)
452+
z = torch.randn_like(x) if has_z else None
453+
state_ref = state[state_indices, :].detach().clone()
454+
out = selective_state_update(state,
455+
x,
456+
dt,
457+
A,
458+
B,
459+
C,
460+
D=D,
461+
z=z,
462+
dt_bias=dt_bias,
463+
dt_softplus=True,
464+
state_batch_indices=state_indices)
465+
out_ref = selective_state_update_ref(state_ref,
466+
x,
467+
dt,
468+
A,
469+
B,
470+
C,
471+
D=D,
472+
z=z,
473+
dt_bias=dt_bias,
474+
dt_softplus=True)
475+
476+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
477+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
478+
assert torch.allclose(state[state_indices, :],
479+
state_ref,
480+
rtol=rtol,
481+
atol=atol)
482+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2024, Tri Dao, Albert Gu.
2+
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
23

34
import torch
45
import triton
@@ -27,6 +28,10 @@ def softplus(dt):
2728
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
2829
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
2930
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
31+
@triton.heuristics({
32+
"HAS_STATE_BATCH_INDICES":
33+
lambda args: args["state_batch_indices_ptr"] is not None
34+
})
3035
@triton.heuristics(
3136
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
3237
@triton.jit
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
4247
D_ptr,
4348
z_ptr,
4449
out_ptr,
50+
state_batch_indices_ptr,
4551
# Matrix dimensions
4652
batch,
4753
nheads,
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
8591
HAS_DT_BIAS: tl.constexpr,
8692
HAS_D: tl.constexpr,
8793
HAS_Z: tl.constexpr,
94+
HAS_STATE_BATCH_INDICES: tl.constexpr,
8895
BLOCK_SIZE_DSTATE: tl.constexpr,
8996
):
9097
pid_m = tl.program_id(axis=0)
9198
pid_b = tl.program_id(axis=1)
9299
pid_h = tl.program_id(axis=2)
93-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100+
101+
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
102+
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
103+
# is the same as the batch id.
104+
if HAS_STATE_BATCH_INDICES:
105+
state_batch_indices_ptr += pid_b
106+
state_batch_idx = tl.load(state_batch_indices_ptr)
107+
state_ptr += (state_batch_idx * stride_state_batch +
108+
pid_h * stride_state_head)
109+
else:
110+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
111+
94112
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
95113
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
96114
if HAS_DT_BIAS:
@@ -177,7 +195,8 @@ def selective_state_update(state,
177195
D=None,
178196
z=None,
179197
dt_bias=None,
180-
dt_softplus=False):
198+
dt_softplus=False,
199+
state_batch_indices=None):
181200
"""
182201
Argument:
183202
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -211,7 +230,10 @@ def selective_state_update(state,
211230
z = z.unsqueeze(1)
212231
if dt_bias is not None and dt_bias.dim() == 1:
213232
dt_bias = dt_bias.unsqueeze(0)
214-
batch, nheads, dim, dstate = state.shape
233+
234+
_, nheads, dim, dstate = state.shape
235+
batch = x.shape[0]
236+
215237
assert x.shape == (batch, nheads, dim)
216238
assert dt.shape == x.shape
217239
assert A.shape == (nheads, dim, dstate)
@@ -225,6 +247,8 @@ def selective_state_update(state,
225247
assert z.shape == x.shape
226248
if dt_bias is not None:
227249
assert dt_bias.shape == (nheads, dim)
250+
if state_batch_indices is not None:
251+
assert state_batch_indices.shape == (batch, )
228252
out = torch.empty_like(x)
229253
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
230254
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
@@ -249,6 +273,7 @@ def selective_state_update(state,
249273
D,
250274
z,
251275
out,
276+
state_batch_indices,
252277
batch,
253278
nheads,
254279
dim,

0 commit comments

Comments
 (0)