Skip to content

Commit e481a08

Browse files
committed
usage docs
1 parent 048deda commit e481a08

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

README.md

+70
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,76 @@ For code used by experiments in our paper, check out: https://github.com/PrimeIn
1616
pip install -U toploc
1717
```
1818

19+
### Usage
20+
21+
#### Build proofs from activations:
22+
As bytes (more compact when stored in binary formats):
23+
```python
24+
import torch
25+
from toploc import build_proofs_bytes
26+
27+
torch.manual_seed(42)
28+
activations = [torch.randn(5, 16, dtype=torch.bfloat16), *(torch.randn(16, dtype=torch.bfloat16) for _ in range(10))]
29+
proofs = build_proofs_bytes(activations, decode_batching_size=3, topk=4, skip_prefill=False)
30+
31+
print(f"Activation shapes: {[i.shape for i in activations]}")
32+
print(f"Proofs: {proofs}")
33+
```
34+
```python
35+
Activation shapes: [torch.Size([5, 16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
36+
Proofs: [b'\xff\xd9\x1bB+g\xbaKum', b'\xff\xd9\xcb\xb8\x9a\xf1\x86%T\xa0', b'\xff\xd9\xb4h\xda\xe6\xe4\xabA\xb6', b'\xff\xd9\x80d\xd6X0\xe2\xafs', b'\xff\xd9\xd2\x04d\xea\x91\x91\xf6\xd7']
37+
```
38+
39+
As base64 (more compact when stored in text formats):
40+
```python
41+
import torch
42+
from toploc import build_proofs_base64
43+
44+
torch.manual_seed(42)
45+
activations = [torch.randn(1, 5, 16, dtype=torch.bfloat16), *(torch.randn(1, 16, dtype=torch.bfloat16) for _ in range(10))]
46+
proofs = build_proofs_base64(activations, decode_batching_size=3, topk=4, skip_prefill=False)
47+
48+
print(f"Activation shapes: {[i.shape for i in activations]}")
49+
print(f"Proofs: {proofs}")
50+
```
51+
```python
52+
Activation shapes: [torch.Size([1, 5, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16]), torch.Size([1, 16])]
53+
Proofs: ['/9kbQitnukt1bQ==', '/9nLuJrxhiVUoA==', '/9m0aNrm5KtBtg==', '/9mAZNZYMOKvcw==', '/9nSBGTqkZH21w==']
54+
```
55+
56+
#### Verify proofs:
57+
```python
58+
import torch
59+
from toploc import ProofPoly
60+
from toploc.poly import batch_activations
61+
from toploc.C.csrc.utils import get_fp_parts
62+
from statistics import mean, median
63+
64+
torch.manual_seed(42)
65+
activations = [torch.randn(1, 5, 16, dtype=torch.bfloat16), *(torch.randn(1, 16, dtype=torch.bfloat16) for _ in range(10))]
66+
proofs = ['/9kbQitnukt1bQ==', '/9nLuJrxhiVUoA==', '/9m0aNrm5KtBtg==', '/9mAZNZYMOKvcw==', '/9nSBGTqkZH21w==']
67+
proofs = [ProofPoly.from_base64(proof) for proof in proofs]
68+
69+
# apply some jitter to the activations
70+
activations = [i * 1.01 for i in activations]
71+
72+
for index, (proof, chunk) in enumerate(zip(proofs, batch_activations(activations, decode_batching_size=3))):
73+
chunk = chunk.view(-1).cpu()
74+
topk_indices = chunk.abs().topk(k=4).indices.tolist()
75+
topk_values = chunk[topk_indices]
76+
proof_topk_values = torch.tensor([proof(i) for i in topk_indices], dtype=torch.uint16).view(dtype=torch.bfloat16)
77+
exps, mants = get_fp_parts(proof_topk_values)
78+
proof_exps, proof_mants = get_fp_parts(topk_values)
79+
80+
exp_intersections = [i == j for i, j in zip(exps, proof_exps)]
81+
mant_errs = [abs(i - j) for i, j, k in zip(mants, proof_mants, exp_intersections) if k]
82+
print(f"=== Proof {index}")
83+
print(f"Exp intersections: {sum(exp_intersections)}")
84+
print(f"Mean mantissa error: {mean(mant_errs)}")
85+
print(f"Median mantissa error: {median(mant_errs)}")
86+
```
87+
88+
1989
# Citing
2090

2191
```bibtex

0 commit comments

Comments
 (0)