Skip to content

Commit 07ff5a2

Browse files
HandH1998zhangying.1998Taka152
authored
compatible gcq params (#409)
* fix gcq params * fix format Co-authored-by: zhangying.1998 <zhangying.1998@bytedance.com> Co-authored-by: xiongying.taka <xiongying.taka@bytedance.com>
1 parent b665742 commit 07ff5a2

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

examples/training/huggingface/gcq/ls_hf_gcq_trainer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ def __init__(self, gcq_args: GCQArguments = None, *args, **kwargs):
2626
def _wrap_model(self, model, training=True, dataloader=None):
2727
model = super()._wrap_model(model, training, dataloader)
2828
# Enable GCQ.
29-
if (
30-
isinstance(model, torch.nn.parallel.DistributedDataParallel)
31-
and self.gcq_args.enable_GCQ
29+
if isinstance(model, torch.nn.parallel.DistributedDataParallel) and getattr(
30+
self.gcq_args, "enable_GCQ", False
3231
):
3332
assert version.parse(torch.__version__) >= version.parse(
3433
"1.10"

lightseq/training/gcq/ls_fs_gcq_trainer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ def __init__(self, *args, **kwargs):
2424
def model(self):
2525
if self._wrapped_model is None:
2626
super().model
27-
if (
28-
isinstance(self._wrapped_model, DistributedDataParallel)
29-
and self.args.enable_GCQ
27+
if isinstance(self._wrapped_model, DistributedDataParallel) and getattr(
28+
self.args, "enable_GCQ", False
3029
):
3130
assert version.parse(torch.__version__) >= version.parse(
3231
"1.10"

0 commit comments

Comments
 (0)