Skip to content

AWQ Qwen3-235B-A22B and Qwen3-30B-A3B #1406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ehartford opened this issue May 1, 2025 · 9 comments
Open

AWQ Qwen3-235B-A22B and Qwen3-30B-A3B #1406

ehartford opened this issue May 1, 2025 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@ehartford
Copy link

ehartford commented May 1, 2025

Describe the bug
When I try to AWQ these models, it hangs forever.

Expected behavior
I expect it to quantize the model

Environment
Nvidia DGX A100

To Reproduce

I used examples/awq/awq_one_shot.py and modified it:

from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "Qwen/Qwen3-30B-A3B"
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512
OUTPUT_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"


def get_calib_dataset(tokenizer):
    from datasets import load_dataset

    ds = load_dataset(
        DATASET_ID,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]",
    )

    def preprocess(example):
        return {
            "input_ids": tokenizer.encode(example["text"].strip())[:MAX_SEQUENCE_LENGTH]
        }

    ds = (
        ds.shuffle(seed=42)
        .map(preprocess, remove_columns=ds.column_names)
        .filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH)
        .select(range(NUM_CALIBRATION_SAMPLES))
    )

    return ds


if __name__ == "__main__":
    recipe = [
        AWQModifier(bits=4, symmetric=False),
        QuantizationModifier(
            # Ignore these layers during quantization
            ignore=[
                "lm_head",
                ".*norm.*",
                ".*gate.*",
            ],
            config_groups={
                "group_0": QuantizationScheme(
                    targets=["Linear"],
                    weights=QuantizationArgs(
                        num_bits=4,
                        type=QuantizationType.INT,
                        dynamic=False,
                        symmetric=False,
                        strategy=QuantizationStrategy.GROUP,
                        group_size=128,
                    ),
                )
            },
        ),
    ]

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, device_map="auto", torch_dtype="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

    oneshot(
        model=model,
        dataset=get_calib_dataset(tokenizer=tokenizer),
        recipe=recipe,
        output_dir=OUTPUT_DIR,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    )

    print("Done! model saved to", OUTPUT_DIR)

The output

(vllm) dgxuser@linux:~/workspace/llm-compressor$ python qwen_moe_awq.py 
Loading checkpoint shards: 100%|████████████████████████████████████████████| 16/16 [00:22<00:00,  1.39s/it]
Repo card metadata block was not found. Setting CardData to empty.
2025-04-30T18:26:27.175014-0700 | reset | INFO - Compression lifecycle reset
2025-04-30T18:26:27.175475-0700 | from_modifiers | INFO - Creating recipe from modifiers
@ehartford ehartford added the bug Something isn't working label May 1, 2025
@ehartford
Copy link
Author

FYI - I also try with w8a16 and it works, the problem is in AWQ

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer 
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

MODEL_ID = "Qwen/Qwen3-235B-A22B"

device_map = calculate_offload_device_map(
    MODEL_ID,
    reserve_for_hessians=True,
    num_gpus=8,  
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

model = AutoModelForCausalLM.from_pretrained( # Reverted to AutoModelForCausalLM
    MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

DATASET_ID = "neuralmagic/LLM_compression_calibration"
NUM_CALIBRATION_SAMPLES = 256 
MAX_SEQUENCE_LENGTH = 8192 

ds = load_dataset(DATASET_ID, split="train") 
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def preprocess(example):
  return {"text": tokenizer.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False)}

ds = ds.map(preprocess)

def tokenize(example):
    return tokenizer(
        example["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )

ds = ds.map(tokenize, remove_columns=["messages", "text"])

recipe = GPTQModifier(
    targets="Linear",
    scheme="W8A16", 
    ignore=["lm_head", ".*gate.*", ".*norm.*"],
    dampening_frac=0.1, 
)

SAVE_DIR = "Qwen3-235B-A22B-quantized.w8a16"


oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    save_compressed=True,
    trust_remote_code_model=True,
    output_dir=SAVE_DIR,
)

input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))

@brian-dellabetta brian-dellabetta self-assigned this May 1, 2025
@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented May 1, 2025

Hi @ehartford , thanks for your interest in AWQ and for bringing this to our attention. While it seems the non-MoE Qwen3 models ran, these MoE models are hanging while resolving the mappings. We are using string matches, and it causes runtime to increase dramatically looping over 48 layers, each with 128 experts in the case of Qwen/Qwen3-30B-A3B.

This isn't an issue in AutoAWQ, which has custom wrappers for each model (Qwen3MoE example here).

I will try to address this by end of next week

@ubergarm
Copy link

ubergarm commented May 2, 2025

@ehartford

I'm running your AWQ code on a single RTX A6000 48 GB VRAM and after allocating ~42 GB for the model it sits with no GPU utilization and a single CPU core spinning at 100% for python. I'll let it sit overnight and possibly it will loop over the 48 layers x 128 experts eventually?

