forked from microsoft/nlp-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgensen_wrapper.py
158 lines (119 loc) · 4.86 KB
/
gensen_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import json
import os
from examples.sentence_similarity.gensen_train import train
from utils_nlp.eval.classification import compute_correlation_coefficients
from utils_nlp.models.gensen.create_gensen_model import (
create_multiseq2seq_model,
)
from utils_nlp.models.gensen.gensen import GenSenSingle
from utils_nlp.models.gensen.preprocess_utils import gensen_preprocess
class GenSenClassifier:
""" GenSen Classifier that trains a model on several NLP tasks.
learning_rate (str): The learning rate for the model.
config_file (str) : Configuration file that is used to train the model. This
specifies the batch size, directories to load and save the model.
cache_dir (str) : Location of GenSen's data directory.
"""
def __init__(
self,
config_file,
pretrained_embedding_path,
learning_rate=0.0001,
cache_dir=".",
max_epoch=None,
):
self.learning_rate = learning_rate
self.config_file = config_file
self.cache_dir = cache_dir
self.pretrained_embedding_path = pretrained_embedding_path
self.model_name = "gensen_multiseq2seq"
self.max_epoch = max_epoch
self._validate_params()
def _validate_params(self):
"""Validate input params."""
if not isinstance(self.learning_rate, float) or (
self.learning_rate <= 0.0
):
raise ValueError(
"Learning rate must be of type float and greater than 0"
)
assert os.path.isfile(self.pretrained_embedding_path)
try:
f = open(self.config_file)
self.config = self._read_config(self.config_file)
f.close()
except FileNotFoundError:
raise FileNotFoundError("Provided config file does not exist!")
def _get_gensen_tokens(self, train_df=None, dev_df=None, test_df=None):
"""
Args:
train_df(pd.Dataframe): A dataframe containing tokenized sentences from
the training set.
dev_df(pd.Dataframe): A dataframe containing tokenized
sentences from the validation set.
test_df(pd.Dataframe): A dataframe containing tokenized sentences from the
test set.
Returns:
str: Path to the folder containing all preprocessed token files.
"""
return gensen_preprocess(train_df, dev_df, test_df, self.cache_dir)
@staticmethod
def _read_config(config_file):
""" Read JSON config.
Args:
config_file: Path to the config file.
Returns
dict: The loaded json file as python object
"""
json_object = json.load(open(config_file, "r", encoding="utf-8"))
return json_object
def _create_multiseq2seq_model(self):
""" Method that creates a GenSen model from a MultiSeq2Seq model."""
create_multiseq2seq_model(
save_folder=os.path.join(
self.cache_dir, self.config["data"]["save_dir"]
),
save_name=self.model_name,
trained_model_folder=os.path.join(
self.cache_dir, self.config["data"]["save_dir"]
),
)
def fit(self, train_df, dev_df, test_df):
""" Method to train the Gensen model.
Args:
train_df: A dataframe containing tokenized sentences from the training set.
dev_df: A dataframe containing tokenized sentences from the validation set.
test_df: A dataframe containing tokenized sentences from the test set.
"""
self.cache_dir = self._get_gensen_tokens(train_df, dev_df, test_df)
train(
data_folder=os.path.abspath(self.cache_dir),
config=self.config,
learning_rate=self.learning_rate,
max_epoch=self.max_epoch,
)
self._create_multiseq2seq_model()
def predict(self, sentences):
"""
Method to predict the model on the test dataset. This uses SentEval utils.
Args:
sentences(list) : List of sentences.
Returns
pd.Dataframe: A pairwise cosine similarity for the sentences provided based on their
gensen vector representations.
"""
# self.cache_dir = os.path.join(self.cache_dir, "clean/snli_1.0")
# self._create_multiseq2seq_model()
gensen_model = GenSenSingle(
model_folder=os.path.join(
self.cache_dir, self.config["data"]["save_dir"]
),
filename_prefix=self.model_name,
pretrained_emb=self.pretrained_embedding_path,
)
reps_h, reps_h_t = gensen_model.get_representation(
sentences, pool="last", return_numpy=True, tokenize=True
)
return compute_correlation_coefficients(reps_h_t)