@@ -16,6 +16,76 @@ For code used by experiments in our paper, check out: https://github.com/PrimeIn
16
16
pip install -U toploc
17
17
```
18
18
19
+ ### Usage
20
+
21
+ #### Build proofs from activations:
22
+ As bytes (more compact when stored in binary formats):
23
+ ``` python
24
+ import torch
25
+ from toploc import build_proofs_bytes
26
+
27
+ torch.manual_seed(42 )
28
+ activations = [torch.randn(5 , 16 , dtype = torch.bfloat16), * (torch.randn(16 , dtype = torch.bfloat16) for _ in range (10 ))]
29
+ proofs = build_proofs_bytes(activations, decode_batching_size = 3 , topk = 4 , skip_prefill = False )
30
+
31
+ print (f " Activation shapes: { [i.shape for i in activations]} " )
32
+ print (f " Proofs: { proofs} " )
33
+ ```
34
+ ``` python
35
+ Activation shapes: [torch.Size([5 , 16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ]), torch.Size([16 ])]
36
+ Proofs: [b ' \xff\xd9\x1b B+g\xba Kum' , b ' \xff\xd9\xcb\xb8\x9a\xf1\x86 %T\xa0 ' , b ' \xff\xd9\xb4 h\xda\xe6\xe4\xab A\xb6 ' , b ' \xff\xd9\x80 d\xd6 X0\xe2\xaf s' , b ' \xff\xd9\xd2\x04 d\xea\x91\x91\xf6\xd7 ' ]
37
+ ```
38
+
39
+ As base64 (more compact when stored in text formats):
40
+ ``` python
41
+ import torch
42
+ from toploc import build_proofs_base64
43
+
44
+ torch.manual_seed(42 )
45
+ activations = [torch.randn(1 , 5 , 16 , dtype = torch.bfloat16), * (torch.randn(1 , 16 , dtype = torch.bfloat16) for _ in range (10 ))]
46
+ proofs = build_proofs_base64(activations, decode_batching_size = 3 , topk = 4 , skip_prefill = False )
47
+
48
+ print (f " Activation shapes: { [i.shape for i in activations]} " )
49
+ print (f " Proofs: { proofs} " )
50
+ ```
51
+ ``` python
52
+ Activation shapes: [torch.Size([1 , 5 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ]), torch.Size([1 , 16 ])]
53
+ Proofs: [' /9kbQitnukt1bQ==' , ' /9nLuJrxhiVUoA==' , ' /9m0aNrm5KtBtg==' , ' /9mAZNZYMOKvcw==' , ' /9nSBGTqkZH21w==' ]
54
+ ```
55
+
56
+ #### Verify proofs:
57
+ ``` python
58
+ import torch
59
+ from toploc import ProofPoly
60
+ from toploc.poly import batch_activations
61
+ from toploc.C.csrc.utils import get_fp_parts
62
+ from statistics import mean, median
63
+
64
+ torch.manual_seed(42 )
65
+ activations = [torch.randn(1 , 5 , 16 , dtype = torch.bfloat16), * (torch.randn(1 , 16 , dtype = torch.bfloat16) for _ in range (10 ))]
66
+ proofs = [' /9kbQitnukt1bQ==' , ' /9nLuJrxhiVUoA==' , ' /9m0aNrm5KtBtg==' , ' /9mAZNZYMOKvcw==' , ' /9nSBGTqkZH21w==' ]
67
+ proofs = [ProofPoly.from_base64(proof) for proof in proofs]
68
+
69
+ # apply some jitter to the activations
70
+ activations = [i * 1.01 for i in activations]
71
+
72
+ for index, (proof, chunk) in enumerate (zip (proofs, batch_activations(activations, decode_batching_size = 3 ))):
73
+ chunk = chunk.view(- 1 ).cpu()
74
+ topk_indices = chunk.abs().topk(k = 4 ).indices.tolist()
75
+ topk_values = chunk[topk_indices]
76
+ proof_topk_values = torch.tensor([proof(i) for i in topk_indices], dtype = torch.uint16).view(dtype = torch.bfloat16)
77
+ exps, mants = get_fp_parts(proof_topk_values)
78
+ proof_exps, proof_mants = get_fp_parts(topk_values)
79
+
80
+ exp_intersections = [i == j for i, j in zip (exps, proof_exps)]
81
+ mant_errs = [abs (i - j) for i, j, k in zip (mants, proof_mants, exp_intersections) if k]
82
+ print (f " === Proof { index} " )
83
+ print (f " Exp intersections: { sum (exp_intersections)} " )
84
+ print (f " Mean mantissa error: { mean(mant_errs)} " )
85
+ print (f " Median mantissa error: { median(mant_errs)} " )
86
+ ```
87
+
88
+
19
89
# Citing
20
90
21
91
``` bibtex
0 commit comments