diff --git a/crypto/src/hash/monolith/mod.rs b/crypto/src/hash/monolith/mod.rs index acf99284b..bf8c2e005 100644 --- a/crypto/src/hash/monolith/mod.rs +++ b/crypto/src/hash/monolith/mod.rs @@ -166,7 +166,11 @@ impl MonolithMersenne31 IsGroup for EdwardsProjectivePoint { let num_s2 = &y1y2 - E::a() * &x1x2; let den_s2 = &one - &dx1x2y1y2; - // SAFETY: The creation of the result point is safe because the inputs are always points that belong to the curve. - let point = Self::new([&num_s1 / &den_s1, &num_s2 / &den_s2, one]); + // We are using that den_s1 and den_s2 aren't zero. + // See Theorem 3.3 from https://eprint.iacr.org/2007/286.pdf. + let x_coord = (&num_s1 / &den_s1).unwrap(); + let y_coord = (&num_s2 / &den_s2).unwrap(); + let point = Self::new([x_coord, y_coord, one]); point.unwrap() } diff --git a/math/src/elliptic_curve/montgomery/point.rs b/math/src/elliptic_curve/montgomery/point.rs index 7e860ccdf..5bcce1cdf 100644 --- a/math/src/elliptic_curve/montgomery/point.rs +++ b/math/src/elliptic_curve/montgomery/point.rs @@ -142,7 +142,11 @@ impl IsGroup for MontgomeryProjectivePoint { let x1_square = &x1 * &x1; let num = &x1_square + &x1_square + x1_square + &x1a + x1a + &one; let den = (&b + &b) * &y1; - let div = num / den; + + // We are using that den != 0 because b and y1 aren't zero. + // b != 0 because the cofficient b of a montgomery elliptic curve has to be different from zero. + // y1 != 0 because if not, it woould be the case from above: x2 = x1 and y2 + y1 = 0. + let div = unsafe { (num / den).unwrap_unchecked() }; let new_x = &div * &div * &b - (&x1 + x2) - a; let new_y = div * (x1 - &new_x) - y1; @@ -156,7 +160,8 @@ impl IsGroup for MontgomeryProjectivePoint { } else { let num = &y2 - &y1; let den = &x2 - &x1; - let div = num / den; + + let div = unsafe { (num / den).unwrap_unchecked() }; let new_x = &div * &div * E::b() - (&x1 + &x2) - E::a(); let new_y = div * (x1 - &new_x) - y1; diff --git a/math/src/elliptic_curve/short_weierstrass/curves/bls12_377/field_extension.rs b/math/src/elliptic_curve/short_weierstrass/curves/bls12_377/field_extension.rs index c5023ac44..d806d7318 100644 --- a/math/src/elliptic_curve/short_weierstrass/curves/bls12_377/field_extension.rs +++ b/math/src/elliptic_curve/short_weierstrass/curves/bls12_377/field_extension.rs @@ -69,8 +69,9 @@ impl IsField for Degree2ExtensionField { } /// Returns the division of `a` and `b` - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -124,9 +125,11 @@ impl IsSubFieldOf for BLS12377PrimeField { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree2ExtensionField::inv(b)?; + Ok(>::mul( + a, &b_inv, + )) } fn sub( @@ -378,7 +381,7 @@ mod tests { let a = Fp6E::from(3); let a_extension = Fp12E::from(3); let b = Fp12E::from(2); - assert_eq!(a / &b, a_extension / b); + assert_eq!((a / &b).unwrap(), (a_extension / b).unwrap()); } #[test] diff --git a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs index cb41e5aeb..394282597 100644 --- a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs +++ b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs @@ -71,8 +71,9 @@ impl IsField for Degree2ExtensionField { } /// Returns the division of `a` and `b` - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -126,9 +127,11 @@ impl IsSubFieldOf for BLS12381PrimeField { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree2ExtensionField::inv(b)?; + Ok(>::mul( + a, &b_inv, + )) } fn sub( @@ -429,7 +432,7 @@ mod tests { let a = FieldElement::::from(3); let a_extension = FieldElement::::from(3); let b = FieldElement::::from(2); - assert_eq!(a / &b, a_extension / b); + assert_eq!((a / &b).unwrap(), (a_extension / b).unwrap()); } #[test] diff --git a/math/src/elliptic_curve/short_weierstrass/curves/bn_254/field_extension.rs b/math/src/elliptic_curve/short_weierstrass/curves/bn_254/field_extension.rs index 4bd9f60d0..474666314 100644 --- a/math/src/elliptic_curve/short_weierstrass/curves/bn_254/field_extension.rs +++ b/math/src/elliptic_curve/short_weierstrass/curves/bn_254/field_extension.rs @@ -73,8 +73,9 @@ impl IsField for Degree2ExtensionField { } /// Returns the division of `a` and `b` - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -128,9 +129,11 @@ impl IsSubFieldOf for BN254PrimeField { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree2ExtensionField::inv(b)?; + Ok(>::mul( + a, &b_inv, + )) } fn sub( @@ -397,7 +400,7 @@ mod tests { let a = Fp6E::from(3); let a_extension = Fp12E::from(3); let b = Fp12E::from(2); - assert_eq!(a / &b, a_extension / b); + assert_eq!((a / &b).unwrap(), (a_extension / b).unwrap()); } #[test] diff --git a/math/src/elliptic_curve/short_weierstrass/point.rs b/math/src/elliptic_curve/short_weierstrass/point.rs index 2a82e4c1e..a93863ad1 100644 --- a/math/src/elliptic_curve/short_weierstrass/point.rs +++ b/math/src/elliptic_curve/short_weierstrass/point.rs @@ -347,9 +347,22 @@ where z = ByteConversion::from_bytes_le(&bytes[len * 2..])?; } - let point = - Self::new([x, y, z]).map_err(|_| DeserializationError::FieldFromBytesError)?; - Ok(point) + let Ok(z_inv) = z.inv() else { + let point = Self::new([x, y, z]) + .map_err(|_| DeserializationError::FieldFromBytesError)?; + return if point.is_neutral_element() { + Ok(point) + } else { + Err(DeserializationError::FieldFromBytesError) + }; + }; + let x_affine = &x * &z_inv; + let y_affine = &y * &z_inv; + if E::defining_equation(&x_affine, &y_affine) == FieldElement::zero() { + Self::new([x, y, z]).map_err(|_| DeserializationError::FieldFromBytesError) + } else { + Err(DeserializationError::FieldFromBytesError) + } } PointFormat::Uncompressed => { if bytes.len() % 2 != 0 { diff --git a/math/src/field/element.rs b/math/src/field/element.rs index cc8055fbc..b605a56e3 100644 --- a/math/src/field/element.rs +++ b/math/src/field/element.rs @@ -329,12 +329,11 @@ where F: IsSubFieldOf, L: IsField, { - type Output = FieldElement; + type Output = Result, FieldError>; fn div(self, rhs: &FieldElement) -> Self::Output { - Self::Output { - value: >::div(&self.value, &rhs.value), - } + let value = >::div(&self.value, &rhs.value)?; + Ok(FieldElement:: { value }) } } @@ -343,7 +342,7 @@ where F: IsSubFieldOf, L: IsField, { - type Output = FieldElement; + type Output = Result, FieldError>; fn div(self, rhs: FieldElement) -> Self::Output { &self / &rhs @@ -355,7 +354,7 @@ where F: IsSubFieldOf, L: IsField, { - type Output = FieldElement; + type Output = Result, FieldError>; fn div(self, rhs: &FieldElement) -> Self::Output { &self / rhs @@ -367,7 +366,7 @@ where F: IsSubFieldOf, L: IsField, { - type Output = FieldElement; + type Output = Result, FieldError>; fn div(self, rhs: FieldElement) -> Self::Output { self / &rhs diff --git a/math/src/field/extensions/cubic.rs b/math/src/field/extensions/cubic.rs index d05e1fa9d..95c560971 100644 --- a/math/src/field/extensions/cubic.rs +++ b/math/src/field/extensions/cubic.rs @@ -109,8 +109,12 @@ where } /// Returns the division of `a` and `b` - fn div(a: &[FieldElement; 3], b: &[FieldElement; 3]) -> [FieldElement; 3] { - ::mul(a, &Self::inv(b).unwrap()) + fn div( + a: &[FieldElement; 3], + b: &[FieldElement; 3], + ) -> Result<[FieldElement; 3], FieldError> { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -180,9 +184,11 @@ where fn div( a: &Self::BaseType, b: & as IsField>::BaseType, - ) -> as IsField>::BaseType { - let b_inv = as IsField>::inv(b).unwrap(); - >>::mul(a, &b_inv) + ) -> Result< as IsField>::BaseType, FieldError> { + let b_inv = as IsField>::inv(b)?; + Ok(>>::mul( + a, &b_inv, + )) } fn sub( @@ -285,7 +291,7 @@ mod tests { let a = FEE::new([FE::new(0), FE::new(3), FE::new(2)]); let b = FEE::new([-FE::new(2), FE::new(8), FE::new(5)]); let expected_result = FEE::new([FE::new(12), FE::new(6), FE::new(1)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -293,7 +299,7 @@ mod tests { let a = FEE::new([FE::new(12), FE::new(5), FE::new(4)]); let b = FEE::new([-FE::new(4), FE::new(2), FE::new(2)]); let expected_result = FEE::new([FE::new(3), FE::new(8), FE::new(11)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -379,7 +385,7 @@ mod tests { let a = FE::new(2); let b = FEE::new([-FE::new(2), FE::new(8), FE::new(5)]); let expected_result = FEE::new([FE::new(8), FE::new(4), FE::new(10)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -387,6 +393,6 @@ mod tests { let a = FE::new(4); let b = FEE::new([-FE::new(4), FE::new(2), FE::new(2)]); let expected_result = FEE::new([FE::new(3), FE::new(6), FE::new(11)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } } diff --git a/math/src/field/extensions/quadratic.rs b/math/src/field/extensions/quadratic.rs index 77ab87a0d..00cada4fb 100644 --- a/math/src/field/extensions/quadratic.rs +++ b/math/src/field/extensions/quadratic.rs @@ -118,8 +118,12 @@ where } /// Returns the division of `a` and `b` - fn div(a: &[FieldElement; 2], b: &[FieldElement; 2]) -> [FieldElement; 2] { - ::mul(a, &Self::inv(b).unwrap()) + fn div( + a: &[FieldElement; 2], + b: &[FieldElement; 2], + ) -> Result<[FieldElement; 2], FieldError> { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -176,9 +180,11 @@ where fn div( a: &Self::BaseType, b: & as IsField>::BaseType, - ) -> as IsField>::BaseType { - let b_inv = as IsField>::inv(b).unwrap(); - >>::mul(a, &b_inv) + ) -> Result< as IsField>::BaseType, FieldError> { + let b_inv = as IsField>::inv(b)?; + Ok(>>::mul( + a, &b_inv, + )) } fn sub( @@ -282,7 +288,7 @@ mod tests { let a = FEE::new([FE::new(0), FE::new(3)]); let b = FEE::new([-FE::new(2), FE::new(8)]); let expected_result = FEE::new([FE::new(42), FE::new(19)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -290,7 +296,7 @@ mod tests { let a = FEE::new([FE::new(12), FE::new(5)]); let b = FEE::new([-FE::new(4), FE::new(2)]); let expected_result = FEE::new([FE::new(4), FE::new(45)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -383,7 +389,7 @@ mod tests { let a = FE::new(3); let b = FEE::new([-FE::new(2), FE::new(8)]); let expected_result = FEE::new([FE::new(19), FE::new(17)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -391,6 +397,6 @@ mod tests { let a = FE::new(22); let b = FEE::new([FE::new(4), FE::new(2)]); let expected_result = FEE::new([FE::new(28), FE::new(45)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } } diff --git a/math/src/field/fields/fft_friendly/babybear_u32.rs b/math/src/field/fields/fft_friendly/babybear_u32.rs index 5338674c6..81bc67877 100644 --- a/math/src/field/fields/fft_friendly/babybear_u32.rs +++ b/math/src/field/fields/fft_friendly/babybear_u32.rs @@ -142,12 +142,12 @@ mod tests { #[test] fn div_1() { - assert_eq!(FE::from(2) / FE::from(1), FE::from(2)) + assert_eq!((FE::from(2) / FE::from(1)).unwrap(), FE::from(2)) } #[test] fn div_4_2() { - assert_eq!(FE::from(4) / FE::from(2), FE::from(2)) + assert_eq!((FE::from(4) / FE::from(2)).unwrap(), FE::from(2)) } #[test] diff --git a/math/src/field/fields/fft_friendly/quadratic_babybear.rs b/math/src/field/fields/fft_friendly/quadratic_babybear.rs index d3d4ea0c9..81a1630f0 100644 --- a/math/src/field/fields/fft_friendly/quadratic_babybear.rs +++ b/math/src/field/fields/fft_friendly/quadratic_babybear.rs @@ -63,14 +63,6 @@ mod tests { assert_eq!(a.inv().unwrap(), expected_result); } - #[test] - fn test_div_quadratic() { - let a = Fee::new([FE::from(12), FE::from(5)]); - let b = Fee::new([-FE::from(4), FE::from(2)]); - let expected_result = &a * b.inv().unwrap(); - assert_eq!(a / b, expected_result); - } - #[test] fn test_conjugate_quadratic() { let a = Fee::new([FE::from(12), FE::from(5)]); diff --git a/math/src/field/fields/fft_friendly/quartic_babybear.rs b/math/src/field/fields/fft_friendly/quartic_babybear.rs index 23cc0227e..817f446c7 100644 --- a/math/src/field/fields/fft_friendly/quartic_babybear.rs +++ b/math/src/field/fields/fft_friendly/quartic_babybear.rs @@ -73,8 +73,9 @@ impl IsField for Degree4BabyBearExtensionField { ]) } - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { @@ -183,9 +184,12 @@ impl IsSubFieldOf for Babybear31PrimeField { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree4BabyBearExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = + Degree4BabyBearExtensionField::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(>::mul( + a, &b_inv, + )) } fn sub( diff --git a/math/src/field/fields/fft_friendly/quartic_babybear_u32.rs b/math/src/field/fields/fft_friendly/quartic_babybear_u32.rs index a92dfe075..3da691f31 100644 --- a/math/src/field/fields/fft_friendly/quartic_babybear_u32.rs +++ b/math/src/field/fields/fft_friendly/quartic_babybear_u32.rs @@ -77,8 +77,9 @@ impl IsField for Degree4BabyBearU32ExtensionField { ]) } - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { @@ -187,9 +188,9 @@ impl IsSubFieldOf for Babybear31PrimeField { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree4BabyBearU32ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree4BabyBearU32ExtensionField::inv(b)?; + Ok(>::mul(a, &b_inv)) } fn sub( diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs index 69c64f096..081548901 100644 --- a/math/src/field/fields/mersenne31/extensions.rs +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -64,8 +64,9 @@ impl IsField for Degree2ExtensionField { } /// Returns the division of `a` and `b` - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -117,9 +118,11 @@ impl IsSubFieldOf for Mersenne31Field { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree2ExtensionField::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(>::mul( + a, &b_inv, + )) } fn embed(a: Self::BaseType) -> ::BaseType { @@ -186,8 +189,9 @@ impl IsField for Degree4ExtensionField { Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) } - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - ::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(::mul(a, b_inv)) } fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { @@ -238,9 +242,11 @@ impl IsSubFieldOf for Mersenne31Field { fn div( a: &Self::BaseType, b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree4ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + ) -> Result<::BaseType, FieldError> { + let b_inv = Degree4ExtensionField::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(>::mul( + a, &b_inv, + )) } fn embed(a: Self::BaseType) -> ::BaseType { @@ -551,7 +557,7 @@ mod tests { let a = Fp2E::new([FpE::from(12), FpE::from(5)]); let b = Fp2E::new([FpE::from(4), FpE::from(2)]); let expected_result = Fp2E::new([FpE::from(644245097), FpE::from(1288490188)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -559,7 +565,7 @@ mod tests { let a = Fp2E::new([FpE::from(4), FpE::from(7)]); let b = Fp2E::new([FpE::one(), FpE::zero()]); let expected_result = Fp2E::new([FpE::from(4), FpE::from(7)]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] @@ -567,7 +573,7 @@ mod tests { let a = Fp2E::new([FpE::zero(), FpE::zero()]); let b = Fp2E::new([FpE::from(3), FpE::from(12)]); let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); - assert_eq!(a / b, expected_result); + assert_eq!((a / b).unwrap(), expected_result); } #[test] diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 1c8b2dc58..9b5bea83b 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -117,9 +117,9 @@ impl IsField for Mersenne31Field { } /// Returns the division of `a` and `b`. - fn div(a: &u32, b: &u32) -> u32 { - let b_inv = Self::inv(b).expect("InvZeroError"); - Self::mul(a, &b_inv) + fn div(a: &u32, b: &u32) -> Result { + let b_inv = Self::inv(b).map_err(|_| FieldError::DivisionByZero)?; + Ok(Self::mul(a, &b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal or not. @@ -372,18 +372,27 @@ mod tests { #[test] fn div_1() { - assert_eq!(FE::from(&2u32) / FE::from(&1u32), FE::from(&2u32)); + assert_eq!( + (FE::from(&2u32) / FE::from(&1u32)).unwrap(), + FE::from(&2u32) + ); } #[test] fn div_4_2() { - assert_eq!(FE::from(&4u32) / FE::from(&2u32), FE::from(&2u32)); + assert_eq!( + (FE::from(&4u32) / FE::from(&2u32)).unwrap(), + FE::from(&2u32) + ); } #[test] fn div_4_3() { // sage: F(4) / F(3) = 1431655766 - assert_eq!(FE::from(&4u32) / FE::from(&3u32), FE::from(1431655766)); + assert_eq!( + (FE::from(&4u32) / FE::from(&3u32)).unwrap(), + FE::from(1431655766) + ); } #[test] diff --git a/math/src/field/fields/montgomery_backed_prime_fields.rs b/math/src/field/fields/montgomery_backed_prime_fields.rs index 763d77526..1f0455c01 100644 --- a/math/src/field/fields/montgomery_backed_prime_fields.rs +++ b/math/src/field/fields/montgomery_backed_prime_fields.rs @@ -247,8 +247,9 @@ where } #[inline(always)] - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - Self::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } #[inline(always)] @@ -585,7 +586,7 @@ mod tests_u384_prime_fields { #[test] fn div_1() { assert_eq!( - U384F23Element::from(2) / U384F23Element::from(1), + (U384F23Element::from(2) / U384F23Element::from(1)).unwrap(), U384F23Element::from(2) ) } @@ -593,7 +594,7 @@ mod tests_u384_prime_fields { #[test] fn div_4_2() { assert_eq!( - U384F23Element::from(4) / U384F23Element::from(2), + (U384F23Element::from(4) / U384F23Element::from(2)).unwrap(), U384F23Element::from(2) ) } @@ -608,7 +609,7 @@ mod tests_u384_prime_fields { #[test] fn div_4_3() { assert_eq!( - U384F23Element::from(4) / U384F23Element::from(3) * U384F23Element::from(3), + (U384F23Element::from(4) / U384F23Element::from(3)).unwrap() * U384F23Element::from(3), U384F23Element::from(4) ) } @@ -941,7 +942,7 @@ mod tests_u256_prime_fields { #[test] fn div_1() { assert_eq!( - U256F29Element::from(2) / U256F29Element::from(1), + (U256F29Element::from(2) / U256F29Element::from(1)).unwrap(), U256F29Element::from(2) ) } @@ -950,13 +951,13 @@ mod tests_u256_prime_fields { fn div_4_2() { let a = U256F29Element::from(4); let b = U256F29Element::from(2); - assert_eq!(a / &b, b) + assert_eq!((a / &b).unwrap(), b) } #[test] fn div_4_3() { assert_eq!( - U256F29Element::from(4) / U256F29Element::from(3) * U256F29Element::from(3), + (U256F29Element::from(4) / U256F29Element::from(3)).unwrap() * U256F29Element::from(3), U256F29Element::from(4) ) } diff --git a/math/src/field/fields/p448_goldilocks_prime_field.rs b/math/src/field/fields/p448_goldilocks_prime_field.rs index eadd61541..935354117 100644 --- a/math/src/field/fields/p448_goldilocks_prime_field.rs +++ b/math/src/field/fields/p448_goldilocks_prime_field.rs @@ -158,9 +158,9 @@ impl IsField for P448GoldilocksPrimeField { )) } - fn div(a: &U56x8, b: &U56x8) -> U56x8 { - let b_inv = Self::inv(b).unwrap(); - Self::mul(a, &b_inv) + fn div(a: &U56x8, b: &U56x8) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } /// Taken from https://sourceforge.net/p/ed448goldilocks/code/ci/master/tree/src/per_field/f_generic.tmpl.c @@ -428,7 +428,7 @@ mod tests { let num1 = U56x8::from_hex("b86e226f5ac29af28c74e272fc129ab167798f70dedd2ce76aa76204a23beb74c8ddba2a643196c62ee35a18472d6de7d82b6af4b2fc5e58").unwrap(); let num2 = U56x8::from_hex("bb2bd89a1297c7a6052b41be503aa7de2cd6e6775396e76bf995f27f1dccf69131067824ded693bdd6e58fe7c2276fa92ec1d9a0048b9be6").unwrap(); let num3 = P448GoldilocksPrimeField::div(&num1, &num2); - assert_eq!(num3, U56x8::from_hex("707b5cc75967b58ebd28d14d4ed7ed9eaae1187d0b359c7733cf61b1a5c87fc88228ca532c50f19d1ba57146ca2e38417922033f647c8d9").unwrap()); + assert_eq!(num3.unwrap(), U56x8::from_hex("707b5cc75967b58ebd28d14d4ed7ed9eaae1187d0b359c7733cf61b1a5c87fc88228ca532c50f19d1ba57146ca2e38417922033f647c8d9").unwrap()); } #[test] diff --git a/math/src/field/fields/u32_montgomery_backend_prime_field.rs b/math/src/field/fields/u32_montgomery_backend_prime_field.rs index 439b6fab0..0980f59af 100644 --- a/math/src/field/fields/u32_montgomery_backend_prime_field.rs +++ b/math/src/field/fields/u32_montgomery_backend_prime_field.rs @@ -166,8 +166,9 @@ impl IsField for U32MontgomeryBackendPrimeField { } #[inline(always)] - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - Self::mul(a, &Self::inv(b).unwrap()) + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } #[inline(always)] diff --git a/math/src/field/fields/u64_goldilocks_field.rs b/math/src/field/fields/u64_goldilocks_field.rs index 5c1a29d3c..a3883afc2 100644 --- a/math/src/field/fields/u64_goldilocks_field.rs +++ b/math/src/field/fields/u64_goldilocks_field.rs @@ -118,9 +118,9 @@ impl IsField for Goldilocks64Field { } /// Returns the division of `a` and `b`. - fn div(a: &u64, b: &u64) -> u64 { - let b_inv = Self::inv(b).unwrap(); - Self::mul(a, &b_inv) + fn div(a: &u64, b: &u64) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } /// Returns a boolean indicating whether `a` and `b` are equal or not. @@ -386,12 +386,18 @@ mod tests { #[test] fn div_one() { - assert_eq!(F::div(&F::from_base_type(2), &F::from_base_type(1)), 2) + assert_eq!( + F::div(&F::from_base_type(2), &F::from_base_type(1)).unwrap(), + 2 + ) } #[test] fn div_4_2() { - assert_eq!(F::div(&F::from_base_type(4), &F::from_base_type(2)), 2) + assert_eq!( + F::div(&F::from_base_type(4), &F::from_base_type(2)).unwrap(), + 2 + ) } // 1431655766 @@ -399,7 +405,7 @@ mod tests { fn div_4_3() { // sage: F(4) / F(3) = 12297829379609722882 assert_eq!( - F::div(&F::from_base_type(4), &F::from_base_type(3)), + F::div(&F::from_base_type(4), &F::from_base_type(3)).unwrap(), 12297829379609722882 ) } diff --git a/math/src/field/fields/u64_prime_field.rs b/math/src/field/fields/u64_prime_field.rs index 9d16b3d78..ffd075407 100644 --- a/math/src/field/fields/u64_prime_field.rs +++ b/math/src/field/fields/u64_prime_field.rs @@ -39,8 +39,9 @@ impl IsField for U64PrimeField { ((*a as u128 * *b as u128) % MODULUS as u128) as u64 } - fn div(a: &u64, b: &u64) -> u64 { - Self::mul(a, &Self::inv(b).unwrap()) + fn div(a: &u64, b: &u64) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } fn inv(a: &u64) -> Result { @@ -278,17 +279,20 @@ mod tests { #[test] fn div_1() { - assert_eq!(FE::new(2) / FE::new(1), FE::new(2)) + assert_eq!(FE::new(2) * FE::new(1).inv().unwrap(), FE::new(2)) } #[test] fn div_4_2() { - assert_eq!(FE::new(4) / FE::new(2), FE::new(2)) + assert_eq!(FE::new(4) * FE::new(2).inv().unwrap(), FE::new(2)) } #[test] fn div_4_3() { - assert_eq!(FE::new(4) / FE::new(3) * FE::new(3), FE::new(4)) + assert_eq!( + FE::new(4) * FE::new(3).inv().unwrap() * FE::new(3), + FE::new(4) + ) } #[test] diff --git a/math/src/field/test_fields/u32_test_field.rs b/math/src/field/test_fields/u32_test_field.rs index 723f2501c..2f138b50d 100644 --- a/math/src/field/test_fields/u32_test_field.rs +++ b/math/src/field/test_fields/u32_test_field.rs @@ -57,8 +57,9 @@ impl IsField for U32Field { ((*a as u128 * *b as u128) % MODULUS as u128) as u32 } - fn div(a: &u32, b: &u32) -> u32 { - Self::mul(a, &Self::inv(b).unwrap()) + fn div(a: &u32, b: &u32) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } fn inv(a: &u32) -> Result { diff --git a/math/src/field/test_fields/u64_test_field.rs b/math/src/field/test_fields/u64_test_field.rs index afb7f104d..7dd945361 100644 --- a/math/src/field/test_fields/u64_test_field.rs +++ b/math/src/field/test_fields/u64_test_field.rs @@ -30,8 +30,9 @@ impl IsField for U64Field { ((*a as u128 * *b as u128) % MODULUS as u128) as u64 } - fn div(a: &u64, b: &u64) -> u64 { - Self::mul(a, &Self::inv(b).unwrap()) + fn div(a: &u64, b: &u64) -> Result { + let b_inv = &Self::inv(b)?; + Ok(Self::mul(a, b_inv)) } fn inv(a: &u64) -> Result { diff --git a/math/src/field/traits.rs b/math/src/field/traits.rs index 320531f42..ab9534204 100644 --- a/math/src/field/traits.rs +++ b/math/src/field/traits.rs @@ -18,7 +18,7 @@ pub enum RootsConfig { pub trait IsSubFieldOf: IsField { fn mul(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; fn add(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; - fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; + fn div(a: &Self::BaseType, b: &F::BaseType) -> Result; fn sub(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; fn embed(a: Self::BaseType) -> F::BaseType; #[cfg(feature = "alloc")] @@ -45,7 +45,7 @@ where } #[inline(always)] - fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType { + fn div(a: &Self::BaseType, b: &F::BaseType) -> Result { F::div(a, b) } @@ -167,7 +167,7 @@ pub trait IsField: Debug + Clone { fn inv(a: &Self::BaseType) -> Result; /// Returns the division of `a` and `b`. - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType; + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Result; /// Returns a boolean indicating whether `a` and `b` are equal or not. fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool; diff --git a/math/src/polynomial/mod.rs b/math/src/polynomial/mod.rs index 679a83aec..1542f6e04 100644 --- a/math/src/polynomial/mod.rs +++ b/math/src/polynomial/mod.rs @@ -1098,11 +1098,11 @@ mod tests { #[test] fn simple_interpolating_polynomial_by_hand_works() { - let denominator = Polynomial::new(&[FE::new(1) / (FE::new(2) - FE::new(4))]); + let denominator = Polynomial::new(&[FE::new(1) * (FE::new(2) - FE::new(4)).inv().unwrap()]); let numerator = Polynomial::new(&[-FE::new(4), FE::new(1)]); let interpolating = numerator * denominator; assert_eq!( - (FE::new(2) - FE::new(4)) * (FE::new(1) / (FE::new(2) - FE::new(4))), + (FE::new(2) - FE::new(4)) * (FE::new(1) * (FE::new(2) - FE::new(4)).inv().unwrap()), FE::new(1) ); assert_eq!(interpolating.evaluate(&FE::new(2)), FE::new(1)); diff --git a/provers/plonk/src/constraint_system/operations.rs b/provers/plonk/src/constraint_system/operations.rs index f4a977944..c9e207f98 100644 --- a/provers/plonk/src/constraint_system/operations.rs +++ b/provers/plonk/src/constraint_system/operations.rs @@ -287,7 +287,7 @@ mod tests { let inputs = HashMap::from([(input1, a), (input2, b)]); let assignments = system.solve(inputs).unwrap(); - assert_eq!(assignments.get(&result).unwrap(), &(a / b)); + assert_eq!(assignments.get(&result).unwrap(), &(a / b).unwrap()); } #[test] diff --git a/provers/plonk/src/prover.rs b/provers/plonk/src/prover.rs index 3281c48db..b8742cc57 100644 --- a/provers/plonk/src/prover.rs +++ b/provers/plonk/src/prover.rs @@ -15,6 +15,10 @@ use lambdaworks_math::{ }; use lambdaworks_math::{field::traits::IsField, traits::ByteConversion}; +#[derive(Debug)] +pub enum ProverError { + DivisionByZero, +} /// Plonk proof. /// The challenges are denoted /// Round 2: β,γ, @@ -361,7 +365,11 @@ where * lp(b_i, &(&cpi.domain[i] * &cpi.k1)) * lp(c_i, &(&cpi.domain[i] * &k2)); let den = lp(a_i, &s1[i]) * lp(b_i, &s2[i]) * lp(c_i, &s3[i]); - let new_factor = num / den; + // We are using that den != 0 with high probability because beta and gamma are random elements. + let new_factor = (num / den) + .map_err(|_| ProverError::DivisionByZero) + .unwrap(); + let new_term = coefficients.last().unwrap() * &new_factor; coefficients.push(new_term); } @@ -574,9 +582,11 @@ where let zeta_raised_n = Polynomial::new_monomial(r4.zeta.pow(cpi.n + 2), 0); // TODO: Paper says n and 2n, but Gnark uses n+2 and 2n+4 let zeta_raised_2n = Polynomial::new_monomial(r4.zeta.pow(2 * cpi.n + 4), 0); - let l1_zeta = (&r4.zeta.pow(cpi.n as u64) - FieldElement::::one()) - / (&r4.zeta - FieldElement::::one()) - / FieldElement::::from(cpi.n as u64); + // We are using that zeta != 0 because is sampled outside the set of roots of unity, + // and n != 0 because is the length of the trace. + let l1_zeta = ((&r4.zeta.pow(cpi.n as u64) - FieldElement::::one()) + / ((&r4.zeta - FieldElement::::one()) * FieldElement::::from(cpi.n as u64))) + .unwrap(); let mut p_non_constant = &cpi.qm * &r4.a_zeta * &r4.b_zeta + &r4.a_zeta * &cpi.ql diff --git a/provers/plonk/src/verifier.rs b/provers/plonk/src/verifier.rs index 8393148dd..9c9657e68 100644 --- a/provers/plonk/src/verifier.rs +++ b/provers/plonk/src/verifier.rs @@ -81,9 +81,11 @@ impl> Verifier { let k1 = &input.k1; let k2 = k1 * k1; - let l1_zeta = (zeta.pow(input.n as u64) - FieldElement::::one()) - / (&zeta - FieldElement::::one()) - / FieldElement::from(input.n as u64); + // We are using that zeta != 0 because is sampled outside the set of roots of unity, + // and n != 0 because is the length of the trace. + let l1_zeta = ((zeta.pow(input.n as u64) - FieldElement::::one()) + / ((&zeta - FieldElement::::one()) * FieldElement::from(input.n as u64))) + .unwrap(); // Use the following equality to compute PI(ζ) // without interpolating: @@ -97,7 +99,8 @@ impl> Verifier { for (i, value) in public_input.iter().enumerate().skip(1) { li_zeta = &input.omega * &li_zeta - * ((&zeta - &input.domain[i - 1]) / (&zeta - &input.domain[i])); + // We are using that zeta is sampled outside the domain. + * ((&zeta - &input.domain[i - 1]) / (&zeta - &input.domain[i])).unwrap(); p_pi_zeta = &p_pi_zeta + value * &li_zeta; } p_pi_zeta diff --git a/provers/stark/src/constraints/transition.rs b/provers/stark/src/constraints/transition.rs index c35d4da90..6cb9102c0 100644 --- a/provers/stark/src/constraints/transition.rs +++ b/provers/stark/src/constraints/transition.rs @@ -142,7 +142,10 @@ where let denominator = offset_times_x.pow(trace_length / self.period()) - trace_primitive_root.pow(self.offset() * trace_length / self.period()); - numerator.div(denominator) + // The denominator is guaranteed to be non-zero because the sets of powers of `offset_times_x` + // and `trace_primitive_root` are disjoint, provided that the offset is neither an element of the + // interpolation domain nor part of a subgroup with order less than n. + unsafe { numerator.div(denominator).unwrap_unchecked() } }) .collect(); @@ -228,8 +231,9 @@ where let denominator = -trace_primitive_root .pow(self.offset() * trace_length / self.period()) + z.pow(trace_length / self.period()); - - return numerator.div(denominator) * end_exemptions_poly.evaluate(z); + // The denominator isn't zero because z is sampled outside the set of primitive roots. + return unsafe { numerator.div(denominator).unwrap_unchecked() } + * end_exemptions_poly.evaluate(z); } (-trace_primitive_root.pow(self.offset() * trace_length / self.period()) diff --git a/provers/stark/src/examples/fibonacci_rap.rs b/provers/stark/src/examples/fibonacci_rap.rs index b3ee3d496..90581232c 100644 --- a/provers/stark/src/examples/fibonacci_rap.rs +++ b/provers/stark/src/examples/fibonacci_rap.rs @@ -223,7 +223,8 @@ where let n_p_term = not_perm[i - 1].clone() + gamma; let p_term = &perm[i - 1] + gamma; - aux_col.push(z_i * n_p_term.div(p_term)); + // We are using that with high probability p_term != 0 because gamma is a random element. + aux_col.push(z_i * n_p_term.div(p_term).unwrap()); } } @@ -377,7 +378,7 @@ mod test { let n_p_term = not_perm[i - 1] + gamma; let p_term = perm[i - 1] + gamma; - aux_col.push(z_i * n_p_term.div(p_term)); + aux_col.push(z_i * n_p_term.div(p_term).unwrap()); } } diff --git a/provers/stark/src/examples/read_only_memory.rs b/provers/stark/src/examples/read_only_memory.rs index 57fecd1f7..cb1db9162 100644 --- a/provers/stark/src/examples/read_only_memory.rs +++ b/provers/stark/src/examples/read_only_memory.rs @@ -296,12 +296,14 @@ where let mut aux_col = Vec::new(); let num = z - (&a[0] + alpha * &v[0]); let den = z - (&a_sorted[0] + alpha * &v_sorted[0]); - aux_col.push(num / den); + // We are using that den != 0 with high probability because alpha is a random element. + aux_col.push((num / den).unwrap()); // Apply the same equation given in the permutation case to the rest of the trace for i in 0..trace_len - 1 { let num = (z - (&a[i + 1] + alpha * &v[i + 1])) * &aux_col[i]; let den = z - (&a_sorted[i + 1] + alpha * &v_sorted[i + 1]); - aux_col.push(num / den); + // We are using that den != 0 with high probability because alpha is a random element. + aux_col.push((num / den).unwrap()); } for (i, aux_elem) in aux_col.iter().enumerate().take(trace.num_rows()) { @@ -345,7 +347,7 @@ where let den = z - (a_sorted0 + alpha * v_sorted0); let p0_value = num / den; - let c_aux1 = BoundaryConstraint::new_aux(0, 0, p0_value); + let c_aux1 = BoundaryConstraint::new_aux(0, 0, p0_value.unwrap()); let c_aux2 = BoundaryConstraint::new_aux( 0, self.trace_length - 1,