diff --git a/code_soup/ch8/pwws.py b/code_soup/ch8/pwws.py new file mode 100644 index 0000000..689a8d2 --- /dev/null +++ b/code_soup/ch8/pwws.py @@ -0,0 +1,400 @@ +""" +PWWS Attack implementation. The code has been adapted from +https://github.com/thunlp/OpenAttack/blob/master/OpenAttack/attackers/pwws/__init__.py. +""" + + +import sys + +sys.path.append("./") + +from typing import Any, Optional + +import datasets +import numpy as np +import transformers + +from code_soup.common.text.datasets.utils import dataset_mapping +from code_soup.common.text.models import classifier, transformers_classifier +from code_soup.common.text.utils.attack_helpers import * +from code_soup.common.text.utils.exceptions import WordNotInDictionaryException +from code_soup.common.text.utils.metrics import * +from code_soup.common.text.utils.misc import ENGLISH_FILTER_WORDS +from code_soup.common.text.utils.tokenizer import PunctTokenizer, Tokenizer +from code_soup.common.text.utils.word_substitute import WordNetSubstitute + + +def check(prediction, target, targeted): + """ + A utility function to check if the attack was successful. If the attack is + targeted, then the "predicted" class must be same as "target" class. + Otherwise, the "predicted" class must be different from the "target" class. + + Args: + prediction (int): Predicted class (as returned by the model). + target (int): Has a dual meaning. If targeted = True, then target is the + class we want the model to predict (on the adversarial + sample). Otherwise, target is the class the model predicted + for the original sample. + targeted (bool): Whether the attack is targeted or not. Targeted attack + here means that we want to obtain an adversarial sample + such that the model predicts the specified target class. + + Returns: + (bool): Returns whether the attack was successful or not. + """ + if targeted: + return prediction == target + else: + return prediction != target + + +class PWWSAttacker: + def __init__( + self, + tokenizer: Optional[Tokenizer] = None, + token_unk: str = "", + ): + """ + Generating Natural Language Adversarial Examples through Probability + Weighted Word Saliency. + Shuhuai Ren, Yihe Deng, Kun He, Wanxiang Che. ACL 2019. + + `[pdf] `__ + `[code] `__ + + Args: + tokenizer: A tokenizer that will be used during the attack procedure. + Must be an instance of Tokenizer + token_unk: The token id or the token name for out-of-vocabulary + words in victim model. Default: + """ + # PWWS attack substitutes words using synonyms obtained from WordNet. + # For a detailed description of the method, please refer to Section 3.2.1. + # You can also refer to code_soup/ch8/common/text/utils/word_substitute.py. + self.substitute = WordNetSubstitute() + + if tokenizer is None: + tokenizer = PunctTokenizer() + self.tokenizer = tokenizer + + self.token_unk = token_unk + self.filter_words = set(ENGLISH_FILTER_WORDS) + + def __call__(self, victim: classifier.Classifier, input_: Any): + """ + Generates the adversarial sample when the attacker object is called. + + Args: + victim (classifier.Classifier): A classifier which is to be attacked. + input_ (Any): A dictionary which contains the input data (text and + label). Example: + {'label': 0.625, + 'x': 'Singer\\/composer Bryan Adams contributes a + slew of songs .', + 'y': 1 + } + + Raises: + RuntimeError: If the attack is not successful. + + Returns: + adversarial_sample (str): Adversarial sample generated by PWWS. + """ + # If the attack is targeted + if "target" in input_: + target = input_["target"] + targeted = True + # If the attack is not targeted, keep the target as the predicted label + # of the original text; in untargeted attack, we will generate a sample + # with predicted label different from the predicted label of the + # original text. + else: + target = victim.get_pred([input_["x"]])[0] + targeted = False + + # Generate the adversarial sample. + adversarial_sample = self.attack(victim, input_["x"], target, targeted) + + if adversarial_sample is not None: + # Obtain the predicted label of the adversarial sample. + y_adv = victim.get_pred([adversarial_sample])[0] + # Verify if the attack was successful. If not, raise an error. + if not check(y_adv, target, targeted): + raise RuntimeError( + "Check attacker result failed: " + "result ([%d] %s) expect (%s%d)" + % (y_adv, adversarial_sample, "" if targeted else "not ", target) + ) + return adversarial_sample + + def attack( + self, victim: classifier.Classifier, sentence: str, target=0, targeted=True + ): + """ + Given an input sample, generate the adversarial text. + + Args: + victim (classifier.Classifier): A classifier which is to be attacked. + sentence (str): Input text. + target (int): Has a dual meaning. If targeted = True, then target is + the class we want the model to predict (on the + adversarial sample). Otherwise, target is the class + the model predicted for the original sample. Defaults + to 0. + targeted (bool): Whether the attack is targeted or not. Targeted + attack here means that we want to obtain an adversarial + sample such that the model predicts the specified + target class. Defaults to True. + + Returns: + (str): Adversarial sample generated by PWWS. + """ + # Example of x_orig: "inception is an awesome movie ." + x_orig = sentence.lower() + # Words: ['inception', 'is', 'an', 'awesome', 'movie', '.'] + # POS Tags: ['noun', 'verb', 'other', 'adj', 'noun', 'other'] + + # Obtain words and their respective POS tags. + x_orig_pos = self.tokenizer.tokenize(x_orig) + x_orig, poss = list(map(list, zip(*x_orig_pos))) + + # Get the saliency score for every word in the input text. Example: + # [1.19209290e-06, 4.29153442e-06, 1.41859055e-05, 5.17034531e-03, + # 7.03334808e-06, 4.76837158e-07] + S = self.get_saliency(victim, x_orig, target, targeted) + # Normalise the saliency scores. Example: + # [0.16652223, 0.16652276, 0.16652441, 0.16738525, 0.1665232, 0.16652212] + S_softmax = np.exp(S - S.max()) + S_softmax = S_softmax / S_softmax.sum() + + # Obtain the best replacement word for every word in the input text. + # Example: + # [('origination', -2.3841858e-07), ('is', 0), ('an', 0), + # ('awful', 0.9997573), ('pic', 1.180172e-05), ('.', 0)] + w_star = [ + self.get_wstar(victim, x_orig, i, poss[i], target, targeted) + for i in range(len(x_orig)) + ] + # Compute "H" score for every word. It is simply the product of the w_star + # score and the saliency scores. See Eqn (7) in the paper. Example: + # [(0, 'origination', -3.9701995e-08), (1, 'is', 0.0), + # (2, 'an', 0.0), (3, 'awful', 0.16734463), + # (4, 'pic', 1.9652603e-06), (5, '.', 0.0)] + H = [ + (idx, w_star[idx][0], S_softmax[idx] * w_star[idx][1]) + for idx in range(len(x_orig)) + ] + + # Sort the words in the input text by their "H" score (descending order). + H = sorted(H, key=lambda x: -x[2]) + ret_sent = x_orig.copy() + for i in range(len(H)): + idx, wd, _ = H[i] + if ret_sent[idx] in self.filter_words: + continue + ret_sent[idx] = wd + + curr_sent = self.tokenizer.detokenize(ret_sent) + pred = victim.get_pred([curr_sent])[0] + # Verify if the attack was successful. + if check(pred, target, targeted): + return curr_sent + return None + + def get_saliency( + self, clsf: classifier.Classifier, sent: List[str], target=0, targeted=True + ): + """ + Get saliency scores for every score. Simply put, saliency score of a + word is the degree of change in the output probability of the classifier + if the word is set to unknown (out of vocabulary). See Section 3.2.2 + in the paper for more details. + + Args: + clsf (Classifier): A classifier that will be used to get the + saliency scores. + sent (list): List of tokens in a sentence. + target (int): Has a dual meaning. If targeted = True, then target is + the class we want the model to predict (on the + adversarial sample). Otherwise, target is the class + the model predicted for the original sample. Defaults + to 0. + targeted (bool): Whether the attack is targeted or not. Targeted + attack here means that we want to obtain an adversarial + sample such that the model predicts the specified + target class. Defaults to True. + """ + # Replace words with one by one. Compute probability for every such + # sample. + # Example: sent = ["inception", "is", "an", "awesome", "movie", "."] + # A few samples generated: ['', 'is', 'an', 'awesome', 'movie', '.'], + # ['inception', '', 'an', 'awesome', 'movie', '.'], etc. + x_hat_raw = [] + for i in range(len(sent)): + left = sent[:i] + right = sent[i + 1 :] + # Replace the word with unknown token + x_i_hat = left + [self.token_unk] + right + x_hat_raw.append(self.tokenizer.detokenize(x_i_hat)) + # Concatenate the original text as well; we want to compute the probability + # for the original sample too (because we want the difference in probs) + # between generated samples and original sample). + x_hat_raw.append(self.tokenizer.detokenize(sent)) + + # Compute the probabilities. Example: + # [0.9999354, 0.9999323, 0.9999224, 0.99476624, 0.99992955, 0.9999361, + # 0.9999366]. Clearly, the 4th element of the list differs the most + # from the last element (probability of the original sample). The 4th + # element is the probability of ["inception", "is", "an", "", "movie", "."]. + # This proves that the word "awesome" plays a major role in determining + # the classification output. + res = clsf.get_prob(x_hat_raw)[:, target] + if not targeted: + res = res[-1] - res[:-1] + else: + res = res[:-1] - res[-1] + return res + + def get_wstar( + self, + clsf: classifier.Classifier, + sent: List[str], + idx: int, + pos: str, + target=0, + targeted=True, + ): + """ + Given a word in a sentence, find the replacment word (from a list of + candidate replacements) that maximises the difference in probabilities + between the original sample and the generated sample (generated sample + is the sample with the word replaced by the candidate word). This score + is given as delta(P) in the paper. See Section 3.2.1 for more details. + + Args: + clsf (classifier.Classifier): A classifier which is to be attacked. + sent ([str]): Input text. + idx (int): Index of word in sentence. + pos (str): POS Tag. + target (int): Has a dual meaning. If targeted = True, then target is + the class we want the model to predict (on the + adversarial sample). Otherwise, target is the class + the model predicted for the original sample. Defaults + to 0. + targeted (bool): Whether the attack is targeted or not. Targeted + attack here means that we want to obtain an adversarial + sample such that the model predicts the specified + target class. Defaults to True. + + Returns: + ((str, float)): Best replacement word (w_star) and its score (delta(P) + in the paper). + """ + # Example: sent = ["inception", "is", "an", "awesome", movie, "."] + # idx = 3, word = "awesome", pos = "adj" + # Its replacement words are: ['awing', 'amazing', 'awful', 'awe-inspiring'] + word = sent[idx] + try: + # Obtain replacement words. + rep_words = list(map(lambda x: x[0], self.substitute(word, pos))) + except WordNotInDictionaryException: + rep_words = [] + # Remove the word itself from the list of replacement words. + rep_words = list(filter(lambda x: x != word, rep_words)) + # If there are no replacement words, return the original word with score 0. + if len(rep_words) == 0: + return (word, 0) + + sents = [] + for rw in rep_words: + # Step 1: Replace word with candidate word. + new_sent = sent[:idx] + [rw] + sent[idx + 1 :] + sents.append(self.tokenizer.detokenize(new_sent)) + # Append the original sentence as well, we want to compute the difference + # in probabilities between original sample and generated samples. + sents.append(self.tokenizer.detokenize(sent)) + # Get the probabilities. Example: + # Word: awesome + # rep_words: ['awe-inspiring', 'awful', 'awing', 'amazing'] + # [5.1087904e-01, 9.9993670e-01, 9.9991834e-01, 1.7930799e-04, 9.9993658e-01] + res = clsf.get_prob(sents)[:, target] + prob_orig = res[-1] + res = res[:-1] + # Find the best replacement word, i.e., w_star. We maximise delta(P) here. + # Clearly, the best replacement word is the 4th word, i.e., awing. + if targeted: + return (rep_words[res.argmax()], res.max() - prob_orig) + else: + return (rep_words[res.argmin()], prob_orig - res.min()) + + +# Example +def main(): + def_tokenizer = PunctTokenizer() + + path = "gchhablani/bert-base-cased-finetuned-sst2" + + # define the attack + attacker = PWWSAttacker() + + # define the victim model (classifier) + tokenizer = transformers.AutoTokenizer.from_pretrained(path) + model = transformers.AutoModelForSequenceClassification.from_pretrained( + path, num_labels=2, output_hidden_states=False + ) + victim = transformers_classifier.TransformersClassifier( + model, tokenizer, model.bert.embeddings.word_embeddings + ) + + # load the dataset + dataset = datasets.load_dataset("sst", split="train[:10]").map( + function=dataset_mapping + ) + + # define the metric(s) which are to be computed between the original sample + # and the adversarial sample + metrics = [Levenshtein(def_tokenizer)] + + result_iterator = attack_process(attacker, victim, dataset, metrics) + + total_inst = 0 + success_inst = 0 + + for i, res in enumerate(result_iterator): + try: + total_inst += 1 + success_inst += int(res["success"]) + + x_orig = res["data"]["x"] + x_adv = res["result"] + + probs = victim.get_prob([x_orig, x_adv]) + y_orig_prob = probs[0] + y_adv_prob = probs[1] + + preds = victim.get_pred([x_orig, x_adv]) + y_orig_preds = int(preds[0]) + y_adv_preds = int(preds[1]) + + print("======================================================") + print(f"{i}th sample") + print("Original: ") + print(f"TEXT: {x_orig}") + print(f"Probabilities: {y_orig_prob}") + print(f"Predictions: {y_orig_preds}") + + print("Adversarial: ") + print(f"TEXT: {x_adv}") + print(f"Probabilities: {y_adv_prob}") + print(f"Predictions: {y_adv_preds}") + + print("\nMetrics: ") + print(res["metrics"]) + print("======================================================") + except Exception as e: + print(e) + + +if __name__ == "__main__": + main() diff --git a/code_soup/common/text/datasets/utils.py b/code_soup/common/text/datasets/utils.py new file mode 100644 index 0000000..22fe9bc --- /dev/null +++ b/code_soup/common/text/datasets/utils.py @@ -0,0 +1,5 @@ +def dataset_mapping(x): + return { + "x": x["sentence"], + "y": 1 if x["label"] > 0.5 else 0, + } diff --git a/code_soup/common/text/models/classifier.py b/code_soup/common/text/models/classifier.py new file mode 100644 index 0000000..7ac92f3 --- /dev/null +++ b/code_soup/common/text/models/classifier.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + +import numpy as np + + +class Classifier(ABC): # no pragma: no cover + def __init__(self): + pass + + @abstractmethod + def get_prob(input_: List[str]) -> np.ndarray: + pass + + @abstractmethod + def get_pred(input_: List[str]) -> np.ndarray: + pass + + def get_grad(input_: List[str], labels: List[int]) -> Tuple[np.ndarray, np.ndarray]: + pass diff --git a/code_soup/common/text/models/transformers_classifier.py b/code_soup/common/text/models/transformers_classifier.py new file mode 100644 index 0000000..d42a7e9 --- /dev/null +++ b/code_soup/common/text/models/transformers_classifier.py @@ -0,0 +1,180 @@ +""" +Class for transformers-based classifiers. Adapted from +https://github.com/thunlp/OpenAttack/blob/master/OpenAttack/victim/classifiers/transformers.py""" +import numpy as np +import torch +import transformers + +from code_soup.common.text.models.classifier import Classifier +from code_soup.common.text.utils.tokenizer import TransformersTokenizer +from code_soup.common.text.utils.word_embedding import WordEmbedding + + +class HookCloser: + def __init__(self, model_wrapper): + self.model_wrapper = model_wrapper + + def __call__(self, module, input_, output_): + self.model_wrapper.curr_embedding = output_ + output_.retain_grad() + + +class TransformersClassifier(Classifier): + def __init__( + self, + model: transformers.PreTrainedModel, + tokenizer: transformers.PreTrainedTokenizer, + embedding_layer, + device: torch.device = None, + max_length: int = 128, + batch_size: int = 8, + ): + """ + Args: + model: Huggingface model for classification. + tokenizer: Huggingface tokenizer for classification. **Default:** None + embedding_layer: The module of embedding_layer used in transformers models. For example, + ``BertModel.bert.embeddings.word_embeddings``. **Default:** None + device: Device of pytorch model. **Default:** "cpu" if cuda is not available else "cuda" + max_len: Max length of input tokens. If input token list is too long, it will be truncated. Uses None for no + truncation. **Default:** None + batch_size: Max batch size of this classifier. + """ + + self.model = model + + if device is None: # no pragma: no cover + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.to(device) + + self.curr_embedding = None + self.hook = embedding_layer.register_forward_hook(HookCloser(self)) + self.embedding_layer = embedding_layer + + self.word2id = dict() + for i in range(tokenizer.vocab_size): + self.word2id[tokenizer.convert_ids_to_tokens(i)] = i + self.__tokenizer = tokenizer + + self.embedding = embedding_layer.weight.detach().cpu().numpy() + + self.token_unk = tokenizer.unk_token + self.token_unk_id = tokenizer.unk_token_id + + self.max_length = max_length + self.batch_size = batch_size + + @property + def tokenizer(self): + return TransformersTokenizer(self.__tokenizer) # no pragma: no cover + + def to(self, device: torch.device): + """ + Args: + device: Device that moves model to. + """ + self.device = device + self.model = self.model.to(device) + return self + + def get_pred(self, input_): + return self.get_prob(input_).argmax(axis=1) + + def get_prob(self, input_): + return self.get_grad( + [self.__tokenizer.tokenize(sent) for sent in input_], [0] * len(input_) + )[0] + + def get_grad(self, input_, labels): + v = self.predict(input_, labels) + return v[0], v[1] + + def predict(self, sen_list, labels=None): + sen_list = [sen[: self.max_length - 2] for sen in sen_list] + sent_lens = [len(sen) for sen in sen_list] + batch_len = max(sent_lens) + 2 + + attentions = np.array( + [ + [1] * (len(sen) + 2) + [0] * (batch_len - 2 - len(sen)) + for sen in sen_list + ], + dtype="int64", + ) + sen_list = [self.__tokenizer.convert_tokens_to_ids(sen) for sen in sen_list] + tokeinzed_sen = np.array( + [ + [self.__tokenizer.cls_token_id] + + sen + + [self.__tokenizer.sep_token_id] + + ([self.__tokenizer.pad_token_id] * (batch_len - 2 - len(sen))) + for sen in sen_list + ], + dtype="int64", + ) + + result = None + result_grad = None + all_hidden_states = None + + if labels is None: + labels = [0] * len(sen_list) + labels = torch.LongTensor(labels).to(self.device) + + for i in range((len(sen_list) + self.batch_size - 1) // self.batch_size): + curr_sen = tokeinzed_sen[i * self.batch_size : (i + 1) * self.batch_size] + curr_mask = attentions[i * self.batch_size : (i + 1) * self.batch_size] + + xs = torch.from_numpy(curr_sen).long().to(self.device) + masks = torch.from_numpy(curr_mask).long().to(self.device) + outputs = self.model( + input_ids=xs, + attention_mask=masks, + output_hidden_states=True, + labels=labels[i * self.batch_size : (i + 1) * self.batch_size], + ) + if i == 0: + all_hidden_states = outputs.hidden_states[-1].detach().cpu() + loss = outputs.loss + logits = outputs.logits + logits = torch.nn.functional.softmax(logits, dim=-1) + loss = -loss + loss.backward() + + result_grad = self.curr_embedding.grad.clone().cpu() + self.curr_embedding.grad.zero_() + self.curr_embedding = None + result = logits.detach().cpu() + else: + all_hidden_states = torch.cat( + (all_hidden_states, outputs.hidden_states[-1].detach().cpu()), dim=0 + ) + loss = outputs.loss + logits = outputs.logits + logits = torch.nn.functional.softmax(logits, dim=-1) + loss = -loss + loss.backward() + + result_grad = torch.cat( + (result_grad, self.curr_embedding.grad.clone().cpu()), dim=0 + ) + self.curr_embedding.grad.zero_() + self.curr_embedding = None + + result = torch.cat((result, logits.detach().cpu())) + + result = result.numpy() + all_hidden_states = all_hidden_states.numpy() + result_grad = result_grad.numpy()[:, 1:-1] + return result, result_grad, all_hidden_states + + def get_hidden_states(self, input_, labels=None): + """ + :param list input_: A list of sentences of which we want to get the hidden states in the model. + :rtype torch.tensor + """ + return self.predict(input_, labels)[2] + + def get_embedding(self): + return WordEmbedding(self.word2id, self.embedding) diff --git a/code_soup/common/text/utils/attack_helpers.py b/code_soup/common/text/utils/attack_helpers.py new file mode 100644 index 0000000..fe33acf --- /dev/null +++ b/code_soup/common/text/utils/attack_helpers.py @@ -0,0 +1,41 @@ +"""Utility functions for text-based attacks. Adapted from https://github.com/thunlp/OpenAttack.""" + + +def __measure(data, adversarial_sample, metrics): # no pragma: no cover + ret = {} + for it in metrics: + value = it.after_attack(data, adversarial_sample) + if value is not None: + ret[it.name] = value + return ret + + +def __iter_dataset(dataset, metrics): # no pragma: no cover + for data in dataset: + v = data + for it in metrics: + ret = it.before_attack(v) + if ret is not None: + v = ret + yield v + + +def __iter_metrics(iterable_result, metrics): # no pragma: no cover + for data, result in iterable_result: + adversarial_sample = result + ret = { + "data": data, + "success": adversarial_sample is not None, + "result": adversarial_sample, + "metrics": {**__measure(data, adversarial_sample, metrics)}, + } + yield ret + + +def attack_process(attacker, victim, dataset, metrics): # no pragma: no cover + def result_iter(): + for data in __iter_dataset(dataset, metrics): + yield attacker(victim, data) + + for ret in __iter_metrics(zip(dataset, result_iter()), metrics): + yield ret diff --git a/code_soup/common/text/utils/exceptions.py b/code_soup/common/text/utils/exceptions.py new file mode 100644 index 0000000..e2a61c0 --- /dev/null +++ b/code_soup/common/text/utils/exceptions.py @@ -0,0 +1,13 @@ +"""Exceptions for text-based attacks.""" + + +class AttackException(Exception): + pass + + +class WordNotInDictionaryException(AttackException): + pass + + +class UnknownPOSException(AttackException): + pass diff --git a/code_soup/common/text/utils/metrics.py b/code_soup/common/text/utils/metrics.py new file mode 100644 index 0000000..c7dc15d --- /dev/null +++ b/code_soup/common/text/utils/metrics.py @@ -0,0 +1,60 @@ +"""Various metrics for text. Adapted from https://github.com/thunlp/OpenAttack/tree/master/OpenAttack/metric/algorithms.""" +from typing import List + +import torch + +from code_soup.common.text.utils.tokenizer import Tokenizer + + +class AttackMetric(object): # no pragma: no cover + """ + Base class of all metrics. + """ + + def before_attack(self, input): + return + + def after_attack(self, input, adversarial_sample): + return + + +class Levenshtein(AttackMetric): + def __init__(self, tokenizer: Tokenizer) -> None: + """ + Args: + tokenizer: A tokenizer that will be used in this metric. Must be an instance of :py:class:`.Tokenizer` + """ + self.tokenizer = tokenizer + self.name = "Levenshtein Edit Distance" + + def calc_score(self, a: List[str], b: List[str]) -> int: + """ + Args: + a: The first list. + b: The second list. + Returns: + Levenshtein edit distance between two sentences. + + Both parameters can be str or list, str for char-level edit distance while list for token-level edit distance. + """ + la = len(a) + lb = len(b) + f = torch.zeros(la + 1, lb + 1, dtype=torch.long) + for i in range(la + 1): + for j in range(lb + 1): + if i == 0: + f[i][j] = j + elif j == 0: + f[i][j] = i + elif a[i - 1] == b[j - 1]: + f[i][j] = f[i - 1][j - 1] + else: + f[i][j] = min(f[i - 1][j - 1], f[i - 1][j], f[i][j - 1]) + 1 + return f[la][lb].item() + + def after_attack(self, input, adversarial_sample): + if adversarial_sample is not None: + return self.calc_score( + self.tokenizer.tokenize(input["x"], pos_tagging=False), + self.tokenizer.tokenize(adversarial_sample, pos_tagging=False), + ) diff --git a/code_soup/common/text/utils/misc.py b/code_soup/common/text/utils/misc.py new file mode 100644 index 0000000..1d4ffe3 --- /dev/null +++ b/code_soup/common/text/utils/misc.py @@ -0,0 +1,274 @@ +""" +English filter words (stopwords, etc.). Obtained from +https://github.com/thunlp/OpenAttack/blob/master/OpenAttack/attack_assist/filter_words/english.py. +""" +ENGLISH_FILTER_WORDS = [ + "a", + "about", + "above", + "across", + "after", + "afterwards", + "again", + "against", + "ain", + "all", + "almost", + "alone", + "along", + "already", + "also", + "although", + "am", + "among", + "amongst", + "an", + "and", + "another", + "any", + "anyhow", + "anyone", + "anything", + "anyway", + "anywhere", + "are", + "aren", + "aren't", + "around", + "as", + "at", + "back", + "been", + "before", + "beforehand", + "behind", + "being", + "below", + "beside", + "besides", + "between", + "beyond", + "both", + "but", + "by", + "can", + "cannot", + "could", + "couldn", + "couldn't", + "d", + "didn", + "didn't", + "doesn", + "doesn't", + "don", + "don't", + "down", + "due", + "during", + "either", + "else", + "elsewhere", + "empty", + "enough", + "even", + "ever", + "everyone", + "everything", + "everywhere", + "except", + "first", + "for", + "former", + "formerly", + "from", + "hadn", + "hadn't", + "hasn", + "hasn't", + "haven", + "haven't", + "he", + "hence", + "her", + "here", + "hereafter", + "hereby", + "herein", + "hereupon", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "however", + "hundred", + "i", + "if", + "in", + "indeed", + "into", + "is", + "isn", + "isn't", + "it", + "it's", + "its", + "itself", + "just", + "latter", + "latterly", + "least", + "ll", + "may", + "me", + "meanwhile", + "mightn", + "mightn't", + "mine", + "more", + "moreover", + "most", + "mostly", + "must", + "mustn", + "mustn't", + "my", + "myself", + "namely", + "needn", + "needn't", + "neither", + "never", + "nevertheless", + "next", + "no", + "nobody", + "none", + "noone", + "nor", + "not", + "nothing", + "now", + "nowhere", + "o", + "of", + "off", + "on", + "once", + "one", + "only", + "onto", + "or", + "other", + "others", + "otherwise", + "our", + "ours", + "ourselves", + "out", + "over", + "per", + "please", + "s", + "same", + "shan", + "shan't", + "she", + "she's", + "should've", + "shouldn", + "shouldn't", + "somehow", + "something", + "sometime", + "somewhere", + "such", + "t", + "than", + "that", + "that'll", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "thence", + "there", + "thereafter", + "thereby", + "therefore", + "therein", + "thereupon", + "these", + "they", + "this", + "those", + "through", + "throughout", + "thru", + "thus", + "to", + "too", + "toward", + "towards", + "under", + "unless", + "until", + "up", + "upon", + "used", + "ve", + "was", + "wasn", + "wasn't", + "we", + "were", + "weren", + "weren't", + "what", + "whatever", + "when", + "whence", + "whenever", + "where", + "whereafter", + "whereas", + "whereby", + "wherein", + "whereupon", + "wherever", + "whether", + "which", + "while", + "whither", + "who", + "whoever", + "whole", + "whom", + "whose", + "why", + "with", + "within", + "without", + "won", + "won't", + "would", + "wouldn", + "wouldn't", + "y", + "yet", + "you", + "you'd", + "you'll", + "you're", + "you've", + "your", + "yours", + "yourself", + "yourselves", + "have", + "be", +] diff --git a/code_soup/common/text/utils/tokenizer.py b/code_soup/common/text/utils/tokenizer.py new file mode 100644 index 0000000..265859b --- /dev/null +++ b/code_soup/common/text/utils/tokenizer.py @@ -0,0 +1,109 @@ +"""Tokenizer classes. Based on https://github.com/thunlp/OpenAttack/tree/master/OpenAttack/text_process/tokenizer.""" + +from typing import List, Tuple, Union + +import nltk +import transformers +from nltk.tag.perceptron import PerceptronTagger +from nltk.tokenize import WordPunctTokenizer, sent_tokenize + +nltk.download("averaged_perceptron_tagger") +nltk.download("punkt") + + +class Tokenizer: + """ + Tokenizer is the base class of all tokenizers. + """ + + def tokenize( + self, x: str, pos_tagging: bool = True + ) -> Union[List[str], List[Tuple[str, str]]]: + """ + Args: + x: A sentence. + pos_tagging: Whether to return Pos Tagging results. + Returns: + A list of tokens if **pos_tagging** is `False` + + A list of (token, pos) tuples if **pos_tagging** is `True` + + POS tag must be one of the following tags: ``["noun", "verb", "adj", "adv", "other"]`` + """ + return self.do_tokenize(x, pos_tagging) + + def detokenize(self, x: Union[List[str], List[Tuple[str, str]]]) -> str: + """ + Args: + x: The result of :py:meth:`.Tokenizer.tokenize`, can be a list of tokens or tokens with POS tags. + Returns: + A sentence. + """ + if not isinstance(x, list): + raise TypeError("`x` must be a list of tokens") + if len(x) == 0: + return "" + x = [it[0] if isinstance(it, tuple) else it for it in x] + return self.do_detokenize(x) + + def do_tokenize(self, x, pos_tagging): + raise NotImplementedError() + + def do_detokenize(self, x): + raise NotImplementedError() + + +_POS_MAPPING = {"JJ": "adj", "VB": "verb", "NN": "noun", "RB": "adv"} + + +class PunctTokenizer(Tokenizer): + """ + Tokenizer based on nltk.word_tokenizer. + :Language: english + """ + + def __init__(self) -> None: + self.sent_tokenizer = sent_tokenize + self.word_tokenizer = WordPunctTokenizer().tokenize + self.pos_tagger = PerceptronTagger() + + def do_tokenize(self, x, pos_tagging=True): + sentences = self.sent_tokenizer(x) + tokens = [] + for sent in sentences: + tokens.extend(self.word_tokenizer(sent)) + + if not pos_tagging: + return tokens + ret = [] + for word, pos in self.pos_tagger.tag(tokens): + if pos[:2] in _POS_MAPPING: + mapped_pos = _POS_MAPPING[pos[:2]] + else: + mapped_pos = "other" + ret.append((word, mapped_pos)) + return ret + + def do_detokenize(self, x): + return " ".join(x) + + +class TransformersTokenizer(Tokenizer): + """ + Pretrained Tokenizer from transformers. + Usually returned by :py:class:`.TransformersClassifier` . + + """ + + def __init__(self, tokenizer: transformers.PreTrainedTokenizerBase): + self.__tokenizer = tokenizer + + def do_tokenize(self, x, pos_tagging): + if pos_tagging: # no pragma: no cover + raise ValueError( + "`%s` does not support pos tagging" % self.__class__.__name__ + ) + return self.__tokenizer.tokenize(x) + + def do_detokenize(self, x): + return self.__tokenizer.convert_tokens_to_string(x) diff --git a/code_soup/common/text/utils/word_embedding.py b/code_soup/common/text/utils/word_embedding.py new file mode 100644 index 0000000..af088ca --- /dev/null +++ b/code_soup/common/text/utils/word_embedding.py @@ -0,0 +1,16 @@ +from typing import Dict + + +class WordEmbedding: # no pragma: no cover + def __init__(self, word2id: Dict[str, int], embedding) -> None: + self.word2id = word2id + self.embedding = embedding + + def transform(self, word, token_unk): + if word in self.word2id: + return self.embedding[self.word2id[word]] + else: + if isinstance(token_unk, int): + return self.embedding[token_unk] + else: + return self.embedding[self.word2id[token_unk]] diff --git a/code_soup/common/text/utils/word_substitute.py b/code_soup/common/text/utils/word_substitute.py new file mode 100644 index 0000000..afaa19b --- /dev/null +++ b/code_soup/common/text/utils/word_substitute.py @@ -0,0 +1,153 @@ +""" +Contains different word subsitution methods such as replacing a word in a +sentence with its synonyms. +Adapted from +https://github.com/thunlp/OpenAttack/blob/master/OpenAttack/attack_assist/substitute/word/base.py. +""" +from typing import List, Optional, Tuple + +import nltk +from nltk.corpus import wordnet as nltk_wn + +from code_soup.common.text.utils.exceptions import ( + UnknownPOSException, + WordNotInDictionaryException, +) + +nltk.download("wordnet") +nltk.download("omw-1.4") + +POS_LIST = ["adv", "adj", "noun", "verb", "other"] + + +def prefilter(token, synonym): # 预过滤(原词,一个候选词 + if ( + len(synonym.split()) > 2 + or (synonym == token) # the synonym produced is a phrase + or (token == "be") # the pos of the token synonyms are different + or (token == "is") + or (token == "are") + or (token == "am") + ): # token is be + return False + else: + return True + + +class WordSubstitute(object): + def __call__(self, word: str, pos: Optional[str] = None) -> List[Tuple[str, float]]: + """ + In WordSubstitute, we return a list of words that are semantically + similar to the input word. + + Args: + word: A single word. + pos: POS tag of input word. Must be one of the following: + ``["adv", "adj", "noun", "verb", "other", None]`` + + Raises: + WordNotInDictionaryException: input word not in the dictionary of substitute algorithm + UnknownPOSException: invalid pos tagging + + Returns: + A list of words and their distance to original word + (distance is a number between 0 and 1, with smaller indicating more + similarity). + """ + + if pos is None: + ret = {} + for sub_pos in POS_LIST: + try: + for word, sim in self.substitute(word, sub_pos): + if word not in ret: + ret[word] = sim + else: + ret[word] = max(ret[word], sim) + except WordNotInDictionaryException: + continue + list_ret = [] + for word, sim in ret.items(): + list_ret.append((word, sim)) + if len(list_ret) == 0: + raise WordNotInDictionaryException() + return sorted(list_ret, key=lambda x: -x[1]) + elif pos not in POS_LIST: + raise UnknownPOSException("Invalid `pos` %s (expect %s)" % (pos, POS_LIST)) + return self.substitute(word, pos) + + def substitute(self, word: str, pos: str) -> List[Tuple[str, float]]: + raise NotImplementedError() + + +class WordNetSubstitute(WordSubstitute): + def __init__(self, k=50): + """ + English word substitute based on WordNet. WordNet is used to find + synonyms (same named entity as the original word). + See Section 3.2.1 of the PWWS paper to get a better idea of how this works. + Args: + k: Top-k results to return. If k is `None`, all results will be + returned. Default: 50 + """ + + self.wn = nltk_wn + self.k = k + + def substitute(self, word: str, pos: str): + """ + Finds candidate substitutes for the input word. + + Args: + word (str): Input word (obtained after tokenising the input text). + pos (str): POS tag (part of speech) of the input word (noun, verb, + etc.). + + Raises: + WordNotInDictionaryException: If the word does not have a POS tag + from list + ["adv", "adj", "noun", "verb"]. + + Returns: + synonyms ([str]): List of candidate replacements. + """ + token = word.replace("_", " ").split()[0] + if pos == "other": + raise WordNotInDictionaryException() + pos_in_wordnet = {"adv": "r", "adj": "a", "verb": "v", "noun": "n"}[pos] + + # Find synonyms using WordNet which belong to the same named entity. + # Example (wordnet_synonyms for word "new"): + """ + [Lemma('new.a.01.new'), Lemma('fresh.s.04.fresh'), Lemma('fresh.s.04.new'), + Lemma('fresh.s.04.novel'), Lemma('raw.s.12.raw'), Lemma('raw.s.12.new'), + Lemma('new.s.04.new'), Lemma('new.s.04.unexampled'), Lemma('new.s.05.new'), + Lemma('new.a.06.new'), Lemma('newfangled.s.01.newfangled'), + Lemma('newfangled.s.01.new'), Lemma('new.s.08.New'), + Lemma('modern.s.05.Modern'), Lemma('modern.s.05.New'), + Lemma('new.s.10.new'), Lemma('new.s.10.young'), Lemma('new.s.11.new')] + """ + + wordnet_synonyms = [] + synsets = self.wn.synsets(word, pos=pos_in_wordnet) + for synset in synsets: + wordnet_synonyms.extend(synset.lemmas()) + + # Preprocess the synonyms. Example: + # {'young', 'novel', 'unexampled', 'new', 'fresh', 'newfangled', 'modern', + # 'raw'} + synonyms = set() + for wordnet_synonym in wordnet_synonyms: + # Step 1: Obtain the base word from the lemma. + # Step 2: For multi-word synonyms, we only consider the first word. + # Step 3: Prefilter the synonyms, i.e., remove words like "be", "is", + # "are", "am", etc. + preprocessed_synonym = wordnet_synonym.name().split("_")[0] + if prefilter(token, preprocessed_synonym): + synonyms.add(preprocessed_synonym.lower()) + + synonyms = [(syn, 1) for syn in synonyms] + + if self.k is not None and self.k > len(synonyms): + synonyms = synonyms[: self.k] + return synonyms diff --git a/code_soup/misc.py b/code_soup/misc.py new file mode 100644 index 0000000..79b035a --- /dev/null +++ b/code_soup/misc.py @@ -0,0 +1,16 @@ +import random + +import numpy as np +import torch + + +def seed(value=42): + """Set random seed for everything. + Args: + value (int): Seed + """ + np.random.seed(value) + torch.manual_seed(value) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + random.seed(value) diff --git a/requirements.txt b/requirements.txt index 67f034a..457b759 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ +datasets==1.17.0 +nltk==3.6.7 numpy==1.21.1 Pillow==8.3.1 torch==1.9.0 torchvision==0.10.0 +transformers==4.15.0 parameterized==0.8.1 scipy==1.6.2 opencv-python==4.5.3.56 diff --git a/tests/test_ch8/test_pwws.py b/tests/test_ch8/test_pwws.py new file mode 100644 index 0000000..bdb32cb --- /dev/null +++ b/tests/test_ch8/test_pwws.py @@ -0,0 +1,47 @@ +import unittest + +import datasets +import transformers +from parameterized import parameterized_class + +from code_soup.ch8.pwws import PWWSAttacker +from code_soup.common.text.datasets.utils import dataset_mapping +from code_soup.common.text.models import transformers_classifier +from code_soup.common.text.utils.attack_helpers import attack_process +from code_soup.common.text.utils.metrics import Levenshtein +from code_soup.common.text.utils.tokenizer import PunctTokenizer +from code_soup.misc import seed + +seed(42) + + +class TestPWWSAttacker(unittest.TestCase): + """ + pwws.PWWSAttacker() test cases + """ + + @classmethod + def setUpClass(cls) -> None: + + path = "distilbert-base-uncased-finetuned-sst-2-english" + + # define the attack + cls.attacker = PWWSAttacker() + + # define the victim model (classifier) + tokenizer = transformers.AutoTokenizer.from_pretrained(path) + model = transformers.AutoModelForSequenceClassification.from_pretrained( + path, num_labels=2, output_hidden_states=False + ) + cls.victim = transformers_classifier.TransformersClassifier( + model, tokenizer, model.distilbert.embeddings.word_embeddings + ) + + # load the dataset + cls.dataset = datasets.load_dataset("sst", split="train[:2]").map( + function=dataset_mapping + ) + + def test_output(cls): + for sample in cls.dataset: + cls.attacker(cls.victim, sample) diff --git a/tests/test_common/test_text/test_datasets/test_utils.py b/tests/test_common/test_text/test_datasets/test_utils.py new file mode 100644 index 0000000..f757042 --- /dev/null +++ b/tests/test_common/test_text/test_datasets/test_utils.py @@ -0,0 +1,30 @@ +import unittest + +from parameterized import parameterized_class + +from code_soup.common.text.datasets.utils import dataset_mapping + + +@parameterized_class( + ("x", "expected_output"), + [ + ( + {"sentence": "Chuffed to bits!", "label": 0.598}, + {"x": "Chuffed to bits!", "y": 1}, + ), + ({"sentence": "Hello", "label": 0.342}, {"x": "Hello", "y": 0}), + ], +) +class TestTextDatasetUtilsDatasetMapping(unittest.TestCase): + """ + Parameterized test cases for the common/text/datasets/utils/dataset_mapping + function. + + Args: ("x", "expected_output") + """ + + def setUp(self): + pass + + def test_output(self): + self.assertDictEqual(dataset_mapping(self.x), self.expected_output) diff --git a/tests/test_common/test_text/test_models/test_transformers_classifier.py b/tests/test_common/test_text/test_models/test_transformers_classifier.py new file mode 100644 index 0000000..0b47bec --- /dev/null +++ b/tests/test_common/test_text/test_models/test_transformers_classifier.py @@ -0,0 +1,69 @@ +import random +import unittest + +import numpy as np +import torch +from parameterized import parameterized_class +from transformers import BertForSequenceClassification, BertTokenizer + +from code_soup.common.text.models.transformers_classifier import TransformersClassifier +from code_soup.misc import seed + +seed(42) +model = BertForSequenceClassification.from_pretrained( + "textattack/bert-base-uncased-imdb" +) +tokenizer = BertTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") +embedding_layer = model.bert.embeddings.word_embeddings +device = torch.device("cpu") + + +@parameterized_class( + ("input_", "expected_output"), + [ + (["inception is an awesome movie ."], [1]), + (["marvel is cliche .", "Fascinating movie, that !"], [0, 1]), + ], +) +class TestTransformersClassifierGetPred(unittest.TestCase): + """ + Parameterized test cases for the TransformersClassifier.get_pred() function + from the common/text/models/transformers_classifier.py file. + + Args: ("x", "expected_output") + """ + + def setUp(self): + self.clf = TransformersClassifier(model, tokenizer, embedding_layer, device) + + def test_output(self): + self.assertEqual(list(self.clf.get_pred(self.input_)), self.expected_output) + + +@parameterized_class( + ("input_", "expected_output"), + [ + (["inception is an awesome movie ."], np.array([[0.01, 0.99]])), + ( + ["marvel is cliche .", "Fascinating movie, that !"], + np.array([[0.997, 0.003], [0.032, 0.968]]), + ), + ], +) +class TestTransformersClassifierGetProb(unittest.TestCase): + """ + Parameterized test cases for the TransformersClassifier.get_prob() function + from the common/text/models/transformers_classifier.py file. + + Args: ("x", "expected_output") + """ + + def setUp(self): + self.clf = TransformersClassifier(model, tokenizer, embedding_layer, device) + + def test_output(self): + self.assertIsNone( + np.testing.assert_almost_equal( + self.clf.get_prob(self.input_), self.expected_output, decimal=3 + ) + ) diff --git a/tests/test_common/test_text/test_utils/test_metrics.py b/tests/test_common/test_text/test_utils/test_metrics.py new file mode 100644 index 0000000..0c5c027 --- /dev/null +++ b/tests/test_common/test_text/test_utils/test_metrics.py @@ -0,0 +1,27 @@ +import random +import unittest + +from parameterized import parameterized_class + +from code_soup.common.text.utils import metrics, tokenizer +from code_soup.misc import seed + + +@parameterized_class( + ("input", "adversarial_sample", "expected_output"), + [({"x": "compute"}, "comp te", 2), ({"x": "bottle"}, "abossme", 1)], +) +class TestLevenshteinParameterized(unittest.TestCase): + """ + Levenshtein.after_attack Parameterized test case + Args: ("input", "adversarial_sample", "expected_output") + """ + + def setUp(self): + self.levenshtein = metrics.Levenshtein(tokenizer.PunctTokenizer()) + + def test_output(self): + self.assertEqual( + self.levenshtein.after_attack(self.input, self.adversarial_sample), + self.expected_output, + ) diff --git a/tests/test_common/test_text/test_utils/test_tokenizer.py b/tests/test_common/test_text/test_utils/test_tokenizer.py new file mode 100644 index 0000000..3e41552 --- /dev/null +++ b/tests/test_common/test_text/test_utils/test_tokenizer.py @@ -0,0 +1,210 @@ +import unittest + +from parameterized import parameterized_class +from transformers import BertTokenizer + +from code_soup.common.text.utils import tokenizer + + +@parameterized_class( + ("x", "expected_result"), + [ + ( + "xlnet is better than bert . but bert has less parameters .", + [ + ("xlnet", "noun"), + ("is", "verb"), + ("better", "adj"), + ("than", "other"), + ("bert", "noun"), + (".", "other"), + ("but", "other"), + ("bert", "noun"), + ("has", "verb"), + ("less", "adj"), + ("parameters", "noun"), + (".", "other"), + ], + ), + ( + "reformers are efficient transformers . longformers can handle long texts .", + [ + ("reformers", "noun"), + ("are", "verb"), + ("efficient", "adj"), + ("transformers", "noun"), + (".", "other"), + ("longformers", "noun"), + ("can", "other"), + ("handle", "verb"), + ("long", "adj"), + ("texts", "noun"), + (".", "other"), + ], + ), + ], +) +class TestPunctTokenizerTokenizeWPosParameterized(unittest.TestCase): + """ + PunctTokenizer.tokenize() Parameterized TestCase + Args: ("x", "expected_result") + """ + + def setUp(self): + self.tok = tokenizer.PunctTokenizer() + + def test_output(self): + self.assertEqual(self.tok.tokenize(self.x), self.expected_result) + + +@parameterized_class( + ("x", "expected_result"), + [ + ( + "xlnet is better than bert . but bert has less parameters .", + [ + "xlnet", + "is", + "better", + "than", + "bert", + ".", + "but", + "bert", + "has", + "less", + "parameters", + ".", + ], + ), + ( + "reformers are efficient transformers . longformers can handle long texts .", + [ + "reformers", + "are", + "efficient", + "transformers", + ".", + "longformers", + "can", + "handle", + "long", + "texts", + ".", + ], + ), + ], +) +class TestPunctTokenizerTokenizeWoPosParameterized(unittest.TestCase): + """ + PunctTokenizer.tokenize() Parameterized TestCase + Args: ("x", "expected_result") + """ + + def setUp(self): + self.tok = tokenizer.PunctTokenizer() + + def test_output(self): + self.assertEqual(self.tok.tokenize(self.x, False), self.expected_result) + + +@parameterized_class( + ("x", "expected_result"), + [ + ( + [ + "xlnet", + "is", + "better", + "than", + "bert", + ".", + "but", + "bert", + "has", + "less", + "parameters", + ".", + ], + "xlnet is better than bert . but bert has less parameters .", + ), + ( + [ + "reformers", + "are", + "efficient", + "transformers", + ".", + "longformers", + "can", + "handle", + "long", + "texts", + ".", + ], + "reformers are efficient transformers . longformers can handle long texts .", + ), + ([], ""), + ], +) +class TestPunctTokenizerDetokenizeParameterized(unittest.TestCase): + """ + PunctTokenizer.tokenize() Parameterized TestCase + Args: ("x", "expected_result") + """ + + def setUp(self): + self.tok = tokenizer.PunctTokenizer() + + def test_output(self): + self.assertEqual(self.tok.detokenize(self.x), self.expected_result) + + +@parameterized_class( + ("x", "expected_result"), + [ + ("short sentence .", ["short", "sentence", "."]), + ( + "another sentence, slightly longer .", + ["another", "sentence", ",", "slightly", "longer", "."], + ), + ], +) +class TestTransformersTokenizerTokenizeParameterized(unittest.TestCase): + """ + TransformersTokenizer.tokenize() Parameterized TestCase + Args: ("x", "expected_result") + """ + + def setUp(self): + self.tok = tokenizer.TransformersTokenizer( + BertTokenizer.from_pretrained("bert-base-uncased") + ) + + def test_output(self): + self.assertEqual(self.tok.tokenize(self.x, False), self.expected_result) + + +@parameterized_class( + ("x", "expected_result"), + [ + (["short", "sentence", "."], "short sentence ."), + ( + ["another", "sentence", ",", "slightly", "longer", "."], + "another sentence , slightly longer .", + ), + ], +) +class TestTransformersTokenizerDetokenizeParameterized(unittest.TestCase): + """ + TransformersTokenizer.detokenize() Parameterized TestCase + Args: ("x", "expected_result") + """ + + def setUp(self): + self.tok = tokenizer.TransformersTokenizer( + BertTokenizer.from_pretrained("bert-base-uncased") + ) + + def test_output(self): + self.assertEqual(self.tok.detokenize(self.x), self.expected_result) diff --git a/tests/test_common/test_text/test_utils/test_word_substitute.py b/tests/test_common/test_text/test_utils/test_word_substitute.py new file mode 100644 index 0000000..796e5df --- /dev/null +++ b/tests/test_common/test_text/test_utils/test_word_substitute.py @@ -0,0 +1,109 @@ +import random +import unittest + +from parameterized import parameterized_class + +from code_soup.common.text.utils import word_substitute +from code_soup.common.text.utils.exceptions import UnknownPOSException +from code_soup.misc import seed + +seed(42) + + +@parameterized_class( + ("word", "pos", "expected_result"), + [ + ( + "compute", + "verb", + [ + ("calculate", 1), + ("cipher", 1), + ("figure", 1), + ("cypher", 1), + ("work", 1), + ("reckon", 1), + ], + ), + ("bottle", "noun", [("bottleful", 1), ("feeding", 1), ("nursing", 1)]), + ], +) +class TestWordNetSubstituteParameterized(unittest.TestCase): + """ + WordNetSubstitute.substitute() Parameterized TestCase + Args: ("word", "pos", "expected_result") + """ + + def setUp(self): + self.wordnet_substitute = word_substitute.WordNetSubstitute() + + def test_output(self): + self.assertEqual( + sorted(self.wordnet_substitute.substitute(self.word, self.pos)), + sorted(self.expected_result), + ) + + +@parameterized_class( + ("word", "pos", "expected_result"), + [ + ( + "compute", + "verb", + [ + ("calculate", 1), + ("cipher", 1), + ("figure", 1), + ("cypher", 1), + ("work", 1), + ("reckon", 1), + ], + ), + ( + "chair", + None, + [ + ("hot", 1), + ("electric", 1), + ("death", 1), + ("chairwoman", 1), + ("professorship", 1), + ("chairman", 1), + ("chairperson", 1), + ("president", 1), + ], + ), + ], +) +class TestWordNetSubstituteCallParameterized(unittest.TestCase): + """ + WordNetSubstitute() Parameterized TestCase + Args: ("word", "pos", "expected_result") + """ + + def setUp(self): + self.wordnet_substitute = word_substitute.WordNetSubstitute() + + def test_output(self): + # instead of checking for equality, ensure that 85% of the synonyms are in the result + self.assertGreater( + len( + set(self.wordnet_substitute(self.word, self.pos)).intersection( + set(self.expected_result) + ) + ) + / len(self.expected_result), + 0.85, + ) + + +class TestWordNetSubstituteCallException(unittest.TestCase): + """ + WordNetSubstitute() TestCase for UnknownPOSException + """ + + def setUp(self): + self.wordnet_substitute = word_substitute.WordNetSubstitute() + + def test_output(self): + self.assertRaises(UnknownPOSException, self.wordnet_substitute, "dummy", "none")