|
1 | 1 | # Implementation of edge probing module.
|
| 2 | +from typing import Dict |
2 | 3 |
|
3 |
| -from typing import Dict, Iterable |
4 |
| - |
5 |
| -import numpy as np |
6 | 4 | import torch
|
7 | 5 | import torch.nn as nn
|
8 | 6 | import torch.nn.functional as F
|
@@ -80,19 +78,19 @@ def __init__(self, task, d_inp: int, task_params):
|
80 | 78 | if self.is_symmetric or self.single_sided:
|
81 | 79 | # Use None as dummy padding for readability,
|
82 | 80 | # so that we can index projs[1] and projs[2]
|
83 |
| - self.projs = [None, self.proj1, self.proj1] |
| 81 | + self.projs = nn.ModuleList([None, self.proj1, self.proj1]) |
84 | 82 | else:
|
85 | 83 | # Separate params for span2
|
86 | 84 | self.proj2 = self._make_cnn_layer(d_inp)
|
87 |
| - self.projs = [None, self.proj1, self.proj2] |
| 85 | + self.projs = nn.ModuleList([None, self.proj1, self.proj2]) |
88 | 86 |
|
89 | 87 | # Span extractor, shared for both span1 and span2.
|
90 | 88 | self.span_extractor1 = self._make_span_extractor()
|
91 | 89 | if self.is_symmetric or self.single_sided:
|
92 |
| - self.span_extractors = [None, self.span_extractor1, self.span_extractor1] |
| 90 | + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor1]) |
93 | 91 | else:
|
94 | 92 | self.span_extractor2 = self._make_span_extractor()
|
95 |
| - self.span_extractors = [None, self.span_extractor1, self.span_extractor2] |
| 93 | + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor2]) |
96 | 94 |
|
97 | 95 | # Classifier gets concatenated projections of span1, span2
|
98 | 96 | clf_input_dim = self.span_extractors[1].get_output_dim()
|
@@ -131,11 +129,9 @@ def forward(
|
131 | 129 | """
|
132 | 130 | out = {}
|
133 | 131 |
|
134 |
| - batch_size = word_embs_in_context.shape[0] |
135 |
| - out["n_inputs"] = batch_size |
136 |
| - |
137 | 132 | # Apply projection CNN layer for each span.
|
138 | 133 | word_embs_in_context_t = word_embs_in_context.transpose(1, 2) # needed for CNN layer
|
| 134 | + |
139 | 135 | se_proj1 = self.projs[1](word_embs_in_context_t).transpose(2, 1).contiguous()
|
140 | 136 | if not self.single_sided:
|
141 | 137 | se_proj2 = self.projs[2](word_embs_in_context_t).transpose(2, 1).contiguous()
|
@@ -169,28 +165,10 @@ def forward(
|
169 | 165 | out["loss"] = self.compute_loss(logits[span_mask], batch["labels"][span_mask], task)
|
170 | 166 |
|
171 | 167 | if predict:
|
172 |
| - # Return preds as a list. |
173 |
| - preds = self.get_predictions(logits) |
174 |
| - out["preds"] = list(self.unbind_predictions(preds, span_mask)) |
| 168 | + out["preds"] = self.get_predictions(logits) |
175 | 169 |
|
176 | 170 | return out
|
177 | 171 |
|
178 |
| - def unbind_predictions(self, preds: torch.Tensor, masks: torch.Tensor) -> Iterable[np.ndarray]: |
179 |
| - """ Unpack preds to varying-length numpy arrays. |
180 |
| -
|
181 |
| - Args: |
182 |
| - preds: [batch_size, num_targets, ...] |
183 |
| - masks: [batch_size, num_targets] boolean mask |
184 |
| -
|
185 |
| - Yields: |
186 |
| - np.ndarray for each row of preds, selected by the corresponding row |
187 |
| - of span_mask. |
188 |
| - """ |
189 |
| - preds = preds.detach().cpu() |
190 |
| - masks = masks.detach().cpu() |
191 |
| - for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)): |
192 |
| - yield pred[mask].numpy() # only non-masked predictions |
193 |
| - |
194 | 172 | def get_predictions(self, logits: torch.Tensor):
|
195 | 173 | """Return class probabilities, same shape as logits.
|
196 | 174 |
|
@@ -218,16 +196,6 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgePro
|
218 | 196 | Returns:
|
219 | 197 | loss: scalar Tensor
|
220 | 198 | """
|
221 |
| - binary_preds = logits.ge(0).long() # {0,1} |
222 |
| - |
223 |
| - # Matthews coefficient and accuracy computed on {0,1} labels. |
224 |
| - task.mcc_scorer(binary_preds, labels.long()) |
225 |
| - task.acc_scorer(binary_preds, labels.long()) |
226 |
| - |
227 |
| - # F1Measure() expects [total_num_targets, n_classes, 2] |
228 |
| - # to compute binarized F1. |
229 |
| - binary_scores = torch.stack([-1 * logits, logits], dim=2) |
230 |
| - task.f1_scorer(binary_scores, labels) |
231 | 199 |
|
232 | 200 | if self.loss_type == "sigmoid":
|
233 | 201 | return F.binary_cross_entropy(torch.sigmoid(logits), labels.float())
|
|
0 commit comments