Skip to content

Commit

Permalink
Clean up and improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Feb 18, 2025
1 parent 45ead33 commit 6aeea81
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 432 deletions.
7 changes: 5 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def _(

@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)


Expand All @@ -158,8 +160,9 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,

@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
out = torch.zeros_like(A, dtype=torch.uint8)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax
5 changes: 3 additions & 2 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, "A must be float32")
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")

n = A.numel()
blocks = -(n // -blocksize)
Expand All @@ -44,7 +44,8 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(dtype == torch.float32, "A must be float32 on cpu")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")

out = torch.empty_like(A, dtype=dtype)

Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
Expand Down
186 changes: 0 additions & 186 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Tuple

import pytest
import torch

import bitsandbytes as bnb
from tests.helpers import (
BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
Expand All @@ -16,189 +13,6 @@
TRANSPOSE_VALS = [(False, True), (False, False)]


@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)

if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())

n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001

if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None

loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None

if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)

# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)

out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)

n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)

if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None

loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None

if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02

if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)

if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)

n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001

if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None

loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None

if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02


@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
Expand Down
Loading

0 comments on commit 6aeea81

Please sign in to comment.