diff --git a/metrics/functional/query_span_f1.py b/metrics/functional/query_span_f1.py index 488532a..663ce97 100644 --- a/metrics/functional/query_span_f1.py +++ b/metrics/functional/query_span_f1.py @@ -59,7 +59,7 @@ def extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end match_preds = match_label_mask & match_preds match_pos_pairs = np.transpose(np.nonzero(match_preds.numpy())).tolist() - return [(pos[0], pos[1], pseudo_tag) for pos in match_pos_pairs] + return [(pos[1], pos[2], pseudo_tag) for pos in match_pos_pairs] def extract_flat_spans(start_pred, end_pred, match_pred, label_mask, pseudo_tag = "TAG"):