Skip to content

Commit

Permalink
Update float8nocompile readme (#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre authored Feb 11, 2025
1 parent 999b16d commit d99785c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
73 changes: 71 additions & 2 deletions torchao/prototype/float8nocompile/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,72 @@
# Work in progress
# float8nocompile

A prototype version of Float8Linear which is performant without `torch.compile`.

A prototype API for high performance eager mode float8 training that uses handwritten Triton kernels for quantization.

### Usage

Prepare your model for high performance eager mode float8 training with a single conversion function: `convert_to_float8_nocompile_training` ([source](https://github.com/pytorch/ao/blob/32a51eca14257bbaafd3671a5349189e30c65e2b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py#L24)).

This function will replace nn.Linear layers with Float8NoCompileLinear layers in-place, which uses **dynamic, tensorwise scaling**
to perform all matmuls in the linear layer forward and backward pass as FP8 GEMMs.

**Example**:

```python
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)

# define your model, data loaders, etc
...

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_nocompile_training(model)

# training loop
for i in range(num_epochs):
...
```

### Performance benchmarks

Performance benchmarking was done via [experimental integration into torchtitan](https://github.com/pytorch/torchtitan/pull/778).

The results indicate a solid 6-10% tokens/sec speedup with relatively flat memory (+/- 1% peak memory) compared the bf16 eager baseline.

# Performance Comparison of Different Configurations on 8 H100s

## No AC (seq len 4096) - 8 H100s

| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
|-------------------------------------------------|------------|------------------|--------------|---------------|
| bfloat16, eager | 5339.0 | 53.12 | 0% | 0.00% |
| float8nocompile prototype | 5871.4 | 52.7 | 9.97% | -0.79% |
| float8 + torch.compile | 6667.6 | 46.64 | 24.88% | -12.20% |

---

## Selective per layer AC (AC every 2nd layer, seq len 4096) - 8 H100s

| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
|-------------------------------------------------|------------|------------------|--------------|---------------|
| bfloat16, eager | 4882.4 | 40.6 | 0% | 0.00% |
| float8nocompile prototype | 5302.0 | 40.97 | 8.59% | 0.91% |
| float8 + torch.compile | 6199.6 | 37.38 | 26.98% | -7.93% |

---

## Full AC (seq len 4096) - 8 H100s

| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
|-------------------------------------------------|------------|------------------|--------------|---------------|
| bfloat16, eager | 4502.0 | 28.07 | 0% | 0.00% |
| float8nocompile prototype | 4773.4 | 28.07 | 6.03% | 0.00% |
| float8 + torch.compile | 5775.2 | 28.03 | 28.28% | -0.14% |


## Numerical accuracy

Numerical accuracy has been verified via unit tests as well as manually verifying that the training loss curves maintain fidelity with the loss curves for bf16 eager and production float8 + torch.compile:

![loss curves](float8nocompile_loss_curves.png "Loss curves")
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d99785c

Please sign in to comment.