Skip to content

problem cause when doing w8a8kvfp8 for qwen2.5-vl #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
coolKeen opened this issue Apr 16, 2025 · 5 comments
Open

problem cause when doing w8a8kvfp8 for qwen2.5-vl #1354

coolKeen opened this issue Apr 16, 2025 · 5 comments
Assignees

Comments

@coolKeen
Copy link

coolKeen commented Apr 16, 2025

hi, I want to using w8a8kvfp8 for qwen2.5-vl and then run it on vllm. But when I doing quantization like the code below, I face the problem line the figure below, is there anything wrong with the recipe? Could you please give some example to get w8a8kvfp8 results of vision language model? Thanks in adavance !!!

BTY, do you hace any document to explain the meaning of the parameters of QuantizationModifier?

Image

`from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from datasets import load_dataset

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
import argparse
import os
import base64
from io import BytesIO
import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from modelscope import AutoProcessor
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import (
TraceableQwen2_5_VLForConditionalGeneration,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)

if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument("--src_model_path", type=str)
parser.add_argument("--dst_model_path", type=str)
args = parser.parse_args()

MODEL_ID = args.src_model_path
dst_model_path = args.dst_model_path
os.makedirs(dst_model_path, exist_ok=True)



# Load model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(MODEL_ID)


# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)




ds = ds.shuffle(seed=42)

dampening_frac=0.01

def preprocess_and_tokenize(example):
    # preprocess
    buffered = BytesIO()
    example["image"].save(buffered, format="PNG")
    encoded_image = base64.b64encode(buffered.getvalue())
    encoded_image_text = encoded_image.decode("utf-8")
    base64_qwen = f"data:image;base64,{encoded_image_text}"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": base64_qwen},
                {"type": "text", "text": "What does the image show?"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    # tokenize
    return processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
    )


ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
    assert len(batch) == 1
    return {key: torch.tensor(value) for key, value in batch[0].items()}




recipe = [
    QuantizationModifier(scheme={
        "weights": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            # group_size=128,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=False,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]

        ),
        "input_activations": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=True,
            observer=None,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]
        ),
        "kv_cache_scheme": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=True,
            observer=None,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]
        )
    })
]

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

SAVE_DIR = dst_model_path
oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

`

@coolKeen
Copy link
Author

hi, I want to using w8a8kvfp8 for qwen2.5-vl and then run it on vllm. But when I doing quantization like the code below, I face the problem line the figure below, is there anything wrong with the recipe? Could you please give some example to get w8a8kvfp8 results of vision language model? Thanks in adavance !!!

BTY, do you hace any document to explain the meaning of the parameters of QuantizationModifier?

Image

`from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration from datasets import load_dataset

from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier import argparse import os import base64 from io import BytesIO import torch from datasets import load_dataset from qwen_vl_utils import process_vision_info from modelscope import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import ( TraceableQwen2_5_VLForConditionalGeneration, ) from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, QuantizationType, )

if name == "main": parser = argparse.ArgumentParser() parser.add_argument("--src_model_path", type=str) parser.add_argument("--dst_model_path", type=str) args = parser.parse_args()

MODEL_ID = args.src_model_path
dst_model_path = args.dst_model_path
os.makedirs(dst_model_path, exist_ok=True)



# Load model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(MODEL_ID)


# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)




ds = ds.shuffle(seed=42)

dampening_frac=0.01

def preprocess_and_tokenize(example):
    # preprocess
    buffered = BytesIO()
    example["image"].save(buffered, format="PNG")
    encoded_image = base64.b64encode(buffered.getvalue())
    encoded_image_text = encoded_image.decode("utf-8")
    base64_qwen = f"data:image;base64,{encoded_image_text}"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": base64_qwen},
                {"type": "text", "text": "What does the image show?"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    # tokenize
    return processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
    )


ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
    assert len(batch) == 1
    return {key: torch.tensor(value) for key, value in batch[0].items()}




recipe = [
    QuantizationModifier(scheme={
        "weights": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            # group_size=128,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=False,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]

        ),
        "input_activations": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=True,
            observer=None,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]
        ),
        "kv_cache_scheme": QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            strategy=QuantizationStrategy.TENSOR,
            symmetric=True,
            dynamic=True,
            observer=None,
            ignore=["re:.*lm_head", "re:visual.*"],
            targets=["Linear"]
        )
    })
]

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

SAVE_DIR = dst_model_path
oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

`

@coolKeen coolKeen reopened this Apr 16, 2025
@kylesayrs kylesayrs self-assigned this Apr 16, 2025
@kylesayrs
Copy link
Collaborator

kylesayrs commented Apr 16, 2025

Hi @coolKeen! Thanks for using LLM Compressor!

The format used by Compressed Tensors is slightly different than what you've specified, namely that the scheme field maps preset names to targets, and that kv_cache_scheme is a field of QuantizationModifier, not of QuantizationScheme.

Beyond this, there is a validation error where kv_cache_schemes with dynamic quantization are not being parsed correctly. If you want dynamic quantization, simply remove the kv_cache_scheme for now.

In the meantime, this recipe should work for your use case

