Skip to content

Commit 46296bc

Browse files
committed
Address review comments from @markurtz
1 parent 5cf596c commit 46296bc

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

examples/llama-3/sft.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
22

33
plugins:
4-
- axolotl.integrations.llmcompressor_sft.SFTPlugin
4+
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
55

66
load_in_8bit: false
77
load_in_4bit: false

setup.py

-3
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,6 @@ def get_package_version():
143143
"vllm": [
144144
"vllm==0.7.2",
145145
],
146-
"llmcompressor": [
147-
"llm-compressor==0.5.0",
148-
],
149146
}
150147

151148
install_requires, dependency_links, extras_require_build = parse_requirements(

src/axolotl/integrations/llmcompressor_sft/__init__.py src/axolotl/integrations/llm_compressor/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
P = ParamSpec("P") # Params for generic function signatures
2020
R = TypeVar("R") # Return type for generic function signatures
2121

22-
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
22+
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
2323

2424

25-
class SFTCallbackHandler(TrainerCallback):
25+
class LLMCompressorCallbackHandler(TrainerCallback):
2626
"""
2727
Trainer callback for Sparse Finetuning.
2828
Maintains sparsity patterns during training by applying masks after optimization steps,
@@ -111,7 +111,7 @@ def on_train_end(
111111
session.finalize()
112112

113113

114-
class SFTPlugin(BasePlugin):
114+
class LLMCompressorPlugin(BasePlugin):
115115
"""
116116
Sparse Finetuning plugin for Axolotl integration.
117117
"""
@@ -123,7 +123,7 @@ def get_input_args(self) -> str:
123123
Returns:
124124
str: Dotted path to the LLMCompressorArgs class.
125125
"""
126-
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
126+
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
127127

128128
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
129129
"""
@@ -137,7 +137,7 @@ def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
137137
list: List containing the configured callback instances.
138138
"""
139139
LOG.info("Adding Sparse Finetuning callback to the trainer")
140-
callback = SFTCallbackHandler(
140+
callback = LLMCompressorCallbackHandler(
141141
trainer=trainer,
142142
recipe=cfg.llmcompressor.recipe,
143143
)

src/axolotl/integrations/llmcompressor_sft/args.py src/axolotl/integrations/llm_compressor/args.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
from typing_extensions import Annotated
99

1010

11-
class SFTArgs(BaseModel):
11+
class CompressionArgs(BaseModel):
1212
"""Sparse Finetuning config for LLMCompressor."""
1313

1414
# Typing for recipe is set to Any due to:
1515
# https://github.com/vllm-project/llm-compressor/issues/1319
16-
recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")]
16+
recipe: Annotated[
17+
Any,
18+
Field(
19+
description="The recipe containing the compression algorithms and hyperparameters to apply."
20+
),
21+
]
1722

1823
model_config = ConfigDict(
1924
arbitrary_types_allowed=True,
@@ -24,7 +29,12 @@ class SFTArgs(BaseModel):
2429
class LLMCompressorArgs(BaseModel):
2530
"""LLMCompressor configuration BaseModel."""
2631

27-
llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")]
32+
llmcompressor: Annotated[
33+
CompressionArgs,
34+
Field(
35+
description="Arguments enabling compression pathways through the LLM Compressor plugins"
36+
),
37+
]
2838

2939
model_config = ConfigDict(
3040
validate_assignment=True,

0 commit comments

Comments
 (0)