Skip to content

Commit 6a52bfe

Browse files
authored
add truncate_bf16
1 parent 10ceba3 commit 6a52bfe

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

gguf-py/gguf/quants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
3535
return n.astype(np.int16)
3636

3737

38+
# for fp32 values that are just extended bf16
39+
def __truncate_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
40+
n = n.astype(np.float32, copy=False).view(np.uint32) >> 16
41+
return n.astype(np.uint16)
42+
43+
3844
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
3945
def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
4046
rows = arr.reshape((-1, arr.shape[-1]))
@@ -62,6 +68,20 @@ def quantize_bf16(n: np.ndarray):
6268
return __quantize_bf16_array(n)
6369

6470

71+
def __truncate_bf16_array(n: np.ndarray) -> np.ndarray:
72+
return __apply_over_grouped_rows(__truncate_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
73+
74+
75+
__truncate_bf16_lazy = LazyNumpyTensor._wrap_fn(__truncate_bf16_array, meta_noop=np.uint16)
76+
77+
78+
def truncate_bf16(n: np.ndarray):
79+
if type(n) is LazyNumpyTensor:
80+
return __truncate_bf16_lazy(n)
81+
else:
82+
return __truncate_bf16_array(n)
83+
84+
6585
__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
6686

6787

0 commit comments

Comments
 (0)