Skip to content

Commit d4b7d9f

Browse files
authored
Update README.md
1 parent c928e4b commit d4b7d9f

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

README.md

+45-3
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,20 @@ run.influence.compute_self_influence(test_log) # Uncertainty estimation
8585
### HuggingFace Integration
8686
Our software design allows for the seamless integration with HuggingFace's
8787
[Transformer](https://github.com/huggingface/transformers/tree/main), a popular DL framework
88-
that conveniently handles distributed training, data loading, etc. We plan to support more
89-
frameworks (e.g. Lightning) in the future.
88+
that conveniently handles distributed training, data loading, etc.
9089

9190
```python
9291
from transformers import Trainer, Seq2SeqTrainer
9392
from logix.huggingface import patch_trainer, LogIXArguments
9493

95-
logix_args = LogIXArguments(project, config, lora=True, hessian="raw", save="grad")
94+
# Define LogIX arguments
95+
logix_args = LogIXArguments(project="myproject",
96+
config="config.yaml",
97+
lora=True,
98+
hessian="raw",
99+
save="grad")
100+
101+
# Patch HF Trainer
96102
LogIXTrainer = patch_trainer(Trainer)
97103

98104
# Pass LogIXArguments as TrainingArguments
@@ -108,6 +114,42 @@ trainer.influence()
108114
trainer.self_influence()
109115
```
110116

117+
### PyTorch Lightning Integration
118+
Similarly, we also support the LogIX + PyTorch Lightning integration. The code example
119+
is provided below.
120+
121+
```python
122+
from lightning import LightningModule, Trainer
123+
from logix.lightning import patch, LogIXArguments
124+
125+
class MyLitModule(LightningModule):
126+
...
127+
128+
def data_id_extractor(batch):
129+
return tokenizer.batch_decode(batch["input_ids"])
130+
131+
# Define LogIX arguments
132+
logix_args = LogIXArguments(project="myproject",
133+
config="config.yaml",
134+
lora=True,
135+
hessian="raw",
136+
save="grad")
137+
138+
# Patch Lightning Module and Trainer
139+
LogIXModule, LogIXTrainer = patch(MyLitModule,
140+
Trainer,
141+
logix_args=logix_args,
142+
data_id_extractor=data_id_extractor)
143+
144+
# Use patched Module and Trainer as before
145+
module = LogIXModule(user_args)
146+
trainer = LogIXTrainer(user_args)
147+
148+
# Instead of trainer.fit(module, train_loader), use
149+
trainer.extract_log(module, train_loader)
150+
trainer.influence(module, train_loader)
151+
```
152+
111153
Please check out [Examples](/examples) for more detailed examples!
112154

113155

0 commit comments

Comments
 (0)