diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index ed7d739f3..842fc6124 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -14,6 +14,41 @@ register_fake = torch.library.impl_abstract register_kernel = torch.library.impl + +# Higher level op: int8 matmul + dequant + bias +torch.library.define( + "bitsandbytes::int8_linear_dequant", + "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor", +) + + +@register_fake("bitsandbytes::int8_linear_dequant") +def _( + A: torch.Tensor, + B: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dtype=torch.float16, +) -> torch.Tensor: + shapeC = (*A.shape[:-1], B.shape[0]) + return torch.empty(shapeC, device=A.device, dtype=dtype) + + +@register_kernel("bitsandbytes::int8_linear_dequant", None) +def _( + A: torch.Tensor, + B: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dtype=torch.float16, +) -> torch.Tensor: + out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B) + out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias) + return out + + # Define op # TODO: mutable output arg as alias of return can be challenging; # consider a separate op without aliased return: @@ -72,7 +107,7 @@ def _(A: torch.Tensor, stats: torch.Tensor): torch.library.define( "bitsandbytes::int8_mm_dequant", - "(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out=None, Tensor? bias=None) -> Tensor", + "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor", ) @@ -81,11 +116,12 @@ def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, + dtype=torch.float16, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: "A must be int32") - return torch.empty_like(A, dtype=torch.float16) + return torch.empty_like(A, dtype=dtype) torch.library.define( diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f66cdf68d..8803a97b6 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -364,16 +364,8 @@ def forward( else: subA = None - # 3. Int8 Matmul - out32 = F.int8_linear_matmul(CA, state.CB) - - # Dequantize matmul result - if bias is None or bias.dtype == torch.float16: - # we apply the fused bias here - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) - else: # apply bias separately - # TODO: Fused bias for fp32/bf16? - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias) + # 3. Int8 Matmul + Dequant + Bias + output = torch.ops.bitsandbytes.int8_linear_dequant(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype) # 4. Mixed-precision decomposition matmul if subA is not None and state.subB is not None: @@ -423,8 +415,14 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor if req_gradB: Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16)) - gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) - grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) + grad_B = torch.ops.bitsandbytes.int8_linear_dequant( + Cgrad.t().contiguous(), + CAt.t(), + SCgradt, + SCAt, + dtype=torch.float16, + ) + if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 02c73a11a..20de5c206 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -89,6 +89,7 @@ def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, + dtype=torch.float16, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -96,26 +97,33 @@ def _( torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - if bias is not None: - torch._check(bias.dtype == torch.float16, lambda: f"Only fp16 bias is supported, got {bias.dtype}") - if out is None: + # TODO: deprecate out arg + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. out = torch.empty_like(A, dtype=torch.float16) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) - ptrBias = get_ptr(bias) numRows = ct.c_int32(prod(A.shape[:-1])) numCols = ct.c_int32(A.shape[-1]) + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) ) - return out + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype) @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 72fed6132..327edd86b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2155,7 +2155,7 @@ def int8_mm_dequant( Returns: `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. """ - return torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats, out, bias) + return torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats, dtype=torch.float16, out=out, bias=bias) def get_colrow_absmax( diff --git a/tests/test_ops.py b/tests/test_ops.py index 783ddf3ac..94df6cffb 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,7 +4,7 @@ import torch import bitsandbytes -from tests.helpers import id_formatter +from tests.helpers import TRUE_FALSE, id_formatter class TestLLMInt8Ops: @@ -66,8 +66,25 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - # def test_int8_double_quant(): - # pass + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("has_bias", TRUE_FALSE) + def test_int8_linear_dequant(self, device, dtype, has_bias): + if device == "cpu": + pytest.skip("CPU implementation is not available") + + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) + B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) + row_stats = torch.randn(10, dtype=torch.float32, device=device) + col_stats = torch.randn(20, dtype=torch.float32, device=device) + bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None + out = torch.ops.bitsandbytes.int8_linear_dequant(A, B, row_stats, col_stats, bias=bias, dtype=dtype) + + assert out.shape == (10, 30) + assert out.dtype == dtype + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_dequant, (A, B, row_stats, col_stats, bias, dtype)) class TestInt8BlockwiseQuantOps: