Skip to content

Commit 6fa2d96

Browse files
authored
Renaming quantize to quantize_ (#467)
Summary: Addressing feedback for `quantize` API from #391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent cdb6e98 commit 6fa2d96

File tree

12 files changed

+67
-67
lines changed

12 files changed

+67
-67
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
1919
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)
2020

2121
```python
22-
from torchao.quantization.quant_api import quantize, int4_weight_only
23-
m = quantize(m, int4_weight_only())
22+
from torchao.quantization.quant_api import quantize_, int4_weight_only
23+
quantize_(m, int4_weight_only())
2424
```
2525

2626
Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)
@@ -83,7 +83,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
8383

8484
* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
8585
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
86-
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`
86+
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`
8787

8888
## Composability
8989

test/integration/test_integration.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
int4_weight_only,
2424
int8_weight_only,
2525
int8_dynamic_activation_int8_weight,
26-
quantize,
26+
quantize_,
2727
_replace_with_custom_fn_if_matches_filter,
2828
)
2929
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
@@ -98,21 +98,21 @@
9898

9999
def _int8wo_api(mod):
100100
if TORCH_VERSION_AFTER_2_4:
101-
quantize(mod, int8_weight_only(), set_inductor_config=False)
101+
quantize_(mod, int8_weight_only(), set_inductor_config=False)
102102
unwrap_tensor_subclass(mod)
103103
else:
104104
change_linear_weights_to_int8_woqtensors(mod)
105105

106106
def _int8da_int8w_api(mod):
107107
if TORCH_VERSION_AFTER_2_4:
108-
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
108+
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
109109
unwrap_tensor_subclass(mod)
110110
else:
111111
change_linear_weights_to_int8_dqtensors(mod)
112112

113113
def _int4wo_api(mod):
114114
if TORCH_VERSION_AFTER_2_4:
115-
quantize(mod, int4_weight_only(), set_inductor_config=False)
115+
quantize_(mod, int4_weight_only(), set_inductor_config=False)
116116
unwrap_tensor_subclass(mod)
117117
else:
118118
change_linear_weights_to_int4_woqtensors(mod)
@@ -127,8 +127,8 @@ def _int4wo_api(mod):
127127
def undo_recommended_configs():
128128
torch._inductor.config.coordinate_descent_tuning = False
129129
torch._inductor.config.coordinate_descent_check_all_directions = False
130-
torch._inductor.config.force_fuse_int_mm_with_mul = False
131-
torch._inductor.config.fx_graph_cache = False
130+
torch._inductor.config.force_fuse_int_mm_with_mul = False
131+
torch._inductor.config.fx_graph_cache = False
132132
torch._inductor.config.triton.unique_kernel_names = False
133133
torch.set_float32_matmul_precision("highest")
134134

@@ -844,7 +844,7 @@ def api(mod):
844844
kwargs_copy = kwargs.copy()
845845
kwargs_copy["group_size"] = groupsize
846846
del kwargs_copy["groupsize"]
847-
quantize(mod, int4_weight_only(**kwargs_copy))
847+
quantize_(mod, int4_weight_only(**kwargs_copy))
848848
unwrap_tensor_subclass(mod)
849849
else:
850850
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
@@ -865,7 +865,7 @@ def test_dynamic_quant(self):
865865
m = nn.Sequential(nn.Linear(K, N))
866866

867867
y_ref = m(x)
868-
quantize(m, int8_dynamic_activation_int8_weight())
868+
quantize_(m, int8_dynamic_activation_int8_weight())
869869
y_test = m(x)
870870

871871
sqnr = compute_error(y_ref, y_test)

test/prototype/test_quant_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
1818
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
19-
from torchao.quantization.quant_api import quantize
19+
from torchao.quantization.quant_api import quantize_
2020

2121

2222
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
@@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
9191

9292
linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
9393
fpx_linear = copy.deepcopy(linear)
94-
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
94+
quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
9595

9696
x = torch.randn(N, IC, device=device, dtype=torch.half)
9797
expected = fpx_linear(x)

