Skip to content
This repository was archived by the owner on Jul 9, 2024. It is now read-only.

Commit 110f877

Browse files
authored
fix: Revert "fix: mask-out test labels during gradient updates in training loop (#728)" which has negative impact for roc/auc (#764)
This reverts commit 95f650d.
1 parent 97ca5ba commit 110f877

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

src/sagemaker/FD_SL_DGL/gnn_fraud_detection_dgl/fd_sl_train_entry_point.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,14 @@ def train_fg(model, optim, loss, features, labels, train_g, test_g, test_mask,
3535
A full graph verison of RGCN training
3636
"""
3737

38-
# get list of indecies for train labels
39-
train_mask= th.logical_not(test_mask)
40-
train_idx= th.nonzero(train_mask, as_tuple=True)[0]
41-
4238
duration = []
4339
for epoch in range(n_epochs):
4440
tic = time.time()
4541
loss_val = 0.
4642

4743
pred = model(train_g, features.to(device))
4844

49-
# only compute gradient updates for labels in train split
50-
l = loss(th.index_select(pred, 0, train_idx),
51-
th.index_select(labels, 0, train_idx))
45+
l = loss(pred, labels)
5246

5347
optim.zero_grad()
5448
l.backward()
@@ -262,4 +256,4 @@ def get_model(ntype_dict, etypes, hyperparams, in_feats, n_classes, device):
262256

263257
print("Saving model")
264258
save_model(g, model, args.model_dir, id_to_node, mean, stdev)
265-
print("Model and metadata saved")
259+
print("Model and metadata saved")

0 commit comments

Comments
 (0)