@@ -85,14 +85,20 @@ run.influence.compute_self_influence(test_log) # Uncertainty estimation
85
85
### HuggingFace Integration
86
86
Our software design allows for the seamless integration with HuggingFace's
87
87
[ Transformer] ( https://github.com/huggingface/transformers/tree/main ) , a popular DL framework
88
- that conveniently handles distributed training, data loading, etc. We plan to support more
89
- frameworks (e.g. Lightning) in the future.
88
+ that conveniently handles distributed training, data loading, etc.
90
89
91
90
``` python
92
91
from transformers import Trainer, Seq2SeqTrainer
93
92
from logix.huggingface import patch_trainer, LogIXArguments
94
93
95
- logix_args = LogIXArguments(project, config, lora = True , hessian = " raw" , save = " grad" )
94
+ # Define LogIX arguments
95
+ logix_args = LogIXArguments(project = " myproject" ,
96
+ config = " config.yaml" ,
97
+ lora = True ,
98
+ hessian = " raw" ,
99
+ save = " grad" )
100
+
101
+ # Patch HF Trainer
96
102
LogIXTrainer = patch_trainer(Trainer)
97
103
98
104
# Pass LogIXArguments as TrainingArguments
@@ -108,6 +114,42 @@ trainer.influence()
108
114
trainer.self_influence()
109
115
```
110
116
117
+ ### PyTorch Lightning Integration
118
+ Similarly, we also support the LogIX + PyTorch Lightning integration. The code example
119
+ is provided below.
120
+
121
+ ``` python
122
+ from lightning import LightningModule, Trainer
123
+ from logix.lightning import patch, LogIXArguments
124
+
125
+ class MyLitModule (LightningModule ):
126
+ ...
127
+
128
+ def data_id_extractor (batch ):
129
+ return tokenizer.batch_decode(batch[" input_ids" ])
130
+
131
+ # Define LogIX arguments
132
+ logix_args = LogIXArguments(project = " myproject" ,
133
+ config = " config.yaml" ,
134
+ lora = True ,
135
+ hessian = " raw" ,
136
+ save = " grad" )
137
+
138
+ # Patch Lightning Module and Trainer
139
+ LogIXModule, LogIXTrainer = patch(MyLitModule,
140
+ Trainer,
141
+ logix_args = logix_args,
142
+ data_id_extractor = data_id_extractor)
143
+
144
+ # Use patched Module and Trainer as before
145
+ module = LogIXModule(user_args)
146
+ trainer = LogIXTrainer(user_args)
147
+
148
+ # Instead of trainer.fit(module, train_loader), use
149
+ trainer.extract_log(module, train_loader)
150
+ trainer.influence(module, train_loader)
151
+ ```
152
+
111
153
Please check out [ Examples] ( /examples ) for more detailed examples!
112
154
113
155
0 commit comments