Skip to content

Commit e71402a

Browse files
authored
[Feat] proof building (#7)
* port proof building from genesys pr * add tests * use union instead of pipe
1 parent 0ad75e1 commit e71402a

File tree

5 files changed

+357
-0
lines changed

5 files changed

+357
-0
lines changed

tests/test_poly.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import pytest
2+
import torch
3+
import base64
4+
from toploc.poly import (
5+
find_injective_modulus,
6+
build_proofs,
7+
build_proofs_base64,
8+
ProofPoly,
9+
)
10+
11+
12+
def test_find_injective_modulus():
13+
"""Test finding injective modulus"""
14+
x = torch.randint(0, 4_000_000_000, (100,)).tolist()
15+
modulus = find_injective_modulus(x)
16+
assert isinstance(modulus, int)
17+
# Check that all values are unique under modulus
18+
modded = [i % modulus for i in x]
19+
assert len(set(modded)) == len(x)
20+
21+
22+
@pytest.fixture
23+
def sample_poly():
24+
return ProofPoly([1, 2, 3, 4], 65497)
25+
26+
27+
def test_proof_poly_init(sample_poly):
28+
"""Test initialization of ProofPoly"""
29+
assert sample_poly.coeffs == [1, 2, 3, 4]
30+
assert sample_poly.modulus == 65497
31+
32+
33+
def test_proof_poly_call(sample_poly):
34+
"""Test polynomial evaluation"""
35+
x = 42
36+
result = sample_poly(x)
37+
assert isinstance(result, int)
38+
assert result == (1 + 2 * x + 3 * x**2 + 4 * x**3) % 65497
39+
40+
41+
def test_proof_poly_len(sample_poly):
42+
"""Test length of polynomial"""
43+
assert len(sample_poly) == 4
44+
45+
46+
def test_proof_poly_null():
47+
"""Test null polynomial creation"""
48+
length = 5
49+
null_poly = ProofPoly.null(length)
50+
assert len(null_poly) == length
51+
assert null_poly.modulus == 0
52+
assert null_poly.coeffs == [0] * length
53+
54+
55+
def test_proof_poly_from_points_list():
56+
"""Test creation from list points"""
57+
x = [1, 2, 3]
58+
y = [4, 5, 6]
59+
poly = ProofPoly.from_points(x, y)
60+
assert isinstance(poly, ProofPoly)
61+
assert len(poly.coeffs) > 0
62+
63+
64+
def test_proof_poly_from_points_tensor():
65+
"""Test creation from tensor points"""
66+
x = torch.tensor([1, 2, 3])
67+
y = torch.tensor([4, 5, 6])
68+
poly = ProofPoly.from_points(x, y)
69+
assert isinstance(poly, ProofPoly)
70+
assert len(poly.coeffs) == 3
71+
assert poly(1) == 4
72+
assert poly(2) == 5
73+
assert poly(3) == 6
74+
75+
76+
def test_proof_poly_from_points_bfloat16():
77+
"""Test creation from bfloat16 tensor"""
78+
x = torch.tensor([1, 2, 3])
79+
y = torch.tensor([4, 5, 6], dtype=torch.bfloat16)
80+
poly = ProofPoly.from_points(x, y)
81+
assert isinstance(poly, ProofPoly)
82+
assert len(poly.coeffs) == 3
83+
84+
85+
def test_proof_poly_to_base64(sample_poly):
86+
"""Test base64 encoding"""
87+
encoded = sample_poly.to_base64()
88+
assert isinstance(encoded, str)
89+
# Verify it's valid base64
90+
base64.b64decode(encoded)
91+
92+
93+
def test_proof_poly_to_bytes(sample_poly):
94+
"""Test bytes conversion"""
95+
byte_data = sample_poly.to_bytes()
96+
assert isinstance(byte_data, bytes)
97+
assert len(byte_data) > 0
98+
99+
100+
def test_proof_poly_from_bytes(sample_poly):
101+
"""Test creation from bytes"""
102+
byte_data = sample_poly.to_bytes()
103+
reconstructed = ProofPoly.from_bytes(byte_data)
104+
assert reconstructed.coeffs == sample_poly.coeffs
105+
assert reconstructed.modulus == sample_poly.modulus
106+
107+
108+
def test_proof_poly_from_base64(sample_poly):
109+
"""Test creation from base64"""
110+
encoded = sample_poly.to_base64()
111+
reconstructed = ProofPoly.from_base64(encoded)
112+
assert reconstructed.coeffs == sample_poly.coeffs
113+
assert reconstructed.modulus == sample_poly.modulus
114+
115+
116+
def test_proof_poly_repr(sample_poly):
117+
"""Test string representation"""
118+
repr_str = repr(sample_poly)
119+
assert isinstance(repr_str, str)
120+
assert str(65497) in repr_str
121+
assert str([1, 2, 3, 4]) in repr_str
122+
123+
124+
@pytest.fixture
125+
def sample_activations():
126+
DIM = 16
127+
a = [torch.randn(3, DIM, dtype=torch.bfloat16)]
128+
for _ in range(3 * 2 + 1):
129+
a.append(torch.randn(DIM, dtype=torch.bfloat16))
130+
return a
131+
132+
133+
def test_build_proofs(sample_activations):
134+
"""Test building proofs"""
135+
proofs = build_proofs(sample_activations, decode_batching_size=2, topk=5)
136+
assert isinstance(proofs, list)
137+
assert all(isinstance(p, bytes) for p in proofs)
138+
assert len(proofs) == 5
139+
140+
141+
def test_build_proofs_base64(sample_activations):
142+
"""Test building base64 proofs"""
143+
proofs = build_proofs_base64(sample_activations, decode_batching_size=2, topk=5)
144+
assert isinstance(proofs, list)
145+
assert all(isinstance(p, str) for p in proofs)
146+
# Verify each proof is valid base64
147+
for proof in proofs:
148+
base64.b64decode(proof)
149+
assert len(proofs) == 5
150+
151+
152+
def test_build_proofs_skip_prefill(sample_activations):
153+
"""Test building proofs with skip_prefill"""
154+
proofs = build_proofs(
155+
sample_activations, decode_batching_size=2, topk=5, skip_prefill=True
156+
)
157+
assert isinstance(proofs, list)
158+
assert all(isinstance(p, bytes) for p in proofs)
159+
assert len(proofs) == 4
160+
161+
162+
def test_build_proofs_error_handling():
163+
"""Test error handling in proof building"""
164+
invalid_activations = [
165+
torch.randn(0, 16, dtype=torch.bfloat16),
166+
torch.randn(16, dtype=torch.bfloat16),
167+
]
168+
proofs = build_proofs(invalid_activations, decode_batching_size=2, topk=5)
169+
assert isinstance(proofs, list)
170+
assert all(isinstance(p, bytes) for p in proofs)
171+
172+
nullproof = ProofPoly.null(5).to_bytes()
173+
assert all(p == nullproof for p in proofs)
174+
175+
176+
def test_build_proofs_edge_cases(sample_activations):
177+
"""Test edge cases for proof building"""
178+
# Test with minimal topk
179+
proofs_min = build_proofs(sample_activations, decode_batching_size=2, topk=1)
180+
assert len(proofs_min) > 0
181+
182+
# Test with large batching size
183+
proofs_large_batch = build_proofs(
184+
sample_activations, decode_batching_size=10, topk=5
185+
)
186+
assert len(proofs_large_batch) > 0
187+
188+
# Test with only one prefill activation
189+
proofs_one = build_proofs(sample_activations[:1], decode_batching_size=2, topk=5)
190+
assert len(proofs_one) == 1
191+
192+
# Test with only one activation and skip_prefill
193+
proofs_one_skip = build_proofs(
194+
sample_activations[:1], decode_batching_size=2, topk=5, skip_prefill=True
195+
)
196+
assert len(proofs_one_skip) == 0

toploc/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
from toploc.poly import ProofPoly, build_proofs, build_proofs_base64 # noqa: F401
2+
from toploc.utils import sha256sum # noqa: F401
3+
14
__version__ = "0.0.0.dev1"

toploc/ndd.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO: Deprecate this file and move to C
12
MOD_N = 65497
23

34

toploc/poly.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Union
2+
import base64
3+
from toploc.C.csrc.ndd import compute_newton_coefficients, evaluate_polynomial
4+
import torch
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def find_injective_modulus(x: list[int]) -> int:
11+
for i in range(65497, 2**15, -1):
12+
if len(set([j % i for j in x])) == len(x):
13+
return i
14+
raise ValueError("No injective modulus found!")
15+
16+
17+
class ProofPoly:
18+
def __init__(self, coeffs: list[int], modulus: int):
19+
self.coeffs = coeffs
20+
self.modulus = modulus
21+
22+
def __call__(self, x: int):
23+
return evaluate_polynomial(self.coeffs, x % self.modulus)
24+
25+
def __len__(self):
26+
return len(self.coeffs)
27+
28+
@classmethod
29+
def null(cls, length: int) -> "ProofPoly":
30+
return cls([0] * length, 0)
31+
32+
@classmethod
33+
def from_points(
34+
cls, x: Union[list[int], torch.Tensor], y: Union[list[int], torch.Tensor]
35+
) -> "ProofPoly":
36+
if isinstance(x, torch.Tensor):
37+
x = x.tolist()
38+
if isinstance(y, torch.Tensor):
39+
if y.dtype == torch.bfloat16:
40+
y = y.view(dtype=torch.uint16).tolist()
41+
elif y.dtype == torch.float32:
42+
raise NotImplementedError(
43+
"float32 not supported yet because interpolate has hardcode prime"
44+
)
45+
else:
46+
y = y.tolist()
47+
modulus = find_injective_modulus(x)
48+
x = [i % modulus for i in x]
49+
return cls(compute_newton_coefficients(x, y), modulus)
50+
51+
def to_base64(self):
52+
base64_encoded = base64.b64encode(self.to_bytes()).decode("utf-8")
53+
return base64_encoded
54+
55+
def to_bytes(self):
56+
return self.modulus.to_bytes(2, byteorder="big", signed=False) + b"".join(
57+
coeff.to_bytes(2, byteorder="big", signed=False) for coeff in self.coeffs
58+
)
59+
60+
@classmethod
61+
def from_bytes(cls, byte_data: bytes) -> "ProofPoly":
62+
modulus = int.from_bytes(byte_data[:2], byteorder="big", signed=False)
63+
coeffs = [
64+
int.from_bytes(byte_data[i : i + 2], byteorder="big", signed=False)
65+
for i in range(2, len(byte_data), 2)
66+
]
67+
return cls(coeffs, modulus)
68+
69+
@classmethod
70+
def from_base64(cls, base64_encoded: str) -> "ProofPoly":
71+
byte_data = base64.b64decode(base64_encoded)
72+
return cls.from_bytes(byte_data)
73+
74+
def __repr__(self) -> str:
75+
return f"ProofPoly[{self.modulus}]({self.coeffs})"
76+
77+
78+
def build_proofs(
79+
activations: list[torch.Tensor],
80+
decode_batching_size: int,
81+
topk: int,
82+
skip_prefill: bool = False,
83+
) -> list[bytes]:
84+
return [
85+
proof.to_bytes()
86+
for proof in _build_proofs(
87+
activations, decode_batching_size, topk, skip_prefill
88+
)
89+
]
90+
91+
92+
def build_proofs_base64(
93+
activations: list[torch.Tensor],
94+
decode_batching_size: int,
95+
topk: int,
96+
skip_prefill: bool = False,
97+
) -> list[str]:
98+
return [
99+
proof.to_base64()
100+
for proof in _build_proofs(
101+
activations, decode_batching_size, topk, skip_prefill
102+
)
103+
]
104+
105+
106+
def _build_proofs(
107+
activations: list[torch.Tensor],
108+
decode_batching_size: int,
109+
topk: int,
110+
skip_prefill: bool = False,
111+
) -> list[ProofPoly]:
112+
proofs = []
113+
114+
# In order to not crash, we return null proofs if there is an error
115+
try:
116+
# Prefill
117+
if not skip_prefill:
118+
flat_view = activations[0].view(-1)
119+
topk_indices = flat_view.abs().topk(topk).indices
120+
topk_values = flat_view[topk_indices]
121+
proof = ProofPoly.from_points(topk_indices, topk_values)
122+
proofs.append(proof)
123+
124+
# Batched Decode
125+
for i in range(1, len(activations), decode_batching_size):
126+
flat_view = torch.cat(
127+
[i.view(-1) for i in activations[i : i + decode_batching_size]]
128+
)
129+
topk_indices = flat_view.abs().topk(topk).indices
130+
topk_values = flat_view[topk_indices]
131+
proof = ProofPoly.from_points(topk_indices, topk_values)
132+
proofs.append(proof)
133+
except Exception as e:
134+
logger.error(f"Error building proofs: {e}")
135+
proofs = [ProofPoly.null(topk)] * (
136+
1 + (len(activations) - 1 + decode_batching_size) // decode_batching_size
137+
)
138+
139+
return proofs

toploc/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import hashlib
2+
3+
4+
def sha256sum(filename: str, chunk_size: int = 65536) -> str:
5+
"""Calculate the SHA-256 checksum of a file efficiently.
6+
7+
Args:
8+
filename (str): Path to the file.
9+
chunk_size (int, optional): Size of chunks read at a time. Defaults to 64 KB.
10+
11+
Returns:
12+
str: The SHA-256 hash of the file as a hexadecimal string.
13+
"""
14+
sha256 = hashlib.sha256()
15+
with open(filename, "rb", buffering=0) as f:
16+
for chunk in iter(lambda: f.read(chunk_size), b""):
17+
sha256.update(memoryview(chunk))
18+
return sha256.hexdigest()

0 commit comments

Comments
 (0)