5
5
6
6
import logging
7
7
from functools import wraps
8
- from typing import Callable , TypeVar , ParamSpec , Any
8
+ from typing import Any , Callable , ParamSpec , TypeVar
9
9
10
+ from llmcompressor import active_session
11
+ from llmcompressor .core import callbacks as session_callbacks
12
+ from llmcompressor .recipe import Recipe
10
13
from transformers .trainer import Trainer
11
- from transformers .trainer_callback import TrainerCallback , TrainerState , TrainerControl
14
+ from transformers .trainer_callback import TrainerCallback , TrainerControl , TrainerState
12
15
from transformers .training_args import TrainingArguments
13
16
14
17
from ..base import BasePlugin
15
- from llmcompressor import active_session
16
- from llmcompressor .core import callbacks as session_callbacks
17
- from llmcompressor .recipe import Recipe
18
18
19
19
P = ParamSpec ("P" ) # Params for generic function signatures
20
- R = TypeVar ("R" ) # Return type for generic function signatures
20
+ R = TypeVar ("R" ) # Return type for generic function signatures
21
21
22
22
LOG = logging .getLogger ("axolotl.integrations.llmcompressor_sft" )
23
23
@@ -39,11 +39,17 @@ def __init__(self, trainer: Trainer, recipe: Any):
39
39
"""
40
40
super ().__init__ ()
41
41
self .trainer = trainer
42
- self .recipe = Recipe .model_validate (recipe ) if not isinstance (recipe , Recipe ) else recipe
42
+ self .recipe = (
43
+ Recipe .model_validate (recipe ) if not isinstance (recipe , Recipe ) else recipe
44
+ )
43
45
self .trainer .compute_loss = compute_loss_wrapper (self .trainer .compute_loss )
44
46
45
47
def on_train_begin (
46
- self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs
48
+ self ,
49
+ args : TrainingArguments ,
50
+ state : TrainerState ,
51
+ control : TrainerControl ,
52
+ ** kwargs ,
47
53
) -> None :
48
54
"""
49
55
Called at the beginning of training. Initializes the compression session.
@@ -63,7 +69,11 @@ def on_train_begin(
63
69
)
64
70
65
71
def on_step_begin (
66
- self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs
72
+ self ,
73
+ args : TrainingArguments ,
74
+ state : TrainerState ,
75
+ control : TrainerControl ,
76
+ ** kwargs ,
67
77
) -> None :
68
78
"""
69
79
Called at the beginning of a training step. Triggers batch_start callback.
@@ -72,7 +82,11 @@ def on_step_begin(
72
82
session_callbacks .batch_start ()
73
83
74
84
def on_step_end (
75
- self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs
85
+ self ,
86
+ args : TrainingArguments ,
87
+ state : TrainerState ,
88
+ control : TrainerControl ,
89
+ ** kwargs ,
76
90
) -> None :
77
91
"""
78
92
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
@@ -83,7 +97,11 @@ def on_step_end(
83
97
session_callbacks .batch_end ()
84
98
85
99
def on_train_end (
86
- self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs
100
+ self ,
101
+ args : TrainingArguments ,
102
+ state : TrainerState ,
103
+ control : TrainerControl ,
104
+ ** kwargs ,
87
105
) -> None :
88
106
"""
89
107
Called at the end of training. Finalizes the compression session.
0 commit comments