$ CUDA_VISIBLE_DEVICES=0 python compressor.py
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:12<00:00,  1.30it/s]
Some parameters are on the meta device because they were offloaded to the cpu.
Repo card metadata block was not found. Setting CardData to empty.
2025-05-02T01:15:02.456847-0400 | reset | INFO - Compression lifecycle reset
2025-05-02T01:15:02.457368-0400 | from_modifiers | INFO - Creating recipe from modifiers

When I tried AutoAWQ directly after updating transformers and it said

TypeError: qwen3_moe isn't supported yet.

I saw an open issue on the hugging face repo too: https://huggingface.co/Qwen/Qwen3-30B-A3B/discussions/12

Will check in later, thanks!

@ehartford
Copy link
Author

Ok but I think it will hang there forever, I let mine sit overnight

@ubergarm
Copy link

ubergarm commented May 2, 2025

@ehartford

lmao, it seems like it got through the loop but then of course it OOMd when it went to do the actual thing hahah

2025-05-02T01:15:02.456847-0400 | reset | INFO - Compression lifecycle reset
2025-05-02T01:15:02.457368-0400 | from_modifiers | INFO - Creating recipe from modifiers
2025-05-02T01:47:01.661135-0400 | _set_resolved_mappings | INFO - Excluded 48 from resolved mappings due to shape mismatch
2025-05-02T01:47:02.270633-0400 | _calibrate | INFO - Running AWQModifier calibration with 256 samples...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [11:34<00:00,  2.71s/it]
2025-05-02T02:00:16.284428-0400 | _apply_smoothing | INFO - Smoothing activation scales...
  0%|                                                                                                                        | 0/6240 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/w/projects/vllm/compressor.py", line 76, in <module>
    oneshot(
  File "/home/w/projects/vllm/venv/lib/python3.12/site-packages/llmcompressor/entrypoints/oneshot.py", line 179, in oneshot
    one_shot()
  File "/home/w/projects/vllm/venv/lib/python3.12/site-packages/llmcompressor/entrypoints/oneshot.py", line 131, in __call__
    self.apply_recipe_modifiers(
.
.
.
  File "/home/w/projects/vllm/venv/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py", line 54, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 47.41 GiB of which 131.44 MiB is free. Includ
ing non-PyTorch memory, this process has 47.26 GiB memory in use. Of the allocated memory 46.00 GiB is allocated by PyTorch, and 972.30 MiB is reserve
d by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragme
ntation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

So if you have enough VRAM you might wake up to the worlds first Qwen3-30B-A3B AWQ who knows xD!

Looking at the timestamps from the logs it took a little over 30 minutes to work through the loop on a AMD Ryzen Threadripper PRO 7965WX 24-Cores (running 1 core single threaded on python).

@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented May 2, 2025

Yes, it will likely OOM for larger models. We cache the calibrated activations for the entire model, rather than layer-by-layer, so memory requirements do not scale well with model size. AutoAWQ handles this, but we need to integrate our own pipelining abstraction and wanted to do that in a follow-up PR. We need to add that feature in order for our implementation of AWQ to really be fully ready, what we have so far is a basic port of AutoAWQ not quite ready for primetime.

Related issue -- #1369 (comment)

@ubergarm
Copy link

ubergarm commented May 2, 2025

@brian-dellabetta

Thanks! Yeah and seems like no support for CPU backend as I tried: CUDA_VISIBLE_DEVICES="NONE" and get RuntimeError: No CUDA GPUs are available.

I'd love to get AWQ going and output GGUFs to test against ik_llama.cpp imatrix quants e.g. my ubergarm/Qwen3-30B-A3B-GGUF

Guessing inference speed with vllm would be better, and not sure how to test perplexity and KLD etc on AWQ quants. Anyway, beyond the scope. Cheers and thanks for all your efforts!

@ubergarm
Copy link

ubergarm commented May 2, 2025

@ehartford just got this running a moment ago, takes about 17GB VRAM to load plus as much extra for parallel inferencing slots:

CUDA_VISIBLE_DEVICES="0" \
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
VLLM_USE_MODELSCOPE=True \
vllm \
  serve swift/Qwen3-30B-A3B-AWQ \
  --gpu-memory-utilization 0.9 \
  --max-model-len 32768 \
  --max-num-seqs 64 \
  --served-model-name swift/Qwen3-30B-A3B-AWQ \
  --host 127.0.0.1 \
  --port 8080

Not sure how they quantized their model, but maybe how you were trying with enough time and VRAM.

@brian-dellabetta
Copy link
Collaborator

Hi @ubergarm , yes AWQ will require a GPU to run in a reasonable amount of time for most models. We've got that somewhat hard-coded for now, and we'll have better support for offloaded models in a future release.

Yeah, I noticed Qwen publishes some AWQ-ed models (https://huggingface.co/Qwen/Qwen3-32B-AWQ) but no MoE models. There do seem to be lots in the community though 💪

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants