Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some tidbit repo improvements #12

Merged
merged 8 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[run]
source = toploc
omit =
toploc/__init__.py
toploc/C/csrc/__init__.py

[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
pass
raise ImportError

5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ Run single test:
uv run pytest tests/test_utils.py::test_get_fp32_parts
```

Run coverage:
```bash
uv run pytest --cov=toploc --cov-report=term-missing --cov-report=html
```

## Code Quality

Install pre-commit hooks:
Expand Down
42 changes: 17 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ The feature set includes:

For code used by experiments in our paper, check out: https://github.com/PrimeIntellect-ai/toploc-experiments

### Installation
## Installation

```bash
pip install -U toploc
```

### Usage
## Usage

#### Build proofs from activations:
### Build proofs from activations:
As bytes (more compact when stored in binary formats):
```python
import torch
Expand Down Expand Up @@ -53,38 +53,30 @@ Activation shapes: [torch.Size([1, 5, 16]), torch.Size([1, 16]), torch.Size([1,
Proofs: ['/9kbQitnukt1bQ==', '/9nLuJrxhiVUoA==', '/9m0aNrm5KtBtg==', '/9mAZNZYMOKvcw==', '/9nSBGTqkZH21w==']
```

#### Verify proofs:
### Verify proofs:
```python
import torch
from toploc import ProofPoly
from toploc.poly import batch_activations
from toploc.C.csrc.utils import get_fp_parts
from statistics import mean, median
from toploc import verify_proofs_base64

torch.manual_seed(42)
activations = [torch.randn(1, 5, 16, dtype=torch.bfloat16), *(torch.randn(1, 16, dtype=torch.bfloat16) for _ in range(10))]
proofs = ['/9kbQitnukt1bQ==', '/9nLuJrxhiVUoA==', '/9m0aNrm5KtBtg==', '/9mAZNZYMOKvcw==', '/9nSBGTqkZH21w==']
proofs = [ProofPoly.from_base64(proof) for proof in proofs]

# apply some jitter to the activations
activations = [i * 1.01 for i in activations]

for index, (proof, chunk) in enumerate(zip(proofs, batch_activations(activations, decode_batching_size=3))):
chunk = chunk.view(-1).cpu()
topk_indices = chunk.abs().topk(k=4).indices.tolist()
topk_values = chunk[topk_indices]
proof_topk_values = torch.tensor([proof(i) for i in topk_indices], dtype=torch.uint16).view(dtype=torch.bfloat16)
exps, mants = get_fp_parts(proof_topk_values)
proof_exps, proof_mants = get_fp_parts(topk_values)

exp_intersections = [i == j for i, j in zip(exps, proof_exps)]
mant_errs = [abs(i - j) for i, j, k in zip(mants, proof_mants, exp_intersections) if k]
print(f"=== Proof {index}")
print(f"Exp intersections: {sum(exp_intersections)}")
print(f"Mean mantissa error: {mean(mant_errs)}")
print(f"Median mantissa error: {median(mant_errs)}")
```
results = verify_proofs_base64(activations, proofs, decode_batching_size=3, topk=4, skip_prefill=False)

print("Results:")
print(*results, sep="\n")
```
```python
Results:
VerificationResult(exp_intersections=4, mant_err_mean=1.75, mant_err_median=2.0)
VerificationResult(exp_intersections=4, mant_err_mean=2, mant_err_median=2.0)
VerificationResult(exp_intersections=4, mant_err_mean=1.25, mant_err_median=1.0)
VerificationResult(exp_intersections=4, mant_err_mean=1, mant_err_median=1.0)
VerificationResult(exp_intersections=4, mant_err_mean=2, mant_err_median=2.0)
```

# Citing

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dev-dependencies = [
"ruff",
"pre-commit",
"pytest",
"pytest-cov"
]

[tool.setuptools.dynamic]
Expand Down
67 changes: 67 additions & 0 deletions tests/test_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
build_proofs_bytes,
build_proofs_base64,
ProofPoly,
VerificationResult,
verify_proofs_bytes,
verify_proofs_base64,
)


Expand Down Expand Up @@ -123,6 +126,7 @@ def test_proof_poly_repr(sample_poly):

@pytest.fixture
def sample_activations():
torch.manual_seed(42)
DIM = 16
a = [torch.randn(3, DIM, dtype=torch.bfloat16)]
for _ in range(3 * 2 + 1):
Expand Down Expand Up @@ -196,3 +200,66 @@ def test_build_proofs_edge_cases(sample_activations):
sample_activations[:1], decode_batching_size=2, topk=5, skip_prefill=True
)
assert len(proofs_one_skip) == 0


def test_verify_proofs_bytes(sample_activations):
"""Test verification of proofs in bytes format"""
# Generate proofs in bytes format
proofs_bytes = build_proofs_bytes(
sample_activations, decode_batching_size=3, topk=4
)

results = verify_proofs_bytes(
[i * 1.01 for i in sample_activations],
proofs_bytes,
decode_batching_size=3,
topk=4,
)

assert isinstance(results, list)
assert all(isinstance(r, VerificationResult) for r in results)
assert len(results) == len(proofs_bytes)
assert all(r.exp_intersections == 4 for r in results)
assert all(r.mant_err_mean > 0 and r.mant_err_mean <= 2 for r in results)
assert all(r.mant_err_median > 0 and r.mant_err_median <= 2 for r in results)


def test_verify_proofs_base64(sample_activations):
"""Test verification of proofs in base64 format"""
# Generate proofs in base64 format
proofs_base64 = build_proofs_base64(
sample_activations, decode_batching_size=2, topk=5
)

results = verify_proofs_base64(
sample_activations, proofs_base64, decode_batching_size=2, topk=5
)

assert isinstance(results, list)
assert all(isinstance(r, VerificationResult) for r in results)
assert len(results) == len(proofs_base64)
assert all(r.exp_intersections == 5 for r in results)
assert all(r.mant_err_mean == 0 for r in results)
assert all(r.mant_err_median == 0 for r in results)


def test_verify_proofs_bytes_invalid(sample_activations):
# Generate proofs in bytes format
proofs_bytes = build_proofs_bytes(
sample_activations, decode_batching_size=3, topk=4
)

results = verify_proofs_bytes(
[i * 1.10 for i in sample_activations],
proofs_bytes,
decode_batching_size=3,
topk=4,
)

print(results)
assert isinstance(results, list)
assert all(isinstance(r, VerificationResult) for r in results)
assert len(results) == len(proofs_bytes)
assert all(r.exp_intersections >= 3 for r in results)
assert all(r.mant_err_mean > 10 for r in results)
assert all(r.mant_err_median > 10 for r in results)
12 changes: 12 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import time
import pytest
import tempfile
from toploc.utils import sha256sum


@pytest.mark.parametrize(
Expand Down Expand Up @@ -114,3 +116,13 @@ def py_get_fp_parts(tensor: torch.FloatTensor) -> tuple[list[int], list[int]]:
assert exps == ref_exps
assert mantissas == ref_mants
assert new_time < old_time


def test_sha256sum():
with tempfile.NamedTemporaryFile() as f:
f.write(b"Hello, world!" * 1000)
f.flush()
assert (
sha256sum(f.name)
== "a8f764e70df94be2c911fb51b3d0c56c03882078dbdb215de8b7bd0374b0fb10"
)
13 changes: 11 additions & 2 deletions toploc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from toploc.poly import ProofPoly, build_proofs, build_proofs_bytes, build_proofs_base64 # noqa: F401
from toploc.utils import sha256sum # noqa: F401
# ruff: noqa: F401
from toploc.poly import (
ProofPoly,
build_proofs,
build_proofs_bytes,
build_proofs_base64,
verify_proofs,
verify_proofs_bytes,
verify_proofs_base64,
)
from toploc.utils import sha256sum

__version__ = "0.0.0.dev1"
98 changes: 0 additions & 98 deletions toploc/commits.py

This file was deleted.

Loading
Loading