@@ -48,28 +48,20 @@ import logix
48
48
# Initialze LogIX
49
49
run = logix.init(project = " my_project" )
50
50
51
- # Users can specify artifacts they want to log
52
- run.setup({" log" : " grad" , " save" : " grad" , " statistic" : " kfac" })
53
-
54
- # Users can specify specific modules they want to track logs for
51
+ # Specify modules to be tracked for logging
55
52
run.watch(model, name_filter = [" mlp" ], type_filter = [nn.Linear])
56
53
57
- for input , target in data_loader:
58
- # Set data_id for the log from the current batch
59
- with run(data_id = input ):
60
- out = model(input )
61
- loss = loss_fn(out, target, reduction = " sum" )
62
- loss.backward()
63
- model.zero_grad()
64
-
65
- # Access log extracted in the LogIX context block
66
- log = run.get_log() # (data_id, log_dict)
67
- # For example, users can print gradient for the specific module
68
- # print(log[1]["model.layers.23.mlp.down_proj"]["grad"])
69
- # or perform any custom analysis
54
+ # Specify plugins to be used in logging
55
+ run.setup({" grad" : [" log" , " covariance" ]})
56
+ run.save(True )
70
57
71
- # Synchronize statistics (e.g. grad covariance) and
72
- # write remaining logs to disk
58
+ for batch in data_loader:
59
+ # Set `data_id` (and optionally `mask`) for the current batch
60
+ with run(data_id = batch[" input_ids" ], mask = batch[" attention_mask" ]):
61
+ model.zero_grad()
62
+ loss = model(batch)
63
+ loss.backward()
64
+ # Synchronize statistics (e.g. covariance) and write logs to disk
73
65
run.finalize()
74
66
```
75
67
@@ -81,13 +73,11 @@ pre-implemented interpretability algorithms if there is a demand.
81
73
# Build PyTorch DataLoader from saved log data
82
74
log_loader = run.build_log_dataloader()
83
75
84
- with run(data_id = test_input):
85
- test_out = model(test_input)
86
- test_loss = loss_fn(test_out, test_target, reduction = " sum" )
76
+ with run(data_id = test_batch[" input_ids" ]):
77
+ test_loss = model(test_batch)
87
78
test_loss.backward()
88
- # Extract a log for test data
89
- test_log = run.get_log()
90
79
80
+ test_log = run.get_log()
91
81
run.influence.compute_influence_all(test_log, log_loader) # Data attribution
92
82
run.influence.compute_self_influence(test_log) # Uncertainty estimation
93
83
```
0 commit comments