Skip to content

Commit 333fc54

Browse files
Yada PruksachatkunYada Pruksachatkunpyeres
authored
Fixing Edge-Probing after muti-GPU release (#1025)
* fixing update_metrics for EdgeProbing * Throwing error on multi-GPU * Fixing weight and model in different GPU multi-GPU error * remove exception on multi-GPU * remove unbind_predictions() * move unbind_predictions into edge probing task handle_preds method * update comments and docstrings Co-authored-by: Yada Pruksachatkun <pruks22y@mtholyoke.edu> Co-authored-by: Phil Yeres <6176602+pyeres@users.noreply.github.com>
1 parent 57ea962 commit 333fc54

File tree

3 files changed

+54
-40
lines changed

3 files changed

+54
-40
lines changed

jiant/models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,13 @@ def forward(self, task, batch, predict=False):
868868
# Just get embeddings and invoke task module.
869869
word_embs_in_context, sent_mask = self.sent_encoder(batch["input1"], task)
870870
module = getattr(self, "%s_mdl" % task.name)
871-
out = module.forward(batch, word_embs_in_context, sent_mask, task, predict)
871+
out = module.forward(
872+
batch=batch,
873+
word_embs_in_context=word_embs_in_context,
874+
sent_mask=sent_mask,
875+
task=task,
876+
predict=predict,
877+
)
872878
elif isinstance(task, SequenceGenerationTask):
873879
out = self._seq_gen_forward(batch, task, predict)
874880
elif isinstance(task, (MultiRCTask, ReCoRDTask)):

jiant/modules/edge_probing.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Implementation of edge probing module.
2+
from typing import Dict
23

3-
from typing import Dict, Iterable
4-
5-
import numpy as np
64
import torch
75
import torch.nn as nn
86
import torch.nn.functional as F
@@ -80,19 +78,19 @@ def __init__(self, task, d_inp: int, task_params):
8078
if self.is_symmetric or self.single_sided:
8179
# Use None as dummy padding for readability,
8280
# so that we can index projs[1] and projs[2]
83-
self.projs = [None, self.proj1, self.proj1]
81+
self.projs = nn.ModuleList([None, self.proj1, self.proj1])
8482
else:
8583
# Separate params for span2
8684
self.proj2 = self._make_cnn_layer(d_inp)
87-
self.projs = [None, self.proj1, self.proj2]
85+
self.projs = nn.ModuleList([None, self.proj1, self.proj2])
8886

8987
# Span extractor, shared for both span1 and span2.
9088
self.span_extractor1 = self._make_span_extractor()
9189
if self.is_symmetric or self.single_sided:
92-
self.span_extractors = [None, self.span_extractor1, self.span_extractor1]
90+
self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor1])
9391
else:
9492
self.span_extractor2 = self._make_span_extractor()
95-
self.span_extractors = [None, self.span_extractor1, self.span_extractor2]
93+
self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor2])
9694

9795
# Classifier gets concatenated projections of span1, span2
9896
clf_input_dim = self.span_extractors[1].get_output_dim()
@@ -131,11 +129,9 @@ def forward(
131129
"""
132130
out = {}
133131

134-
batch_size = word_embs_in_context.shape[0]
135-
out["n_inputs"] = batch_size
136-
137132
# Apply projection CNN layer for each span.
138133
word_embs_in_context_t = word_embs_in_context.transpose(1, 2) # needed for CNN layer
134+
139135
se_proj1 = self.projs[1](word_embs_in_context_t).transpose(2, 1).contiguous()
140136
if not self.single_sided:
141137
se_proj2 = self.projs[2](word_embs_in_context_t).transpose(2, 1).contiguous()
@@ -169,28 +165,10 @@ def forward(
169165
out["loss"] = self.compute_loss(logits[span_mask], batch["labels"][span_mask], task)
170166

171167
if predict:
172-
# Return preds as a list.
173-
preds = self.get_predictions(logits)
174-
out["preds"] = list(self.unbind_predictions(preds, span_mask))
168+
out["preds"] = self.get_predictions(logits)
175169

176170
return out
177171

178-
def unbind_predictions(self, preds: torch.Tensor, masks: torch.Tensor) -> Iterable[np.ndarray]:
179-
""" Unpack preds to varying-length numpy arrays.
180-
181-
Args:
182-
preds: [batch_size, num_targets, ...]
183-
masks: [batch_size, num_targets] boolean mask
184-
185-
Yields:
186-
np.ndarray for each row of preds, selected by the corresponding row
187-
of span_mask.
188-
"""
189-
preds = preds.detach().cpu()
190-
masks = masks.detach().cpu()
191-
for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)):
192-
yield pred[mask].numpy() # only non-masked predictions
193-
194172
def get_predictions(self, logits: torch.Tensor):
195173
"""Return class probabilities, same shape as logits.
196174
@@ -218,16 +196,6 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgePro
218196
Returns:
219197
loss: scalar Tensor
220198
"""
221-
binary_preds = logits.ge(0).long() # {0,1}
222-
223-
# Matthews coefficient and accuracy computed on {0,1} labels.
224-
task.mcc_scorer(binary_preds, labels.long())
225-
task.acc_scorer(binary_preds, labels.long())
226-
227-
# F1Measure() expects [total_num_targets, n_classes, 2]
228-
# to compute binarized F1.
229-
binary_scores = torch.stack([-1 * logits, logits], dim=2)
230-
task.f1_scorer(binary_scores, labels)
231199

232200
if self.loss_type == "sigmoid":
233201
return F.binary_cross_entropy(torch.sigmoid(logits), labels.float())

jiant/tasks/edge_probing.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import logging as log
55
import os
6+
import torch
67
from typing import Dict, Iterable, List, Sequence, Type
78

89
# Fields for instance processing
@@ -159,6 +160,45 @@ def load_data(self):
159160
iters_by_split[split] = iter
160161
self._iters_by_split = iters_by_split
161162

163+
def update_metrics(self, out, batch):
164+
span_mask = batch["span1s"][:, :, 0] != -1
165+
logits = out["logits"][span_mask]
166+
labels = batch["labels"][span_mask]
167+
168+
binary_preds = logits.ge(0).long() # {0,1}
169+
170+
# Matthews coefficient and accuracy computed on {0,1} labels.
171+
self.mcc_scorer(binary_preds, labels.long())
172+
self.acc_scorer(binary_preds, labels.long())
173+
174+
# F1Measure() expects [total_num_targets, n_classes, 2]
175+
# to compute binarized F1.
176+
binary_scores = torch.stack([-1 * logits, logits], dim=2)
177+
self.f1_scorer(binary_scores, labels)
178+
179+
def handle_preds(self, preds, batch):
180+
"""Unpack preds into varying-length numpy arrays, return the non-masked preds in a list.
181+
182+
Parameters
183+
----------
184+
preds : [batch_size, num_targets, ...]
185+
batch : dict
186+
dict with key "span1s" having val w/ bool Tensor dim [batch_size, num_targets, ...].
187+
188+
Returns
189+
-------
190+
non_masked_preds : list[np.ndarray]
191+
list of of pred np.ndarray selected by the corresponding row of span_mask.
192+
193+
"""
194+
masks = batch["span1s"][:, :, 0] != -1
195+
preds = preds.detach().cpu()
196+
masks = masks.detach().cpu()
197+
non_masked_preds = []
198+
for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)):
199+
non_masked_preds.append(pred[mask].numpy()) # only non-masked predictions
200+
return non_masked_preds
201+
162202
def get_split_text(self, split: str):
163203
""" Get split text as iterable of records.
164204

0 commit comments

Comments
 (0)