diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 3cb5622960..c4d5d6f41a 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -12,6 +12,14 @@ use core::arch::arm::*; #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +#[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; @@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO: Support dotprod when it's available outside of nightly. - let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); - let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); - let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO dotprod once this is the intrinsics are. - let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); - let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); - let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); - let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); - - let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale @@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); @@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - // TODO: dotprod case. - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); @@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), - ); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), - ); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; @@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); - // TODO: dotprod let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); @@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } @@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); @@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { @@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut is = 0usize; // TODO: dotprod - for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), - ); - vaddvq_s16(p1) as i32 * aux[is + index] as i32 - + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 716cca8dee..e7a2ea7f0e 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,4 +1,5 @@ use candle_core::{ + bail, quantized::{self, GgmlDType}, test_utils::to_vec2_round, Device, Module, Result, Tensor, @@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { } } -/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +/// Creates a vector similar to the ones used in GGML unit tests: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 fn create_ggml_like_vector(offset: f32) -> Vec { (0..GGML_TEST_SIZE) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) @@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { sum / a.len() as f32 } -/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 +/// Similar to the GGML quantization unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 fn ggml_quantization_error_test(max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); let mut dst = vec![0.0; GGML_TEST_SIZE]; let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; let error = calculate_rmse(src.as_slice(), dst.as_slice()); if error > max_error { - candle_core::bail!( + bail!( "Quantization error {} exceeds max error {}", error, max_error @@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5K => 0.000740, GgmlDType::Q6K => 0.000952, GgmlDType::Q4_0 => 0.001143, - GgmlDType::Q4_1 => 0.007784, + GgmlDType::Q4_1 => 0.008, GgmlDType::Q5_0 => 0.001353, - GgmlDType::Q5_1 => 0.001363, + GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, - _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), + _ => bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) } -/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +/// Similar to the GGML matmul unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 fn ggml_matmul_error_test() -> Result<()> { let a = create_ggml_like_vector(0.0); let b = create_ggml_like_vector(1.0); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 1.0)?; + // Another example that is more likely to trigger the overflow reported in #1526 + let a = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + let b = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 2.0)?; + Ok(()) +} + +fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Result<()> { let length = a.len(); let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; - T::from_float(&a, &mut a_quant)?; - T::VecDotType::from_float(&b, &mut b_quant)?; + T::from_float(a, &mut a_quant)?; + T::VecDotType::from_float(b, &mut b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; - let reference_result = vec_dot_reference(&a, &b); + let reference_result = vec_dot_reference(a, b); if (result - result_unopt).abs() / length as f32 > 1e-6 { - candle_core::bail!( + bail!( "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" ) } let error = (result - reference_result).abs() / length as f32; - let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - candle_core::bail!( - "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}", - ); + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); } // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // => we use a slightly higher error threshold const ERROR_LENIENCY: f32 = 0.00001; if error - ERROR_LENIENCY > ggml_error { - candle_core::bail!( + bail!( "Dot product error {} exceeds ggml reference error {}", error, ggml_error @@ -543,6 +558,16 @@ fn ggml_matmul_error_test() -> Result<()> { Ok(()) } +#[test] +fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + Ok(()) +} + /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. fn get_random_tensors( m: usize,