-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathxp_embeddings.py
126 lines (109 loc) · 4.08 KB
/
xp_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import List, Optional, Tuple, Literal, Set
import os, glob
import torch
from sacred import Experiment
from sacred.run import Run
from sacred.commands import print_config
from sacred.observers import FileStorageObserver, TelegramObserver
from sacred.utils import apply_backspaces_and_linefeeds
from conivel.datas.dataset import NERDataset
from conivel.datas.conll import CoNLLDataset
from conivel.datas.dekker import load_book
from conivel.train import train_ner_model
from conivel.predict import predict
from conivel.utils import (
pretrained_bert_for_token_classification,
entities_from_bio_tags,
sacred_archive_picklable_as_file,
)
script_dir = os.path.abspath(os.path.dirname(__file__))
ex = Experiment()
ex.captured_out_filter = apply_backspaces_and_linefeeds # type: ignore
ex.observers.append(FileStorageObserver("runs"))
if os.path.isfile(f"{script_dir}/telegram_observer_config.json"):
ex.observers.append(
TelegramObserver.from_config(f"{script_dir}/telegram_observer_config.json")
)
def load_dekker(
dataset_path: str, keep_only_classes: Optional[Set[str]]
) -> Tuple[NERDataset, NERDataset]:
"""Load the version of Dekker dataset annotated with PER, LOC and ORG
:param dataset_path: root directory
:param keep_only_classes: passed to :func:`load_book`
:return: ``(train set, test_set)``
"""
dataset_path = dataset_path.rstrip("/")
# TODO: file names
paths = glob.glob(f"{dataset_path}/*.conll.annotated")
# TODO: hardcoded train/test split for now
train_dataset = NERDataset(
[load_book(path, keep_only_classes=keep_only_classes) for path in paths[:11]]
)
test_dataset = NERDataset(
[load_book(path, keep_only_classes=keep_only_classes) for path in paths[11:]]
)
return train_dataset, test_dataset
@ex.config
def config():
# one of : "dekker", "conll"
dataset_name: str
dataset_path: str
epochs_nb: int
batch_size: int
# Optional[List[str]]
keep_only_classes: Optional[list] = None
@ex.automain
def main(
_run: Run,
dataset_name: Literal["dekker", "conll"],
# path to Dekker et al dataset - use .annotated files for now
# since ORG/LOC classes are important here
dataset_path: str,
epochs_nb: int,
batch_size: int,
keep_only_classes: Optional[List[str]],
):
print_config(_run)
koc = set(keep_only_classes) if not keep_only_classes is None else None
if dataset_name == "dekker":
train_dataset, test_dataset = load_dekker(dataset_path, koc)
elif dataset_name == "conll":
train_dataset = CoNLLDataset.train_dataset(keep_only_classes=keep_only_classes)
test_dataset = CoNLLDataset.test_dataset(keep_only_classes=keep_only_classes)
else:
raise RuntimeError(f"Unknown dataset {dataset_name}")
model = pretrained_bert_for_token_classification(
"bert-base-cased", train_dataset.tag_to_id
)
model = train_ner_model(
model,
train_dataset,
train_dataset,
_run,
epochs_nb=epochs_nb,
batch_size=batch_size,
)
preds = predict(
model,
test_dataset,
batch_size=batch_size,
additional_outputs={"embeddings"},
transfer_additional_outputs_to_cpu=True,
)
assert not preds.embeddings is None
# reload test dataset without class restriction to analyse the
# embeddings of all entities even if we did not train on them
if dataset_name == "dekker":
_, test_dataset = load_dekker(dataset_path, None)
elif dataset_name == "conll":
test_dataset = CoNLLDataset.test_dataset()
else:
raise RuntimeError(f"Unknown dataset {dataset_name}")
entities_embeddings = {}
for sent, sent_embeddings in zip(test_dataset.sents(), preds.embeddings):
sent_entities = entities_from_bio_tags(sent.tokens, sent.tags)
for entity in sent_entities:
entities_embeddings[entity] = torch.mean(
sent_embeddings[entity.start_idx : entity.end_idx + 1], dim=0
)
sacred_archive_picklable_as_file(_run, entities_embeddings, "entities_embeddings")