Skip to content

Commit 4437a54

Browse files
committed
add some utils for ease of use
1 parent e481a08 commit 4437a54

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

tests/test_poly.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import base64
44
from toploc.poly import (
55
find_injective_modulus,
6-
build_proofs,
6+
build_proofs_bytes,
77
build_proofs_base64,
88
ProofPoly,
99
)
@@ -132,7 +132,7 @@ def sample_activations():
132132

133133
def test_build_proofs(sample_activations):
134134
"""Test building proofs"""
135-
proofs = build_proofs(sample_activations, decode_batching_size=2, topk=5)
135+
proofs = build_proofs_bytes(sample_activations, decode_batching_size=2, topk=5)
136136
assert isinstance(proofs, list)
137137
assert all(isinstance(p, bytes) for p in proofs)
138138
assert len(proofs) == 5
@@ -151,7 +151,7 @@ def test_build_proofs_base64(sample_activations):
151151

152152
def test_build_proofs_skip_prefill(sample_activations):
153153
"""Test building proofs with skip_prefill"""
154-
proofs = build_proofs(
154+
proofs = build_proofs_bytes(
155155
sample_activations, decode_batching_size=2, topk=5, skip_prefill=True
156156
)
157157
assert isinstance(proofs, list)
@@ -165,7 +165,7 @@ def test_build_proofs_error_handling():
165165
torch.randn(0, 16, dtype=torch.bfloat16),
166166
torch.randn(16, dtype=torch.bfloat16),
167167
]
168-
proofs = build_proofs(invalid_activations, decode_batching_size=2, topk=5)
168+
proofs = build_proofs_bytes(invalid_activations, decode_batching_size=2, topk=5)
169169
assert isinstance(proofs, list)
170170
assert all(isinstance(p, bytes) for p in proofs)
171171

@@ -176,21 +176,23 @@ def test_build_proofs_error_handling():
176176
def test_build_proofs_edge_cases(sample_activations):
177177
"""Test edge cases for proof building"""
178178
# Test with minimal topk
179-
proofs_min = build_proofs(sample_activations, decode_batching_size=2, topk=1)
179+
proofs_min = build_proofs_bytes(sample_activations, decode_batching_size=2, topk=1)
180180
assert len(proofs_min) > 0
181181

182182
# Test with large batching size
183-
proofs_large_batch = build_proofs(
183+
proofs_large_batch = build_proofs_bytes(
184184
sample_activations, decode_batching_size=10, topk=5
185185
)
186186
assert len(proofs_large_batch) > 0
187187

188188
# Test with only one prefill activation
189-
proofs_one = build_proofs(sample_activations[:1], decode_batching_size=2, topk=5)
189+
proofs_one = build_proofs_bytes(
190+
sample_activations[:1], decode_batching_size=2, topk=5
191+
)
190192
assert len(proofs_one) == 1
191193

192194
# Test with only one activation and skip_prefill
193-
proofs_one_skip = build_proofs(
195+
proofs_one_skip = build_proofs_bytes(
194196
sample_activations[:1], decode_batching_size=2, topk=5, skip_prefill=True
195197
)
196198
assert len(proofs_one_skip) == 0

toploc/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from toploc.poly import ProofPoly, build_proofs, build_proofs_base64 # noqa: F401
1+
from toploc.poly import ProofPoly, build_proofs, build_proofs_bytes, build_proofs_base64 # noqa: F401
22
from toploc.utils import sha256sum # noqa: F401
33

44
__version__ = "0.0.0.dev1"

toploc/poly.py

+44-28
Original file line numberDiff line numberDiff line change
@@ -80,34 +80,6 @@ def build_proofs(
8080
decode_batching_size: int,
8181
topk: int,
8282
skip_prefill: bool = False,
83-
) -> list[bytes]:
84-
return [
85-
proof.to_bytes()
86-
for proof in _build_proofs(
87-
activations, decode_batching_size, topk, skip_prefill
88-
)
89-
]
90-
91-
92-
def build_proofs_base64(
93-
activations: list[torch.Tensor],
94-
decode_batching_size: int,
95-
topk: int,
96-
skip_prefill: bool = False,
97-
) -> list[str]:
98-
return [
99-
proof.to_base64()
100-
for proof in _build_proofs(
101-
activations, decode_batching_size, topk, skip_prefill
102-
)
103-
]
104-
105-
106-
def _build_proofs(
107-
activations: list[torch.Tensor],
108-
decode_batching_size: int,
109-
topk: int,
110-
skip_prefill: bool = False,
11183
) -> list[ProofPoly]:
11284
proofs = []
11385

@@ -137,3 +109,47 @@ def _build_proofs(
137109
)
138110

139111
return proofs
112+
113+
114+
def build_proofs_bytes(
115+
activations: list[torch.Tensor],
116+
decode_batching_size: int,
117+
topk: int,
118+
skip_prefill: bool = False,
119+
) -> list[bytes]:
120+
return [
121+
proof.to_bytes()
122+
for proof in build_proofs(activations, decode_batching_size, topk, skip_prefill)
123+
]
124+
125+
126+
def build_proofs_base64(
127+
activations: list[torch.Tensor],
128+
decode_batching_size: int,
129+
topk: int,
130+
skip_prefill: bool = False,
131+
) -> list[str]:
132+
return [
133+
proof.to_base64()
134+
for proof in build_proofs(activations, decode_batching_size, topk, skip_prefill)
135+
]
136+
137+
138+
def batch_activations(
139+
activations: list[torch.Tensor],
140+
decode_batching_size: int,
141+
) -> list[torch.Tensor]:
142+
batches = []
143+
144+
# Prefill
145+
flat_view = activations[0].view(-1)
146+
batches.append(flat_view)
147+
148+
# Batched Decode
149+
for i in range(1, len(activations), decode_batching_size):
150+
flat_view = torch.cat(
151+
[i.view(-1) for i in activations[i : i + decode_batching_size]]
152+
)
153+
batches.append(flat_view)
154+
155+
return batches

0 commit comments

Comments
 (0)