Skip to content

Commit 0654ea1

Browse files
committed
Update torchao api reference and add contributor guide
Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added pytorch#391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent f96e5ec commit 0654ea1

16 files changed

+807
-56
lines changed

docs/source/api_ref_dtypes.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ torchao.dtypes
1212

1313
to_nf4
1414
to_affine_quantized_intx
15-
to_affine_quantized_floatx
1615
to_affine_quantized_intx_static
16+
to_affine_quantized_floatx
1717
to_affine_quantized_floatx_static
18+
to_affine_quantized_fpx
19+
NF4Tensor
1820
AffineQuantizedTensor
1921

2022
..

docs/source/api_ref_intro.rst

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
``torchao`` API Reference
22
=========================
33

4-
This section introduces the torchao API reference.
5-
Dive into the details of how torchao integrates with PyTorch to
6-
optimize your machine learning models.
4+
This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models.
75

86
.. toctree::
97
:glob:
108
:maxdepth: 1
119
:caption: Python API Reference
1210

13-
api_ref_sparsity
14-
api_ref_quantization
1511
api_ref_dtypes
16-
api_ref_kernel
12+
api_ref_quantization
13+
api_ref_sparsity

docs/source/api_ref_quantization.rst

+32-7
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,40 @@ torchao.quantization
99
.. autosummary::
1010
:toctree: generated/
1111
:nosignatures:
12-
13-
SmoothFakeDynQuantMixin
14-
SmoothFakeDynamicallyQuantizedLinear
15-
swap_linear_with_smooth_fq_linear
16-
smooth_fq_linear_to_inference
17-
Int4WeightOnlyGPTQQuantizer
18-
Int4WeightOnlyQuantizer
12+
autoquant
13+
1914
quantize_
2015
int8_dynamic_activation_int4_weight
2116
int8_dynamic_activation_int8_weight
2217
int4_weight_only
2318
int8_weight_only
19+
float8_weight_only
20+
float8_dynamic_activation_float8_weight
21+
float8_static_activation_float8_weight
22+
uintx_weight_only
23+
fpx_weight_only
24+
25+
to_linear_activation_quantized
26+
to_linear_activation_weight_observed
27+
28+
swap_linear_with_smooth_fq_linear
29+
smooth_fq_linear_to_inference
30+
31+
choose_qparams_affine
32+
choose_qparams_affine_with_min_max
33+
choose_qparams_affine_floatx
34+
quantize_affine
35+
quantize_affine_floatx
36+
dequantize_affine
37+
dequantize_affine_floatx
38+
choose_qparams_and_quantize_affine_hqq
39+
fake_quantize_affine
40+
fake_quantize_affine_cachemask
41+
42+
safe_int_mm
43+
int_scaled_matmul
44+
45+
MappingType
46+
ZeroPointDomain
47+
TorchAODType
48+

docs/source/contributor_guide.rst

+674
Large diffs are not rendered by default.

docs/source/index.rst

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
Welcome to the torchao Documentation
22
=======================================
33

4-
**torchao** is an open-source library that provides the functionality
5-
to quantize and prune your models using native PyTorch. Our documentation is under development
6-
with more content coming soon.
4+
`**torchao** <https://github.com/pytorch/ao>`__ is a ibrary for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.
75

86
..
97
.. grid:: 3
@@ -81,13 +79,19 @@ with more content coming soon.
8179
:maxdepth: 1
8280
:caption: API Reference
8381

84-
api_ref_sparsity
85-
api_ref_intro
86-
api_ref_quantization
8782
api_ref_dtypes
83+
api_ref_quantization
84+
api_ref_sparsity
8885
..
8986
api_ref_kernel
90-
87+
88+
.. toctree::
89+
:glob:
90+
:maxdepth: 1
91+
:caption: Contributor Guide
92+
93+
contributor_guide
94+
9195
.. toctree::
9296
:glob:
9397
:maxdepth: 1