recipe = [
    QuantizationModifier(
        config_groups={
            "group_0": dict(
                targets=["Linear"],
                weights=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    # group_size=128,
                    strategy=QuantizationStrategy.TENSOR,
                    symmetric=True,
                    dynamic=False,
                    ignore=["re:.*lm_head", "re:visual.*"],
                    targets=["Linear"]
                ),
                input_activations=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    strategy=QuantizationStrategy.TENSOR,
                    symmetric=True,
                    dynamic=True,
                    observer=None,
                    ignore=["re:.*lm_head", "re:visual.*"],
                    targets=["Linear"]
                )
            )
        },
    )
]

@coolKeen
Copy link
Author

Hi @coolKeen! Thanks for using LLM Compressor!

The format used by Compressed Tensors is slightly different than what you've specified, namely that the scheme field maps preset names to targets, and that kv_cache_scheme is a field of QuantizationModifier, not of QuantizationScheme.

Beyond this, there is a validation error where kv_cache_schemes with dynamic quantization are not being parsed correctly. If you want dynamic quantization, simply remove the kv_cache_scheme for now.

In the meantime, this recipe should work for your use case

recipe = [
QuantizationModifier(
config_groups={
"group_0": dict(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
# group_size=128,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False,
ignore=["re:.lm_head", "re:visual."],
targets=["Linear"]
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=True,
observer=None,
ignore=["re:.lm_head", "re:visual."],
targets=["Linear"]
)
)
},
)
]

@kylesayrs thanks for you respond!! But I also want to use fp8 for kv cache and run it on vllm. How should I change this recipe.

@coolKeen
Copy link
Author

@kylesayrs I change recipe to
recipe = [
QuantizationModifier(
config_groups={
"group_0": dict(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
# group_size=128,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False,
ignore=["re:.lm_head", "re:visual."],
targets=["Linear"]
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=True,
observer=None,
ignore=["re:.lm_head", "re:visual."],
targets=["Linear"]
)
)
},
kv_cache_scheme=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False,
observer=None,
ignore=["re:.lm_head", "re:visual."],
targets=["Linear"]
)
)
]

but another problem cause:

Image
looks like something wrong with kv cache fp8 calibration?

@coolKeen
Copy link
Author

btw,what is the difference between static and dynamic for fp8?

kylesayrs added a commit that referenced this issue May 2, 2025
## Purpose ## 
* Abstract functionality which allows modifiers to act as quantization
configs into a mixin called `QuantizationMixin`
* This gives #1279 an interface to properly infer which pipeline to use
based on the recipe (if a recipe contains modifiers requires
calibration, then use the "basic" or "sequential" pipelines)
* This enables future modifiers to act as quantization modifiers (in the
same way that GPTQ does now)
* Related to #1354 where previous logic would attempt to add a
QuantizedKVCache for dynamic kv_quant

## Changes ##
* Implement `QuantizationMixin` which implements five public methods
  * Lifecycle methods
* `initialize_quantization` is used to apply a config and attach
observers to a model
* quantization is disabled so that modules aren't quantized before
they're calibrated
* `start_calibration` is used to initialize calibration hooks and status
* quantization is enabled, since we currently quantize as we calibrate,
although this decision is somewhat arbitrary
* `end_calibration` is used to remove calibration hooks and apply the
frozen status
* quantization remains enabled, since we want future forward passes to
simulate quantization
  * Recipe-related methods
* `has_config` returns true if a config was specified, used for checking
against duplicate configs in the recipe
* `resolve_quantization_config` returns the quantization config
specified by the modifier fields
* `QuantizationModifier` inherits from `QuantizationMixin`
* `GPTQModifier` inherits from `QuantizationMixin`
* Unlike QMod, GPTQ disables quantization during calibration. As noted
before, this is a somewhat arbitrary choice but one which matches the
current implementation

* Calibration utils
* Replace `set_unset_kv_cache` with `initialize_quantized_kv_cache` and
`freeze_module_quantization`
    * Treat the `QuantizedKVCache` as analogous to another observer
  * Pull setting the calibration status out of`update_weight_zp_scale`
* This better matches the lifecycle detailed in `QuantizationMixin`
description
* Implement `reset_quantization_status` which is used to remove any
existing quantization configs before the current config is applied by
`initialize_quantization`

## Remove Support ##
* Removing support for recipe with multiple quantization modifiers
active at the same time (a check for this will be added by #1279)
* Remove `num_calibration_steps`, `quantize`,
`disable_quantization_observer_epoch` and `min_tokens_per_module`
* `num_calibration_steps` is already controlled by
https://github.com/vllm-project/llm-compressor/blob/42b62f5283d0234b26623fe1f1bf02a77c6e4019/src/llmcompressor/datasets/utils.py#L106
* `quantize` was implemented as a workaround for GPTQ's modifier
builder. Similar functionality may be require to support SpinQuant +
GPTQ, but such functionality should exist at a higher level
* `disable_quantization_observer_epoch` seems to implement functionality
where a model's observers are removed but quantization remains active.
This functionality is maintained by setting an "end" epoch for qmod
* `min_tokens_per_module` requires that the modifier have references to
the calibration dataset, which is disallowed by #1279. This information
is already printed in GPTQ's logs. If research still wants this tool
specifically for `QuantizationModifier`, then it can be reimplemented to
avoid using references to the calibration dataset
  
## Testing ##
* Updated tests to reflect new mixin
* Ran a set of GPTQ and QuantizationModifier examples to completion
* CI tests pass

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants