Skip to content

Commit

Permalink
use DefaultWeightsLoader in skip modules
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Feb 28, 2025
1 parent b7bdbbd commit e66bbff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
20 changes: 7 additions & 13 deletions server/text_generation_server/layers/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Weight,
Weights,
WeightsLoader,
UnquantizedWeight,
DefaultWeightsLoader,
)

if SYSTEM == "ipex":
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
quant_method: str,
quantize: str,
sym: bool,
modules_to_not_convert: Optional[List[str]],
modules_to_not_convert: List[str],
):
self.bits = bits
self.desc_act = desc_act
Expand All @@ -117,8 +117,7 @@ def get_weights(self, weights: Weights, prefix: str):
use_exllama = False

if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_tensor(f"{prefix}.weight")
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_weights(weights, prefix)

try:
qweight = weights.get_tensor(f"{prefix}.qweight")
Expand Down Expand Up @@ -200,10 +199,9 @@ def get_weights_col_packed(
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
return UnquantizedWeight(w)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
Expand Down Expand Up @@ -256,10 +254,7 @@ def get_weights_col_packed(

def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
w = torch.cat(
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
)
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
Expand Down Expand Up @@ -339,8 +334,7 @@ def get_weights_row(self, weights: Weights, prefix: str):
use_exllama = False

if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_weights_row(weights, prefix)
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", None
"modules_to_not_convert", []
)
except Exception:
filename = "quantize_config.json"
Expand Down

0 comments on commit e66bbff

Please sign in to comment.