Skip to content

Commit 069369f

Browse files
authored
fix masking in __compute_fp32_to_bf16
1 parent 46054d1 commit 069369f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

gguf-py/gguf/quants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizati
2525

2626
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
2727
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
28-
n = n.astype(np.float32, copy=False).view(np.int32)
28+
n = n.astype(np.float32, copy=False).view(np.uint32)
2929
# force nan to quiet
30-
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
30+
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | (64 << 16), n)
3131
# flush subnormals to zero
32-
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
32+
n = np.where((n & 0x7f800000) == 0, n & np.uint32(0x80000000), n)
3333
# round to nearest even
3434
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
35-
return n.astype(np.int16)
35+
return n.astype(np.uint16)
3636

3737

3838
# for fp32 values that are just extended bf16
@@ -55,10 +55,10 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
5555

5656

5757
def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
58-
return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
58+
return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
5959

6060

61-
__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.int16)
61+
__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16)
6262

6363

6464
def quantize_bf16(n: np.ndarray):

0 commit comments

Comments
 (0)