Skip to content

Commit a5af954

Browse files
authored
Update README.md
1 parent bc22e8b commit a5af954

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

README.md

+14-24
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,20 @@ import logix
4848
# Initialze LogIX
4949
run = logix.init(project="my_project")
5050

51-
# Users can specify artifacts they want to log
52-
run.setup({"log": "grad", "save": "grad", "statistic": "kfac"})
53-
54-
# Users can specify specific modules they want to track logs for
51+
# Specify modules to be tracked for logging
5552
run.watch(model, name_filter=["mlp"], type_filter=[nn.Linear])
5653

57-
for input, target in data_loader:
58-
# Set data_id for the log from the current batch
59-
with run(data_id=input):
60-
out = model(input)
61-
loss = loss_fn(out, target, reduction="sum")
62-
loss.backward()
63-
model.zero_grad()
64-
65-
# Access log extracted in the LogIX context block
66-
log = run.get_log() # (data_id, log_dict)
67-
# For example, users can print gradient for the specific module
68-
# print(log[1]["model.layers.23.mlp.down_proj"]["grad"])
69-
# or perform any custom analysis
54+
# Specify plugins to be used in logging
55+
run.setup({"grad": ["log", "covariance"]})
56+
run.save(True)
7057

71-
# Synchronize statistics (e.g. grad covariance) and
72-
# write remaining logs to disk
58+
for batch in data_loader:
59+
# Set `data_id` (and optionally `mask`) for the current batch
60+
with run(data_id=batch["input_ids"], mask=batch["attention_mask"]):
61+
model.zero_grad()
62+
loss = model(batch)
63+
loss.backward()
64+
# Synchronize statistics (e.g. covariance) and write logs to disk
7365
run.finalize()
7466
```
7567

@@ -81,13 +73,11 @@ pre-implemented interpretability algorithms if there is a demand.
8173
# Build PyTorch DataLoader from saved log data
8274
log_loader = run.build_log_dataloader()
8375

84-
with run(data_id=test_input):
85-
test_out = model(test_input)
86-
test_loss = loss_fn(test_out, test_target, reduction="sum")
76+
with run(data_id=test_batch["input_ids"]):
77+
test_loss = model(test_batch)
8778
test_loss.backward()
88-
# Extract a log for test data
89-
test_log = run.get_log()
9079

80+
test_log = run.get_log()
9181
run.influence.compute_influence_all(test_log, log_loader) # Data attribution
9282
run.influence.compute_self_influence(test_log) # Uncertainty estimation
9383
```

0 commit comments

Comments
 (0)