test/quantization/test_quant_api.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
Int8WeightOnlyQuantizedLinearWeight,
3232
Int4WeightOnlyQuantizedLinearWeight,
3333
)
34-
from torchao import quantize
34+
from torchao import quantize_
3535
from torchao.quantization.quant_api import (
3636
_replace_with_custom_fn_if_matches_filter,
3737
Quantizer,
@@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
8989

9090
class TorchCompileDynamicQuantizer(Quantizer):
9191
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
92-
quantize(model, int8_dynamic_activation_int8_weight())
92+
quantize_(model, int8_dynamic_activation_int8_weight())
9393
return model
9494

9595
class ToyLinearModel(torch.nn.Module):
@@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
152152
def test_dynamic_quant_gpu_singleline(self):
153153
m = ToyLinearModel().eval()
154154
example_inputs = m.example_inputs()
155-
m = quantize(m, int8_dynamic_activation_int8_weight())
155+
quantize_(m, int8_dynamic_activation_int8_weight())
156156
quantized = m(*example_inputs)
157157
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
158158
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
195195
)
196196
m = ToyLinearModel().eval().cpu()
197197
def api(model):
198-
model = quantize(model, int8_weight_only())
198+
quantize_(model, int8_weight_only())
199199
unwrap_tensor_subclass(model)
200200

201201
api(m)
@@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self):
501501
m = ToyLinearModel().eval()
502502
m_copy = copy.deepcopy(m)
503503
example_inputs = m.example_inputs()
504-
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))
504+
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
505505

506506
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
507507
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self):
530530
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
531531

532532
group_size = 32
533-
m = quantize(m, int4_weight_only(group_size=group_size))
533+
quantize_(m, int4_weight_only(group_size=group_size))
534534
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
535535
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
536536

@@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
550550
m_copy = copy.deepcopy(m)
551551
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
552552

553-
m = quantize(m, int8_weight_only())
553+
quantize_(m, int8_weight_only())
554554

555555
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
556556
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
@@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
573573
m_copy = copy.deepcopy(m)
574574
# setting batch_size to 20 to be compatible with the kernel
575575
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
576-
m = quantize(m, int8_dynamic_activation_int8_weight())
576+
quantize_(m, int8_dynamic_activation_int8_weight())
577577

578578
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
579579
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self):
607607
m_copy = copy.deepcopy(m)
608608
example_inputs = m.example_inputs(dtype=torch.bfloat16)
609609

610-
m = quantize(m, int8_weight_only())
610+
quantize_(m, int8_weight_only())
611611
ref = m(*example_inputs)
612612
with tempfile.NamedTemporaryFile() as f:
613613
torch.save(m.state_dict(), f)

torchao/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030

3131
from torchao.quantization import (
3232
autoquant,
33-
quantize,
33+
quantize_,
3434
)
3535
from . import dtypes
3636

3737
__all__ = [
3838
"dtypes",
3939
"autoquant",
40-
"quantize",
40+
"quantize_",
4141
]
4242

4343
# test-pytorchbot

