Skip to content

Commit 135d672

Browse files
committed
pre commit hooks
1 parent d21b2d1 commit 135d672

File tree

4 files changed

+39
-25
lines changed

4 files changed

+39
-25
lines changed

examples/llama-3/sft.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
1+
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
22
# TODO: change to
33
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
44

@@ -68,11 +68,11 @@ llmcompressor:
6868
ConstantPruningModifier:
6969
targets: [
7070
're:.*q_proj.weight',
71-
're:.*k_proj.weight',
71+
're:.*k_proj.weight',
7272
're:.*v_proj.weight',
7373
're:.*o_proj.weight',
7474
're:.*gate_proj.weight',
7575
're:.*up_proj.weight',
7676
're:.*down_proj.weight',
7777
]
78-
start: 0
78+
start: 0

src/axolotl/integrations/llmcompressor_sft/__init__.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55

66
import logging
77
from functools import wraps
8-
from typing import Callable, TypeVar, ParamSpec, Any
8+
from typing import Any, Callable, ParamSpec, TypeVar
99

10+
from llmcompressor import active_session
11+
from llmcompressor.core import callbacks as session_callbacks
12+
from llmcompressor.recipe import Recipe
1013
from transformers.trainer import Trainer
11-
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
14+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
1215
from transformers.training_args import TrainingArguments
1316

1417
from ..base import BasePlugin
15-
from llmcompressor import active_session
16-
from llmcompressor.core import callbacks as session_callbacks
17-
from llmcompressor.recipe import Recipe
1818

1919
P = ParamSpec("P") # Params for generic function signatures
20-
R = TypeVar("R") # Return type for generic function signatures
20+
R = TypeVar("R") # Return type for generic function signatures
2121

2222
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
2323

@@ -39,11 +39,17 @@ def __init__(self, trainer: Trainer, recipe: Any):
3939
"""
4040
super().__init__()
4141
self.trainer = trainer
42-
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
42+
self.recipe = (
43+
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
44+
)
4345
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
4446

4547
def on_train_begin(
46-
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
48+
self,
49+
args: TrainingArguments,
50+
state: TrainerState,
51+
control: TrainerControl,
52+
**kwargs,
4753
) -> None:
4854
"""
4955
Called at the beginning of training. Initializes the compression session.
@@ -63,7 +69,11 @@ def on_train_begin(
6369
)
6470

6571
def on_step_begin(
66-
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
72+
self,
73+
args: TrainingArguments,
74+
state: TrainerState,
75+
control: TrainerControl,
76+
**kwargs,
6777
) -> None:
6878
"""
6979
Called at the beginning of a training step. Triggers batch_start callback.
@@ -72,7 +82,11 @@ def on_step_begin(
7282
session_callbacks.batch_start()
7383

7484
def on_step_end(
75-
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
85+
self,
86+
args: TrainingArguments,
87+
state: TrainerState,
88+
control: TrainerControl,
89+
**kwargs,
7690
) -> None:
7791
"""
7892
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
@@ -83,7 +97,11 @@ def on_step_end(
8397
session_callbacks.batch_end()
8498

8599
def on_train_end(
86-
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
100+
self,
101+
args: TrainingArguments,
102+
state: TrainerState,
103+
control: TrainerControl,
104+
**kwargs,
87105
) -> None:
88106
"""
89107
Called at the end of training. Finalizes the compression session.

src/axolotl/integrations/llmcompressor_sft/args.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
LLMCompressor and Sparse Finetuning config models.
33
"""
44

5-
from pydantic import BaseModel, Field, ConfigDict
65
from typing import Any
6+
7+
from pydantic import BaseModel, ConfigDict, Field
78
from typing_extensions import Annotated
89

10+
911
class SFTArgs(BaseModel):
1012
"""Sparse Finetuning config for LLMCompressor."""
1113

1214
# Typing for recipe is set to Any due to:
1315
# https://github.com/vllm-project/llm-compressor/issues/1319
14-
recipe: Annotated[
15-
Any,
16-
Field(description="Recipe config.")
17-
]
16+
recipe: Annotated[Any, Field(description="Recipe config.")]
1817

1918
model_config = ConfigDict(
2019
arbitrary_types_allowed=True,
@@ -25,10 +24,7 @@ class SFTArgs(BaseModel):
2524
class LLMCompressorArgs(BaseModel):
2625
"""LLMCompressor configuration BaseModel."""
2726

28-
llmcompressor: Annotated[
29-
SFTArgs,
30-
Field(description="SFT llmcompressor args")
31-
]
27+
llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")]
3228

3329
model_config = ConfigDict(
3430
validate_assignment=True,

src/axolotl/utils/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
137137
hasattr(model_config, "quantization_config")
138138
and model_config.quantization_config
139139
)
140-
140+
141141
# Detect compressed-tensors config
142142
is_compressed_tensors_config = (
143143
quant_config_exists
@@ -152,7 +152,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
152152
)
153153
# Skip further quant checks for compressed-tensors
154154
return
155-
155+
156156
quant_config_method_is_gptq = (
157157
quant_config_exists
158158
and "quant_method" in model_config.quantization_config

0 commit comments

Comments
 (0)