@@ -25,14 +25,14 @@ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizati
25
25
26
26
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
27
27
def __compute_fp32_to_bf16 (n : np .ndarray ) -> np .ndarray :
28
- n = n .astype (np .float32 , copy = False ).view (np .int32 )
28
+ n = n .astype (np .float32 , copy = False ).view (np .uint32 )
29
29
# force nan to quiet
30
- n = np .where ((n & 0x7fffffff ) > 0x7f800000 , (n & 0xffff0000 ) | (64 << 16 ), n )
30
+ n = np .where ((n & 0x7fffffff ) > 0x7f800000 , (n & np . uint32 ( 0xffff0000 ) ) | (64 << 16 ), n )
31
31
# flush subnormals to zero
32
- n = np .where ((n & 0x7f800000 ) == 0 , n & 0x80000000 , n )
32
+ n = np .where ((n & 0x7f800000 ) == 0 , n & np . uint32 ( 0x80000000 ) , n )
33
33
# round to nearest even
34
34
n = (n + (0x7fff + ((n >> 16 ) & 1 ))) >> 16
35
- return n .astype (np .int16 )
35
+ return n .astype (np .uint16 )
36
36
37
37
38
38
# for fp32 values that are just extended bf16
@@ -55,10 +55,10 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
55
55
56
56
57
57
def __quantize_bf16_array (n : np .ndarray ) -> np .ndarray :
58
- return __apply_over_grouped_rows (__compute_fp32_to_bf16 , arr = n , otype = np .int16 , oshape = n .shape )
58
+ return __apply_over_grouped_rows (__compute_fp32_to_bf16 , arr = n , otype = np .uint16 , oshape = n .shape )
59
59
60
60
61
- __quantize_bf16_lazy = LazyNumpyTensor ._wrap_fn (__quantize_bf16_array , meta_noop = np .int16 )
61
+ __quantize_bf16_lazy = LazyNumpyTensor ._wrap_fn (__quantize_bf16_array , meta_noop = np .uint16 )
62
62
63
63
64
64
def quantize_bf16 (n : np .ndarray ):
0 commit comments