torchao/_models/llama/eval.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
)
1515
from torchao.quantization.quant_api import (
16-
quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass
16+
quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass
1717

1818
)
1919
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
@@ -60,13 +60,13 @@ def run_evaluation(
6060

6161
if quantization:
6262
if "int8wo" in quantization:
63-
quantize(model, int8_weight_only())
63+
quantize_(model, int8_weight_only())
6464
if "int8dq" in quantization:
65-
quantize(model, int8_dynamic_activation_int8_weight())
65+
quantize_(model, int8_dynamic_activation_int8_weight())
6666
if "int4wo" in quantization and not "gptq" in quantization:
6767
groupsize=int(quantization.split("-")[-1])
6868
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
69-
quantize(model.to(device), int4_weight_only(group_size=groupsize))
69+
quantize_(model.to(device), int4_weight_only(group_size=groupsize))
7070
if "int4wo" in quantization and "gptq" in quantization:
7171
groupsize=int(quantization.split("-")[-2])
7272
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
@@ -94,8 +94,8 @@ def run_evaluation(
9494
model = torch.compile(model, mode="max-autotune", fullgraph=True)
9595
with torch.no_grad():
9696
TransformerEvalWrapper(
97-
model=model.to(device),
98-
tokenizer=tokenizer,
97+
model=model.to(device),
98+
tokenizer=tokenizer,
9999
max_seq_length=max_length,
100100
input_prep_func=prepare_inputs_for_model,
101101
device=device,
@@ -122,16 +122,16 @@ def run_evaluation(
122122

123123
args = parser.parse_args()
124124
run_evaluation(
125-
args.checkpoint_path,
126-
args.tasks,
127-
args.limit,
128-
args.device,
129-
args.precision,
130-
args.quantization,
131-
args.compile,
132-
args.max_length,
133-
args.calibration_tasks,
134-
args.calibration_limit,
135-
args.calibration_seq_length,
136-
args.pad_calibration_inputs,
125+
args.checkpoint_path,
126+
args.tasks,
127+
args.limit,
128+
args.device,
129+
args.precision,
130+
args.quantization,
131+
args.compile,
132+
args.max_length,
133+
args.calibration_tasks,
134+
args.calibration_limit,
135+
args.calibration_seq_length,
136+
args.pad_calibration_inputs,
137137
)

torchao/_models/llama/generate.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def generate(
100100
T_new = T + max_new_tokens
101101
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)
102102
seq[:T] = prompt.view(-1)
103-
103+
104104
# setup model cache
105105
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
106106
with torch.device(device):
@@ -158,7 +158,7 @@ def main(
158158
"""
159159

160160
torchao.quantization.utils.recommended_inductor_config_setter()
161-
161+
162162
assert checkpoint_path.is_file(), checkpoint_path
163163
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
164164
assert tokenizer_path.is_file(), str(tokenizer_path)
@@ -180,11 +180,11 @@ def main(
180180
prompt_length = encoded.size(0)
181181

182182
torch.manual_seed(1234)
183-
183+
184184

185185
if quantization:
186186
from torchao.quantization.quant_api import (
187-
quantize,
187+
quantize_,
188188
int8_weight_only,
189189
int8_dynamic_activation_int8_weight,
190190
int4_weight_only,
@@ -193,13 +193,13 @@ def main(
193193
)
194194

195195
if "int8wo" in quantization:
196-
quantize(model, int8_weight_only())
196+
quantize_(model, int8_weight_only())
197197
if "int8dq" in quantization:
198-
quantize(model, int8_dynamic_activation_int8_weight())
198+
quantize_(model, int8_dynamic_activation_int8_weight())
199199
if "int4wo" in quantization:
200200
groupsize=int(quantization.split("-")[-1])
201201
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
202-
quantize(model, int4_weight_only(group_size=groupsize))
202+
quantize_(model, int4_weight_only(group_size=groupsize))
203203
if "autoquant" == quantization:
204204
model = autoquant(model, manual=True)
205205

torchao/prototype/quant_llm/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F
55
## Usage
66

77
```python
8-
from torchao.quantization.quant_api import quantize
8+
from torchao.quantization.quant_api import quantize_
99
from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only
1010

1111
model = ...
1212
model.half() # not necessary, but recommeneded to maintain accuracy
13-
quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place
13+
quantize_(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place
1414

1515
# for generic FPx EyMz where x = 1 + y + z
16-
# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead
16+
# quantize_(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead
1717

1818
# fully compatible with torch.compile()
1919
model.compile(mode="max-autotune", fullgraph=True)

torchao/quantization/README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
7979
from torchao.dtypes import to_affine_quantized
8080
import copy
8181
from torchao.quantization.quant_api import (
82-
quantize,
82+
quantize_,
8383
int4_weight_only,
8484
)
8585

@@ -106,7 +106,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
106106
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
107107
group_size = 32
108108
# only works for torch 2.4+
109-
m = quantize(m, int4_weight_only(group_size=group_size))
109+
quantize_(m, int4_weight_only(group_size=group_size))
110110

111111
# temporary workaround for tensor subclass + torch.compile
112112
from torchao.utils import unwrap_tensor_subclass
@@ -173,7 +173,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True
173173

174174
# for torch 2.4+
175175
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
176-
quantize(model, int8_dynamic_activation_int8_weight())
176+
quantize_(model, int8_dynamic_activation_int8_weight())
177177

178178
# for torch 2.2.2 and 2.3
179179
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
@@ -185,7 +185,7 @@ change_linear_weights_to_int8_dqtensors(model)
185185
```python
186186
# for torch 2.4+
187187
from torchao.quantization import quantize, int8_weight_only
188-
quantize(model, int8_weight_only())
188+
quantize_(model, int8_weight_only())
189189

190190
# for torch 2.2.2 and 2.3
191191
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
@@ -200,7 +200,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
200200
```python
201201
# for torch 2.4+
202202
from torchao.quantization import quantize, int4_weight_only
203-
quantize(model, int4_weight_only())
203+
quantize_(model, int4_weight_only())
204204

205205
# for torch 2.2.2 and 2.3
206206
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

torchao/quantization/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"quantize_affine",
3030
"dequantize_affine",
3131
"choose_qprams_affine",
32-
"quantize",
32+
"quantize_",
3333
"int8_dynamic_activation_int4_weight",
3434
"int8_dynamic_activation_int8_weight",
3535
"int4_weight_only",

0 commit comments

Comments
 (0)