Skip to content

Commit f2cb745

Browse files
authored
Added DeepSeek V3 support. (#688)
1 parent 4c9a791 commit f2cb745

File tree

6 files changed

+137
-4
lines changed

6 files changed

+137
-4
lines changed

awq/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .phi3_v import Phi3VAWQForCausalLM
2525
from .cohere import CohereAWQForCausalLM
2626
from .deepseek_v2 import DeepseekV2AWQForCausalLM
27+
from .deepseek_v3 import DeepseekV3AWQForCausalLM
2728
from .minicpm import MiniCPMAWQForCausalLM
2829
from .internlm2 import InternLM2AWQForCausalLM
2930
from .minicpm3 import MiniCPM3AWQForCausalLM

awq/models/auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"phi3_v": Phi3VAWQForCausalLM,
3535
"cohere": CohereAWQForCausalLM,
3636
"deepseek_v2": DeepseekV2AWQForCausalLM,
37+
"deepseek_v3": DeepseekV3AWQForCausalLM,
3738
"minicpm": MiniCPMAWQForCausalLM,
3839
"internlm2": InternLM2AWQForCausalLM,
3940
"minicpm3": MiniCPM3AWQForCausalLM,

awq/models/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"phi3_v": "AutoModelForCausalLM",
8383
"cohere": "AutoModelForCausalLM",
8484
"deepseek_v2": "AutoModelForCausalLM",
85+
"deepseek_v3": "AutoModelForCausalLM",
8586
"minicpm": "AutoModelForCausalLM",
8687
"minicpm3":"AutoModelForCausalLM",
8788
"internlm2": "AutoModelForCausalLM",

awq/models/deepseek_v3.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import tqdm
2+
from typing import List, Tuple
3+
from .base import BaseAWQForCausalLM
4+
5+
6+
class DeepseekV3AWQForCausalLM(BaseAWQForCausalLM):
7+
layer_type = "DeepseekV3DecoderLayer"
8+
max_seq_len_key = "max_position_embeddings"
9+
10+
@staticmethod
11+
def get_model_layers(model):
12+
return model.model.layers
13+
14+
@staticmethod
15+
def get_act_for_scaling(module):
16+
return dict(is_scalable=False)
17+
18+
@staticmethod
19+
def move_embed(model, device: str):
20+
model.model.embed_tokens = model.model.embed_tokens.to(device)
21+
22+
@staticmethod
23+
def get_layers_for_scaling(
24+
module, input_feat, module_kwargs
25+
):
26+
layers = []
27+
28+
if hasattr(module.self_attn, "q_proj"):
29+
# attention input
30+
layers.append(
31+
dict(
32+
prev_op=module.input_layernorm,
33+
layers=[
34+
module.self_attn.q_proj,
35+
module.self_attn.kv_a_proj_with_mqa,
36+
],
37+
inp=input_feat["self_attn.q_proj"],
38+
module2inspect=module.self_attn,
39+
kwargs=module_kwargs,
40+
)
41+
)
42+
else:
43+
# attention input
44+
layers.append(
45+
dict(
46+
prev_op=module.input_layernorm,
47+
layers=[
48+
module.self_attn.q_a_proj,
49+
module.self_attn.kv_a_proj_with_mqa,
50+
],
51+
inp=input_feat["self_attn.q_a_proj"],
52+
module2inspect=module.self_attn,
53+
kwargs=module_kwargs,
54+
)
55+
)
56+
layers.append(
57+
dict(
58+
prev_op=module.self_attn.q_a_layernorm,
59+
layers=[
60+
module.self_attn.q_b_proj,
61+
],
62+
inp=input_feat["self_attn.q_b_proj"],
63+
)
64+
)
65+
66+
# kv layernorm
67+
layers.append(
68+
dict(
69+
prev_op=module.self_attn.kv_a_layernorm,
70+
layers=[
71+
module.self_attn.kv_b_proj,
72+
],
73+
inp=input_feat["self_attn.kv_b_proj"],
74+
)
75+
)
76+
77+
if hasattr(module.mlp, "gate"):
78+
# linear in
79+
layers.append(
80+
dict(
81+
prev_op=module.post_attention_layernorm,
82+
layers=[
83+
w
84+
for expert in module.mlp.experts
85+
for w in [expert.gate_proj, expert.up_proj]
86+
] + [module.mlp.shared_experts.gate_proj, module.mlp.shared_experts.up_proj],
87+
inp=input_feat["mlp"],
88+
module2inspect=module.mlp,
89+
)
90+
)
91+
92+
# linear out
93+
for i, expert in enumerate(module.mlp.experts):
94+
layers.append(
95+
dict(
96+
prev_op=expert.up_proj,
97+
layers=[expert.down_proj],
98+
inp=input_feat[f"mlp.experts.{i}.down_proj"],
99+
)
100+
)
101+
layers.append(
102+
dict(
103+
prev_op=module.mlp.shared_experts.up_proj,
104+
layers=[module.mlp.shared_experts.down_proj],
105+
inp=input_feat[f"mlp.shared_experts.down_proj"],
106+
)
107+
)
108+
else:
109+
# linear 1
110+
layers.append(
111+
dict(
112+
prev_op=module.post_attention_layernorm,
113+
layers=[module.mlp.gate_proj, module.mlp.up_proj],
114+
inp=input_feat["mlp.gate_proj"],
115+
module2inspect=module.mlp,
116+
)
117+
)
118+
119+
# linear 2
120+
layers.append(
121+
dict(
122+
prev_op=module.mlp.up_proj,
123+
layers=[module.mlp.down_proj],
124+
inp=input_feat["mlp.down_proj"],
125+
)
126+
)
127+
128+
return layers

awq/quantize/quantizer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
def pseudo_quantize_tensor(self, w: torch.Tensor):
7474
org_w_shape = w.shape
7575
if self.group_size > 0:
76-
assert org_w_shape[-1] % self.group_size == 0
76+
assert org_w_shape[-1] % self.group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({self.group_size})!"
7777
w = w.reshape(-1, self.group_size)
7878
assert w.dim() == 2
7979
assert torch.isnan(w).sum() == 0
@@ -338,6 +338,7 @@ def _search_best_scale(
338338
with torch.no_grad():
339339
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
340340
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
341+
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)
341342

342343
# [STEP 4]: Compute loss
343344
best_scales = self._compute_best_scale(
@@ -406,6 +407,7 @@ def _compute_best_scale(
406407

407408
# W * X
408409
int_w_output = self._module_forward(x, module2inspect, kwargs)
410+
int_w_output = int_w_output.clip(torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max)
409411

410412
# compute mean squared error (L2 norm)
411413
loss = self._compute_loss(fp16_output, int_w_output, device)
@@ -623,7 +625,7 @@ def cache_input_hook(m, x, y, name, feat_dict):
623625
"block_sparse_moe": layer.block_sparse_moe,
624626
}
625627

626-
if self.awq_model.model_type == "deepseek_v2":
628+
if self.awq_model.model_type == "deepseek_v2" or self.awq_model.model_type == "deepseek_v3":
627629
named_linears = {
628630
**named_linears,
629631
"mlp": layer.mlp,

awq/utils/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def set_module_name(model, name, value):
7272
def clear_memory(weight=None):
7373
if weight is not None:
7474
del weight
75-
gc.collect()
76-
torch.cuda.empty_cache()
75+
# gc.collect()
76+
# torch.cuda.empty_cache()
7777

7878

7979
def compute_memory_used_pct(device):

0 commit comments

Comments
 (0)