test/integration/test_integration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
change_linear_weights_to_int8_woqtensors,
3535
change_linear_weights_to_int4_woqtensors,
3636
)
37-
from torchao.quantization.quant_primitives import (
37+
from torchao.quantization import (
3838
safe_int_mm,
39+
)
40+
from torchao.quantization.quant_primitives import (
3941
choose_qparams_affine,
4042
quantize_affine,
4143
dequantize_affine,

torchao/dtypes/affine_quantized_tensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131
choose_qparams_and_quantize_affine_hqq,
3232
dequantize_affine,
3333
dequantize_affine_floatx,
34-
int_scaled_matmul,
3534
quantize_affine,
3635
quantize_affine_floatx,
3736
)
37+
from torchao.kernel import (
38+
int_scaled_matmul,
39+
)
3840
from torchao.quantization.utils import (
3941
pack_tinygemm_scales_and_zeros,
4042
)

torchao/kernel/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torchao.kernel.intmm import int_scaled_matmul
2+
from torchao.kernel.intmm import safe_int_mm
3+
4+
__all__ = [
5+
"safe_int_mm",
6+
"int_scaled_matmul",
7+
]

torchao/quantization/__init__.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
PerTensor,
2525
PerToken,
2626
)
27+
from torchao.kernel import (
28+
safe_int_mm,
29+
int_scaled_matmul,
30+
)
2731
from .linear_activation_quantized_tensor import (
2832
LinearActivationQuantizedTensor,
2933
to_linear_activation_quantized,
@@ -70,52 +74,79 @@
7074
compute_error,
7175
)
7276
from .weight_only import WeightOnlyInt8QuantLinear
77+
from .linear_activation_weight_observed_tensor import (
78+
to_linear_activation_weight_observed,
79+
)
7380

7481
__all__ = [
75-
"swap_conv2d_1x1_to_linear",
82+
# top level API - auto
7683
"autoquant",
7784
"DEFAULT_AUTOQUANT_CLASS_LIST",
7885
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
7986
"OTHER_AUTOQUANT_CLASS_LIST",
80-
"get_scale",
81-
"SmoothFakeDynQuantMixin",
82-
"SmoothFakeDynamicallyQuantizedLinear",
83-
"swap_linear_with_smooth_fq_linear",
84-
"smooth_fq_linear_to_inference",
85-
"set_smooth_fq_attribute",
86-
"compute_error",
87-
"Int4WeightOnlyGPTQQuantizer",
88-
"Int4WeightOnlyQuantizer",
89-
"quantize_affine",
90-
"dequantize_affine",
91-
"choose_qparams_affine",
87+
88+
# top level API - manual
9289
"quantize_",
9390
"int8_dynamic_activation_int4_weight",
9491
"int8_dynamic_activation_int8_weight",
9592
"int8_dynamic_activation_int8_semi_sparse_weight",
9693
"int4_weight_only",
9794
"int8_weight_only",
95+
"float8_weight_only",
96+
"float8_dynamic_activation_float8_weight",
97+
"float8_static_activation_float8_weight"
9898
"uintx_weight_only",
9999
"fpx_weight_only",
100-
"LinearActivationQuantizedTensor",
100+
101+
# smooth quant - subject to change
102+
"swap_conv2d_1x1_to_linear"
103+
"get_scale",
104+
"SmoothFakeDynQuantMixin",
105+
"SmoothFakeDynamicallyQuantizedLinear",
106+
"swap_linear_with_smooth_fq_linear",
107+
"smooth_fq_linear_to_inference",
108+
"set_smooth_fq_attribute",
109+
"compute_error",
110+
111+
# building blocks
101112
"to_linear_activation_quantized",
102113
"to_weight_tensor_with_linear_activation_scale_metadata",
103-
"float8_weight_only",
104-
"float8_dynamic_activation_float8_weight",
105-
"float8_static_activation_float8_weight",
106-
"Int8DynActInt4WeightGPTQQuantizer",
107-
"Int8DynActInt4WeightQuantizer",
108-
"Int8DynActInt4WeightLinear",
109-
"WeightOnlyInt8QuantLinear",
110-
"TwoStepQuantizer",
111-
"Quantizer",
112-
"ZeroPointDomain",
113-
"MappingType",
114114
"AffineQuantizedMinMaxObserver",
115115
"AffineQuantizedObserverBase",
116+
117+
# quant primitive ops
118+
"choose_qprams_affine",
119+
"choose_qparams_affine_with_min_max",
120+
"choose_qparams_affine_floatx",
121+
"quantize_affine",
122+
"quantize_affine_floatx",
123+
"dequantize_affine",
124+
"dequantize_affine_floatx",
125+
"choose_qparams_and_quantize_affine_hqq",
126+
"fake_quantize_affine",
127+
"fake_quantize_affine_cachemask",
128+
129+
# operators/kernels
130+
"safe_int_mm",
131+
"int_scaled_matmul",
132+
133+
# dataclasses and types
134+
"MappingType",
135+
"ZeroPointDomain",
136+
"TorchAODType",
116137
"PerTensor",
117138
"PerAxis",
118139
"PerGroup",
119140
"PerRow",
120141
"PerToken",
142+
143+
"LinearActivationQuantizedTensor",
144+
"Int4WeightOnlyGPTQQuantizer",
145+
"Int4WeightOnlyQuantizer",
146+
"Int8DynActInt4WeightGPTQQuantizer",
147+
"Int8DynActInt4WeightQuantizer",
148+
"Int8DynActInt4WeightLinear",
149+
"WeightOnlyInt8QuantLinear",
150+
"TwoStepQuantizer",
151+
"Quantizer",
121152
]

torchao/quantization/autoquant.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
PerRow,
2525
PerTensor,
2626
)
27-
from .quant_primitives import safe_int_mm
27+
from torchao.kernel import safe_int_mm
2828
from .subclass import ( # noqa
2929
Int8DynamicallyQuantizedLinearWeight,
3030
Int8WeightOnlyQuantizedLinearWeight,
3131
QuantizedLinearWeightBase,
3232
)
3333

34+
3435
__all__ = [
3536
"AutoQuantizableLinearWeight",
3637
"autoquant",

torchao/quantization/linear_activation_quantized_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
2020
"""
2121
Applies activation quantization for linear operator, this is used to support
22-
dynamic quantization or static quantization, user can pass in a `input_quant_func`
22+
dynamic quantization, user can pass in a `input_quant_func`
2323
that is used to quantize the activation
2424
2525
Args:
@@ -60,7 +60,7 @@ def __init__(
6060
self.quant_kwargs = quant_kwargs
6161

6262
def __repr__(self):
63-
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
63+
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
6464

6565
def __tensor_flatten__(self):
6666
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]

torchao/quantization/linear_activation_weight_observer.py renamed to torchao/quantization/linear_activation_weight_observed_tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
__all__ = [
1313
"LinearActivationWeightObservedTensor",
14+
"to_linear_activation_weight_observed",
1415
]
1516

1617
aten = torch.ops.aten
@@ -147,6 +148,8 @@ def _(func, types, args, kwargs):
147148
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
148149
)
149150

151+
to_linear_activation_weight_observed = LinearActivationWeightObservedTensor.from_float
152+
150153

151154
if TORCH_VERSION_AT_LEAST_2_5:
152155
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def insert_observers_(
385385
def convert_to_linear_observer(linear_module: nn.Linear):
386386
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter
387387
linear_module.weight = nn.Parameter(
388-
LinearActivationWeightObservedTensor.from_float(
388+
to_linear_activation_weight_observed(
389389
linear_module.weight,
390390
input_observer=input_observer,
391391
weight_observer=weight_observer,

torchao/quantization/quant_primitives.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
)
2525

2626
__all__ = [
27-
"safe_int_mm",
28-
"int_scaled_matmul",
2927
"choose_qparams_affine",
3028
"choose_qparams_affine_with_min_max",
3129
"choose_qparams_affine_floatx",
@@ -36,6 +34,9 @@
3634
"fake_quantize_affine",
3735
"fake_quantize_affine_cachemask",
3836
"choose_qparams_and_quantize_affine_hqq",
37+
"MappingType",
38+
"ZeroPointDomain",
39+
"TorchAODType",
3940
]
4041

4142

torchao/quantization/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import torch
1010
from torch.utils._python_dispatch import TorchDispatchMode
1111

12+
from torchao.kernel import (
13+
int_scaled_matmul,
14+
)
1215
from torchao.quantization.quant_primitives import (
1316
MappingType,
1417
ZeroPointDomain,
1518
choose_qparams_affine,
1619
dequantize_affine,
17-
int_scaled_matmul,
1820
quantize_affine,
1921
)
2022
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

torchao/quantization/weight_tensor_linear_activation_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.quant_kwargs = quant_kwargs
7171

7272
def __repr__(self):
73-
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
73+
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
7474

7575
def __tensor_flatten__(self):
7676
tensor_data = ["original_weight_tensor", "scale"]

0 commit comments

Comments
 (0)