Skip to content

Commit

Permalink
Add higher level custom op for int8 matmul + dequant + bias
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Feb 25, 2025
1 parent 6aeea81 commit cbd1670
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 23 deletions.
40 changes: 38 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,41 @@
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl


# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_linear_dequant",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)


@register_fake("bitsandbytes::int8_linear_dequant")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)


@register_kernel("bitsandbytes::int8_linear_dequant", None)
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out


# Define op
# TODO: mutable output arg as alias of return can be challenging;
# consider a separate op without aliased return:
Expand Down Expand Up @@ -72,7 +107,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):

torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out=None, Tensor? bias=None) -> Tensor",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor",
)


Expand All @@ -81,11 +116,12 @@ def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=torch.float16)
return torch.empty_like(A, dtype=dtype)


torch.library.define(
Expand Down
22 changes: 10 additions & 12 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,8 @@ def forward(
else:
subA = None

# 3. Int8 Matmul
out32 = F.int8_linear_matmul(CA, state.CB)

# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 3. Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_linear_dequant(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)

# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
Expand Down Expand Up @@ -423,8 +415,14 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))

gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
grad_B = torch.ops.bitsandbytes.int8_linear_dequant(
Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)

if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

Expand Down
18 changes: 13 additions & 5 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,33 +89,41 @@ def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")

if bias is not None:
torch._check(bias.dtype == torch.float16, lambda: f"Only fp16 bias is supported, got {bias.dtype}")

if out is None:
# TODO: deprecate out arg
# Note: cuda kernel only currently supports fp16 output.
# We'll later cast to desired dtype if needed.
out = torch.empty_like(A, dtype=torch.float16)

ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrBias = get_ptr(bias)
numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(A.shape[-1])

# Note: fused bias in the kernel is only supported for fp16
# TODO(matthewdouglas): Consider supporting bf16 fused bias
ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None

with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16(
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
)

return out
# Add bias separately if not fused in kernel
if bias is not None and bias.dtype != torch.float16:
out.add_(bias)

return out.to(dtype)


@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,7 +2155,7 @@ def int8_mm_dequant(
Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
"""
return torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats, out, bias)
return torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats, dtype=torch.float16, out=out, bias=bias)


def get_colrow_absmax(
Expand Down
23 changes: 20 additions & 3 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import bitsandbytes
from tests.helpers import id_formatter
from tests.helpers import TRUE_FALSE, id_formatter


class TestLLMInt8Ops:
Expand Down Expand Up @@ -66,8 +66,25 @@ def test_int8_mm_dequant(self, device):

torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))

# def test_int8_double_quant():
# pass
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_linear_dequant(self, device, dtype, has_bias):
if device == "cpu":
pytest.skip("CPU implementation is not available")

A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
row_stats = torch.randn(10, dtype=torch.float32, device=device)
col_stats = torch.randn(20, dtype=torch.float32, device=device)
bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
out = torch.ops.bitsandbytes.int8_linear_dequant(A, B, row_stats, col_stats, bias=bias, dtype=dtype)

assert out.shape == (10, 30)
assert out.dtype == dtype
assert out.device == A.device

torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_dequant, (A, B, row_stats, col_stats, bias, dtype))


class TestInt8BlockwiseQuantOps:
Expand Down

0 comments on commit cbd1670

Please sign in to comment.