|
| 1 | +""" |
| 2 | +Used for n-gram decontamination. |
| 3 | +First build an index using the tasks we want to use to decontaminate our training dataset. |
| 4 | +Then read your training data and apply the filter with the index loaded. |
| 5 | +""" |
| 6 | + |
| 7 | +import os |
| 8 | +from collections import defaultdict |
| 9 | +from concurrent.futures import ThreadPoolExecutor |
| 10 | +from dataclasses import dataclass, field |
| 11 | +from typing import Tuple |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +from loguru import logger |
| 15 | + |
| 16 | +from datatrove.data import Document, DocumentsPipeline |
| 17 | +from datatrove.io import DataFolderLike, file_exists, get_datafolder, open_file |
| 18 | +from datatrove.pipeline.base import PipelineStep |
| 19 | +from datatrove.pipeline.filters.base_filter import BaseFilter |
| 20 | +from datatrove.pipeline.writers.disk_base import DiskWriter |
| 21 | +from datatrove.utils.binaryio import read_np_from_file |
| 22 | +from datatrove.utils.text import TextNormConfig, simplify_text, xxhash64 |
| 23 | + |
| 24 | + |
| 25 | +@dataclass |
| 26 | +class NGramsDecontConfig: |
| 27 | + """ |
| 28 | + Example for n_grams=4 |
| 29 | + query = ['A', 'B', 'C', 'D', 'E'] (the prompt/instruction) |
| 30 | + label = ['F', 'G', 'H', 'I', 'J'] (the answer/gold) |
| 31 | + Will find the following N-GRAMS in the training data: |
| 32 | + 'F G H I' |
| 33 | + 'G H I J' |
| 34 | + + IF find_query_ngrams: |
| 35 | + 'A B C D' |
| 36 | + 'B C D E' |
| 37 | + + IF find_overlap_ngrams: |
| 38 | + 'C D E F' |
| 39 | + 'D E F G' |
| 40 | + 'E F G H' |
| 41 | + """ |
| 42 | + |
| 43 | + n_grams: int = 12 |
| 44 | + find_query_ngrams: bool = False # enable to also check for matches in n-grams containing only the input/prompt |
| 45 | + find_overlap_ngrams: bool = True # will also find matches for n-grams containing BOTH input and query |
| 46 | + norm_config: TextNormConfig = field(default_factory=TextNormConfig) |
| 47 | + |
| 48 | + |
| 49 | +DEFAULT_NGRAMS_DECONT_CONFIG = NGramsDecontConfig() |
| 50 | + |
| 51 | + |
| 52 | +class NGramsDecontIndexer(PipelineStep): |
| 53 | + """ |
| 54 | + Creates a decontamination index (basically a list of uint64 hashes from ngrams) for each reference task. |
| 55 | + Ways to provide task data: |
| 56 | + - as input documents from the previous pipeline step with "text=label/correct answer" |
| 57 | + and metadata={"query": query/prompt/input, "task": task name} |
| 58 | + - as a list of strings in the format "suite|task" from the lighteval metadata table: |
| 59 | + https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl as `lighteval_tasks` |
| 60 | + - a path to a text file containing one such list, with one "suite|task" per line as `lighteval_tasks` |
| 61 | + you can also define your custom tasks with `custom_lighteval_tasks`. See explanation for `custom_tasks` here: |
| 62 | + https://github.com/huggingface/lighteval/tree/main?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks |
| 63 | +
|
| 64 | + """ |
| 65 | + |
| 66 | + type = "🦠 - DECONT" |
| 67 | + name = "💥 N-grams build index" |
| 68 | + _requires_dependencies = ["nltk", "lighteval", "xxhash"] |
| 69 | + |
| 70 | + def __init__( |
| 71 | + self, |
| 72 | + output_folder: DataFolderLike, |
| 73 | + lighteval_tasks: str | list[str] | None = None, # list in the format suite|task or path to one such list |
| 74 | + custom_lighteval_tasks: str | None = None, |
| 75 | + config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG, |
| 76 | + language: str = "english", |
| 77 | + ): |
| 78 | + super().__init__() |
| 79 | + self.output_folder = get_datafolder(output_folder) |
| 80 | + # parse list of tasks |
| 81 | + if isinstance(lighteval_tasks, str): |
| 82 | + if file_exists(lighteval_tasks): |
| 83 | + with open_file(lighteval_tasks, "rt") as f: |
| 84 | + self.lighteval_tasks = f.read().strip().splitlines() |
| 85 | + else: |
| 86 | + self.lighteval_tasks = [lighteval_tasks] |
| 87 | + else: |
| 88 | + self.lighteval_tasks = lighteval_tasks |
| 89 | + self.custom_lighteval_tasks = custom_lighteval_tasks |
| 90 | + self.config = config |
| 91 | + self.language = language |
| 92 | + |
| 93 | + def compute_hashes(self, label: str, query: str | None = None) -> list[int]: |
| 94 | + from nltk import ngrams |
| 95 | + from nltk.tokenize import word_tokenize |
| 96 | + |
| 97 | + label_tokens = word_tokenize(simplify_text(label, self.config.norm_config), language=self.language) |
| 98 | + ngrams_to_compute = list(ngrams(label_tokens, self.config.n_grams)) |
| 99 | + if query is not None: |
| 100 | + query_tokens = word_tokenize(simplify_text(query, self.config.norm_config), language=self.language) |
| 101 | + if self.config.find_query_ngrams: |
| 102 | + ngrams_to_compute.extend(ngrams(query_tokens, self.config.n_grams)) |
| 103 | + if self.config.find_overlap_ngrams: |
| 104 | + # add tokens overlapping query and label |
| 105 | + """ |
| 106 | + A, B, C, D, E | F, G, H, I, J |
| 107 | + 5 grams |
| 108 | + B, C, D, E, F (-N + 1 + i:) + (:i + 1) |
| 109 | + ... |
| 110 | + E, F, G, H, I |
| 111 | + """ |
| 112 | + ngrams_to_compute.extend( |
| 113 | + [ |
| 114 | + query_tokens[-self.config.n_grams + 1 + i :] + label_tokens[: i + 1] |
| 115 | + for i in range(self.config.n_grams - 1) |
| 116 | + # make sure we actually get a list of size N |
| 117 | + if len(query_tokens) >= self.config.n_grams - 1 - i and len(label_tokens) >= i + 1 |
| 118 | + ] |
| 119 | + ) |
| 120 | + return list(map(xxhash64, map(" ".join, ngrams_to_compute))) |
| 121 | + |
| 122 | + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): |
| 123 | + if world_size != 1: |
| 124 | + raise ValueError("Decontamination index building requires a single worker.") |
| 125 | + hashes = defaultdict(set) |
| 126 | + # use whatever date is parsed in with the following format: |
| 127 | + # doc.text -> label |
| 128 | + # doc.metadata["input"] -> input |
| 129 | + if data: |
| 130 | + for doc in data: |
| 131 | + if not self.config.find_query_ngrams and "query" not in doc.metadata: |
| 132 | + raise ValueError( |
| 133 | + "only_label_ngrams is False but could not find 'query' field in documents metadata" |
| 134 | + ) |
| 135 | + hashes[doc.metadata.get("task", "input")].update( |
| 136 | + self.compute_hashes(doc.text, doc.metadata.get("query", None)) |
| 137 | + ) |
| 138 | + |
| 139 | + # parse data from lighteval defined tasks |
| 140 | + from lighteval.tasks.lighteval_task import LightevalTask |
| 141 | + from lighteval.tasks.registry import Registry |
| 142 | + |
| 143 | + task_dict = Registry(cache_dir=os.getenv("HF_HOME")).get_task_dict( |
| 144 | + self.lighteval_tasks, custom_tasks=self.custom_lighteval_tasks |
| 145 | + ) |
| 146 | + LightevalTask.load_datasets(task_dict.values()) |
| 147 | + |
| 148 | + for task_name, task in task_dict.items(): |
| 149 | + for eval_doc in task.eval_docs(): |
| 150 | + try: |
| 151 | + golds = eval_doc.get_golds() |
| 152 | + query = eval_doc.query |
| 153 | + except Exception as e: |
| 154 | + logger.warning(f"Error while fetching doc data: {e}") |
| 155 | + continue |
| 156 | + for gold in golds: |
| 157 | + hashes[task_name].update(self.compute_hashes(gold, query)) |
| 158 | + |
| 159 | + for task_name, task_hashes in hashes.items(): |
| 160 | + hashes_array = np.array(list(task_hashes), dtype="<u8") |
| 161 | + logger.info(f"Saving {len(task_hashes)} hashes for {task_name}") |
| 162 | + with self.output_folder.open(f"{task_name.replace(' ', '_')}.index.hashes", mode="wb") as f: |
| 163 | + if self.output_folder.is_local(): |
| 164 | + hashes_array.tofile(f) |
| 165 | + else: |
| 166 | + f.write(hashes_array.tobytes()) |
| 167 | + |
| 168 | + |
| 169 | +class NGramsDecontFilter(BaseFilter): |
| 170 | + """ |
| 171 | + Loads list of hashes created by the Indexer step. |
| 172 | + For each document in the block's input, we will check if any of its ngrams are part of the reference eval tasks. |
| 173 | + If so, they will be removed. The contaminated ngram and task where it was found will be saved in the removed |
| 174 | + document's metadata. |
| 175 | + """ |
| 176 | + |
| 177 | + type = "🦠 - DECONT" |
| 178 | + name = "💥 N-grams decontaminate" |
| 179 | + _requires_dependencies = ["nltk", "xxhash"] |
| 180 | + |
| 181 | + def __init__( |
| 182 | + self, |
| 183 | + index_folder: DataFolderLike, |
| 184 | + config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG, |
| 185 | + exclusion_writer: DiskWriter = None, |
| 186 | + language: str = "english", |
| 187 | + ): |
| 188 | + super().__init__() |
| 189 | + self.index_folder = get_datafolder(index_folder) |
| 190 | + self.config = config |
| 191 | + self.exclusion_writer = exclusion_writer |
| 192 | + self.language = language |
| 193 | + self._index_hashes = None |
| 194 | + |
| 195 | + def load_index_hashes(self): |
| 196 | + def load_index_from_file(file): |
| 197 | + with self.index_folder.open(file, mode="rb") as f: |
| 198 | + return file, read_np_from_file(f, np.dtype("<u8"), self.index_folder.is_local()).tolist() |
| 199 | + |
| 200 | + with ThreadPoolExecutor() as pool: |
| 201 | + hashes = pool.map(load_index_from_file, self.index_folder.list_files()) |
| 202 | + |
| 203 | + self._index_hashes = {} |
| 204 | + for filename, hashlist in hashes: |
| 205 | + taskname = filename.removesuffix(".index.hashes") |
| 206 | + logger.info(f"Loading {len(hashlist)} hashes for {taskname}") |
| 207 | + for hash in hashlist: |
| 208 | + self._index_hashes[hash] = taskname |
| 209 | + |
| 210 | + def filter(self, doc: Document) -> bool | Tuple[bool, str]: |
| 211 | + if self._index_hashes is None: |
| 212 | + self.load_index_hashes() |
| 213 | + |
| 214 | + from nltk import ngrams |
| 215 | + from nltk.tokenize import word_tokenize |
| 216 | + |
| 217 | + text_tokens = word_tokenize(simplify_text(doc.text, self.config.norm_config), language=self.language) |
| 218 | + ngrams_to_compute = list(ngrams(text_tokens, self.config.n_grams)) |
| 219 | + for n_gram in map(" ".join, ngrams_to_compute): |
| 220 | + task = self._index_hashes.get(xxhash64(n_gram), None) |
| 221 | + if task is not None: |
| 222 | + doc.metadata["contaminated_ngram"] = n_gram |
| 223 | + doc.metadata["contaminated_task"] = task |
| 224 | + self.stat_update(f"contaminated_{task}") |
| 225 | + if ":" in task: |
| 226 | + self.stat_update(f"contaminated_tg_{task[:task.index(':')]}") |
| 227 | + return False, "contaminated" |
| 228 | + return True |
0 commit comments