Skip to content

Commit 0a9554f

Browse files
committed
Adds polynomials and Lagrange polynomials.
1 parent b20ab68 commit 0a9554f

File tree

6 files changed

+289
-3
lines changed

6 files changed

+289
-3
lines changed

group/group.go

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ type Scalar interface {
101101
Set(x Scalar) Scalar
102102
// Copy returns a new scalar equal to the receiver.
103103
Copy() Scalar
104+
// IsZero returns true if the receiver is equal to zero.
105+
IsZero() bool
104106
// IsEqual returns true if the receiver is equal to x.
105107
IsEqual(x Scalar) bool
106108
// SetUint64 set the receiver to x, and returns the receiver.

group/group_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ func testScalar(t *testing.T, testTimes int, g group.Group) {
314314

315315
c.Inv(a)
316316
c.Mul(c, a)
317-
if !one.IsEqual(c) {
317+
c.Sub(c, one)
318+
if !c.IsZero() {
318319
test.ReportError(t, c, one, a)
319320
}
320321
}

group/ristretto255.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
r255 "github.com/bwesterb/go-ristretto"
1010
"github.com/cloudflare/circl/expander"
11+
"github.com/cloudflare/circl/internal/conv"
1112
)
1213

1314
// Ristretto255 is a quotient group generated from the edwards25519 curve.
@@ -203,9 +204,9 @@ func (e *ristrettoElement) UnmarshalBinary(data []byte) error {
203204
}
204205

205206
func (s *ristrettoScalar) Group() Group { return Ristretto255 }
206-
func (s *ristrettoScalar) String() string { return fmt.Sprintf("0x%x", s.s.Bytes()) }
207+
func (s *ristrettoScalar) String() string { return conv.BytesLe2Hex(s.s.Bytes()) }
207208
func (s *ristrettoScalar) SetUint64(n uint64) Scalar { s.s.SetUint64(n); return s }
208-
209+
func (s *ristrettoScalar) IsZero() bool { return s.s.IsNonZeroI() == 0 }
209210
func (s *ristrettoScalar) IsEqual(x Scalar) bool {
210211
return s.s.Equals(&x.(*ristrettoScalar).s)
211212
}

group/short.go

+4
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ type wScl struct {
274274
func (s *wScl) Group() Group { return s.wG }
275275
func (s *wScl) String() string { return fmt.Sprintf("0x%x", s.k) }
276276
func (s *wScl) SetUint64(n uint64) Scalar { s.fromBig(new(big.Int).SetUint64(n)); return s }
277+
func (s *wScl) IsZero() bool {
278+
return subtle.ConstantTimeCompare(s.k, make([]byte, (s.wG.c.Params().BitSize+7)/8)) == 1
279+
}
280+
277281
func (s *wScl) IsEqual(a Scalar) bool {
278282
aa := s.cvtScl(a)
279283
return subtle.ConstantTimeCompare(s.k, aa.k) == 1

math/polynomial/polynomial.go

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Package polynomial provides representations of polynomials over the scalars
2+
// of a group.
3+
package polynomial
4+
5+
import "github.com/cloudflare/circl/group"
6+
7+
// Polynomial stores a polynomial over the set of scalars of a group.
8+
type Polynomial struct {
9+
// Internal representation is in polynomial basis:
10+
// Thus,
11+
// p(x) = \sum_i^k c[i] x^i,
12+
// where k = len(c)-1 is the degree of the polynomial.
13+
c []group.Scalar
14+
}
15+
16+
// New creates a new polynomial given its coefficients in ascending order.
17+
// Thus,
18+
// p(x) = \sum_i^k c[i] x^i,
19+
// where k = len(c)-1 is the degree of the polynomial.
20+
//
21+
// The zero polynomial has degree equal to -1 and can be instantiated passing
22+
// nil to New.
23+
func New(coeffs []group.Scalar) (p Polynomial) {
24+
if l := len(coeffs); l != 0 {
25+
p.c = make([]group.Scalar, l)
26+
for i := range coeffs {
27+
p.c[i] = coeffs[i].Copy()
28+
}
29+
}
30+
31+
return
32+
}
33+
34+
func (p Polynomial) Degree() int {
35+
i := len(p.c) - 1
36+
for i > 0 && p.c[i].IsZero() {
37+
i--
38+
}
39+
return i
40+
}
41+
42+
func (p Polynomial) Evaluate(x group.Scalar) group.Scalar {
43+
px := x.Group().NewScalar()
44+
if l := len(p.c); l != 0 {
45+
px.Set(p.c[l-1])
46+
for i := l - 2; i >= 0; i-- {
47+
px.Mul(px, x)
48+
px.Add(px, p.c[i])
49+
}
50+
}
51+
return px
52+
}
53+
54+
// LagrangePolynomial stores a Lagrange polynomial over the set of scalars of a group.
55+
type LagrangePolynomial struct {
56+
// Internal representation is in Lagrange basis:
57+
// Thus,
58+
// p(x) = \sum_i^k y[i] L_j(x), where k is the degree of the polynomial,
59+
// L_j(x) = \prod_i^k (x-x[i])/(x[j]-x[i]),
60+
// y[i] = p(x[i]), and
61+
// all x[i] are different.
62+
x, y []group.Scalar
63+
}
64+
65+
// NewLagrangePolynomial creates a polynomial in Lagrange basis given a list
66+
// of nodes (x) and values (y), such that:
67+
// p(x) = \sum_i^k y[i] L_j(x), where k is the degree of the polynomial,
68+
// L_j(x) = \prod_i^k (x-x[i])/(x[j]-x[i]),
69+
// y[i] = p(x[i]), and
70+
// all x[i] are different.
71+
// It panics if one of these conditions does not hold.
72+
//
73+
// The zero polynomial has degree equal to -1 and can be instantiated passing
74+
// (nil,nil) to NewLagrangePolynomial.
75+
func NewLagrangePolynomial(x, y []group.Scalar) (l LagrangePolynomial) {
76+
if len(x) != len(y) {
77+
panic("lagrange: invalid length")
78+
}
79+
80+
if !areAllDifferent(x) {
81+
panic("lagrange: x[i] must be different")
82+
}
83+
84+
if n := len(x); n != 0 {
85+
l.x, l.y = make([]group.Scalar, n), make([]group.Scalar, n)
86+
for i := range x {
87+
l.x[i], l.y[i] = x[i].Copy(), y[i].Copy()
88+
}
89+
}
90+
91+
return
92+
}
93+
94+
func (l LagrangePolynomial) Degree() int { return len(l.x) - 1 }
95+
96+
func (l LagrangePolynomial) Evaluate(x group.Scalar) group.Scalar {
97+
px := x.Group().NewScalar()
98+
tmp := x.Group().NewScalar()
99+
for i := range l.x {
100+
LjX := baseRatio(uint(i), l.x, x)
101+
tmp.Mul(l.y[i], LjX)
102+
px.Add(px, tmp)
103+
}
104+
105+
return px
106+
}
107+
108+
// LagrangeBase returns the j-th Lagrange polynomial base evaluated at x.
109+
// Thus, L_j(x) = \prod (x - x[i]) / (x[j] - x[i]) for 0 <= i < k, and i != j.
110+
func LagrangeBase(jth uint, xi []group.Scalar, x group.Scalar) group.Scalar {
111+
if jth >= uint(len(xi)) {
112+
panic("lagrange: invalid index")
113+
}
114+
return baseRatio(jth, xi, x)
115+
}
116+
117+
func baseRatio(jth uint, xi []group.Scalar, x group.Scalar) group.Scalar {
118+
num := x.Copy()
119+
num.SetUint64(1)
120+
den := x.Copy()
121+
den.SetUint64(1)
122+
123+
tmp := x.Copy()
124+
for i := range xi {
125+
if uint(i) != jth {
126+
num.Mul(num, tmp.Sub(x, xi[i]))
127+
den.Mul(den, tmp.Sub(xi[jth], xi[i]))
128+
}
129+
}
130+
131+
return num.Mul(num, den.Inv(den))
132+
}
133+
134+
func areAllDifferent(x []group.Scalar) bool {
135+
m := make(map[string]struct{})
136+
for i := range x {
137+
k, err := x[i].MarshalBinary()
138+
if err != nil {
139+
panic(err)
140+
}
141+
if _, exists := m[string(k)]; exists {
142+
return false
143+
}
144+
m[string(k)] = struct{}{}
145+
}
146+
return true
147+
}

math/polynomial/polynomial_test.go

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package polynomial_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/cloudflare/circl/group"
7+
"github.com/cloudflare/circl/internal/test"
8+
"github.com/cloudflare/circl/math/polynomial"
9+
)
10+
11+
func TestPolyDegree(t *testing.T) {
12+
g := group.P256
13+
14+
t.Run("zeroPoly", func(t *testing.T) {
15+
p := polynomial.New(nil)
16+
test.CheckOk(p.Degree() == -1, "it should be -1", t)
17+
p = polynomial.New([]group.Scalar{})
18+
test.CheckOk(p.Degree() == -1, "it should be -1", t)
19+
})
20+
21+
t.Run("constantPoly", func(t *testing.T) {
22+
c := []group.Scalar{
23+
g.NewScalar().SetUint64(0),
24+
g.NewScalar().SetUint64(0),
25+
}
26+
p := polynomial.New(c)
27+
test.CheckOk(p.Degree() == 0, "it should be 0", t)
28+
})
29+
30+
t.Run("linearPoly", func(t *testing.T) {
31+
c := []group.Scalar{
32+
g.NewScalar().SetUint64(0),
33+
g.NewScalar().SetUint64(1),
34+
g.NewScalar().SetUint64(0),
35+
}
36+
p := polynomial.New(c)
37+
test.CheckOk(p.Degree() == 1, "it should be 1", t)
38+
})
39+
}
40+
41+
func TestPolyEval(t *testing.T) {
42+
g := group.P256
43+
c := []group.Scalar{
44+
g.NewScalar(),
45+
g.NewScalar(),
46+
g.NewScalar(),
47+
}
48+
c[0].SetUint64(5)
49+
c[1].SetUint64(5)
50+
c[2].SetUint64(2)
51+
p := polynomial.New(c)
52+
53+
x := g.NewScalar()
54+
x.SetUint64(10)
55+
56+
got := p.Evaluate(x)
57+
want := g.NewScalar()
58+
want.SetUint64(255)
59+
if !got.IsEqual(want) {
60+
test.ReportError(t, got, want)
61+
}
62+
}
63+
64+
func TestLagrange(t *testing.T) {
65+
g := group.P256
66+
c := []group.Scalar{
67+
g.NewScalar(),
68+
g.NewScalar(),
69+
g.NewScalar(),
70+
}
71+
c[0].SetUint64(1234)
72+
c[1].SetUint64(166)
73+
c[2].SetUint64(94)
74+
p := polynomial.New(c)
75+
76+
x := []group.Scalar{g.NewScalar(), g.NewScalar(), g.NewScalar()}
77+
x[0].SetUint64(2)
78+
x[1].SetUint64(4)
79+
x[2].SetUint64(5)
80+
81+
y := []group.Scalar{}
82+
for i := range x {
83+
y = append(y, p.Evaluate(x[i]))
84+
}
85+
86+
zero := g.NewScalar()
87+
l := polynomial.NewLagrangePolynomial(x, y)
88+
test.CheckOk(l.Degree() == p.Degree(), "bad degree", t)
89+
90+
got := l.Evaluate(zero)
91+
want := p.Evaluate(zero)
92+
93+
if !got.IsEqual(want) {
94+
test.ReportError(t, got, want)
95+
}
96+
97+
// Test Kronecker's delta of LagrangeBase.
98+
// Thus:
99+
// L_j(x[i]) = { 1, if i == j;
100+
// { 0, otherwise.
101+
one := g.NewScalar()
102+
one.SetUint64(1)
103+
for j := range x {
104+
for i := range x {
105+
got := polynomial.LagrangeBase(uint(j), x, x[i])
106+
107+
if i == j {
108+
want = one
109+
} else {
110+
want = zero
111+
}
112+
113+
if !got.IsEqual(want) {
114+
test.ReportError(t, got, want)
115+
}
116+
}
117+
}
118+
119+
// Test that inputs are different length
120+
err := test.CheckPanic(func() { polynomial.NewLagrangePolynomial(x, y[:1]) })
121+
test.CheckNoErr(t, err, "should panic")
122+
123+
// Test that nodes must be different.
124+
x[0].Set(x[1])
125+
err = test.CheckPanic(func() { polynomial.NewLagrangePolynomial(x, y) })
126+
test.CheckNoErr(t, err, "should panic")
127+
128+
// Test LagrangeBase wrong index
129+
err = test.CheckPanic(func() { polynomial.LagrangeBase(10, x, zero) })
130+
test.CheckNoErr(t, err, "should panic")
131+
}

0 commit comments

Comments
 (0)