Skip to content

Commit 98f48eb

Browse files
authored
Merge branch 'main' into main
2 parents c300760 + b2b96e4 commit 98f48eb

30 files changed

+757
-88
lines changed

Diff for: pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ processing = [
5858
"fasteners",
5959
"xxhash"
6060
]
61+
decont = [
62+
"lighteval>=0.3.0"
63+
]
6164
quality = [
6265
"ruff>=0.1.5"
6366
]
@@ -66,6 +69,7 @@ testing = [
6669
"datatrove[io]",
6770
"datatrove[processing]",
6871
"datatrove[s3]",
72+
"datatrove[decont]",
6973
"pytest",
7074
"pytest-timeout",
7175
"pytest-xdist",

Diff for: src/datatrove/executor/slurm.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import math
45
import os
56
import random
67
import signal
@@ -25,7 +26,7 @@
2526
def requeue_handler(signum, _frame):
2627
signame = signal.Signals(signum).name
2728
logger.warning(f"Received signal {signum} ({signame}). Requeueing and exiting...")
28-
subprocess.run(["scontrol", "requeue", "${SLURM_JOB_ID}"])
29+
subprocess.run(["scontrol", "requeue", os.environ.get("SLURM_JOB_ID")])
2930
sys.exit(15)
3031

3132

@@ -79,7 +80,7 @@ class SlurmPipelineExecutor(PipelineExecutor):
7980
mail_type: see https://slurm.schedmd.com/sbatch.html. Common values are (NONE, BEGIN, END, FAIL, REQUEUE, ALL)
8081
mail_user: email address to send notifications to
8182
requeue: requeue the job if it fails
82-
83+
tasks_per_job: each slurm job in the job array will run these many datatrove tasks. This reduces the total nb of slurm jobs launched.
8384
"""
8485

8586
def __init__(
@@ -111,13 +112,16 @@ def __init__(
111112
mail_type: str = "ALL",
112113
mail_user: str = None,
113114
requeue: bool = True,
115+
srun_args: dict = None,
116+
tasks_per_job: int = 1,
114117
):
115118
super().__init__(pipeline, logging_dir, skip_completed)
116119
self.tasks = tasks
117120
self.workers = workers
118121
self.partition = partition
119122
self.cpus_per_task = cpus_per_task
120123
self.mem_per_cpu_gb = mem_per_cpu_gb
124+
self.tasks_per_job = tasks_per_job
121125
self.time = time
122126
self.job_name = job_name
123127
self.qos = qos
@@ -136,6 +140,7 @@ def __init__(
136140
self.requeue_signals = requeue_signals
137141
self.mail_type = mail_type
138142
self.mail_user = mail_user
143+
self.srun_args = srun_args
139144
self.slurm_logs_folder = (
140145
slurm_logs_folder
141146
if slurm_logs_folder
@@ -160,18 +165,23 @@ def run(self):
160165
slurm_rank = int(os.environ["SLURM_ARRAY_TASK_ID"]) + self.max_array_size * int(
161166
os.environ.get("RUN_OFFSET", 0)
162167
)
168+
ranks_to_run_range = (slurm_rank * self.tasks_per_job, (slurm_rank + 1) * self.tasks_per_job)
163169
with self.logging_dir.open("ranks_to_run.json", "r") as ranks_to_run_file:
164170
all_ranks = json.load(ranks_to_run_file)
165-
if slurm_rank >= len(all_ranks):
171+
if ranks_to_run_range[0] >= len(all_ranks):
166172
return
167-
rank = all_ranks[slurm_rank]
168173

169174
for ss in self.requeue_signals or []:
170175
signal.signal(signal.Signals[ss], requeue_handler)
171176

172-
if self.randomize_start:
173-
time.sleep(random.randint(0, 60 * 3))
174-
self._run_for_rank(rank)
177+
for rank_to_run in range(*ranks_to_run_range):
178+
if rank_to_run >= len(all_ranks):
179+
break
180+
rank = all_ranks[rank_to_run]
181+
182+
if self.randomize_start:
183+
time.sleep(random.randint(0, 60 * 3))
184+
self._run_for_rank(rank)
175185
else:
176186
# we still have to launch the job
177187
self.launch_job()
@@ -244,12 +254,14 @@ def launch_job(self):
244254
# we actually save this (only once) to avoid race conditions
245255
json.dump(ranks_to_run, ranks_to_run_file)
246256

247-
max_array = min(len(ranks_to_run), self.max_array_size) if self.max_array_size != -1 else len(ranks_to_run)
257+
nb_jobs_to_launch = math.ceil(len(ranks_to_run) / self.tasks_per_job)
258+
max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch
248259

249260
# create the actual sbatch script
261+
srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else ""
250262
launch_file_contents = self.get_launch_file_contents(
251263
self.get_sbatch_args(max_array),
252-
f"srun -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
264+
f"srun {srun_args_str} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
253265
)
254266
# save it
255267
with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f:
@@ -261,7 +273,7 @@ def launch_job(self):
261273

262274
# launch (possibly multiple) jobs
263275
launched_jobs = 0
264-
while launched_jobs * max_array < len(ranks_to_run):
276+
while launched_jobs * max_array < nb_jobs_to_launch:
265277
if launched_jobs and self.max_array_launch_parallel and self.stagger_max_array_jobs > 0:
266278
time.sleep(self.stagger_max_array_jobs)
267279
args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]

Diff for: src/datatrove/io.py

+5
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def open_file(file: IO | str, mode="rt", **kwargs):
284284
return file
285285

286286

287+
def file_exists(path: str):
288+
fs, a, fpath = get_fs_token_paths(path)
289+
return fs.exists(fpath[0])
290+
291+
287292
def download_file(remote_path: str, local_path: str, progress: bool = True):
288293
fs, _, paths = get_fs_token_paths(remote_path)
289294
fs.get_file(

Diff for: src/datatrove/pipeline/decont/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .n_grams import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer

Diff for: src/datatrove/pipeline/decont/n_grams.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
Used for n-gram decontamination.
3+
First build an index using the tasks we want to use to decontaminate our training dataset.
4+
Then read your training data and apply the filter with the index loaded.
5+
"""
6+
7+
import os
8+
from collections import defaultdict
9+
from concurrent.futures import ThreadPoolExecutor
10+
from dataclasses import dataclass, field
11+
from typing import Tuple
12+
13+
import numpy as np
14+
from loguru import logger
15+
16+
from datatrove.data import Document, DocumentsPipeline
17+
from datatrove.io import DataFolderLike, file_exists, get_datafolder, open_file
18+
from datatrove.pipeline.base import PipelineStep
19+
from datatrove.pipeline.filters.base_filter import BaseFilter
20+
from datatrove.pipeline.writers.disk_base import DiskWriter
21+
from datatrove.utils.binaryio import read_np_from_file
22+
from datatrove.utils.text import TextNormConfig, simplify_text, xxhash64
23+
24+
25+
@dataclass
26+
class NGramsDecontConfig:
27+
"""
28+
Example for n_grams=4
29+
query = ['A', 'B', 'C', 'D', 'E'] (the prompt/instruction)
30+
label = ['F', 'G', 'H', 'I', 'J'] (the answer/gold)
31+
Will find the following N-GRAMS in the training data:
32+
'F G H I'
33+
'G H I J'
34+
+ IF find_query_ngrams:
35+
'A B C D'
36+
'B C D E'
37+
+ IF find_overlap_ngrams:
38+
'C D E F'
39+
'D E F G'
40+
'E F G H'
41+
"""
42+
43+
n_grams: int = 12
44+
find_query_ngrams: bool = False # enable to also check for matches in n-grams containing only the input/prompt
45+
find_overlap_ngrams: bool = True # will also find matches for n-grams containing BOTH input and query
46+
norm_config: TextNormConfig = field(default_factory=TextNormConfig)
47+
48+
49+
DEFAULT_NGRAMS_DECONT_CONFIG = NGramsDecontConfig()
50+
51+
52+
class NGramsDecontIndexer(PipelineStep):
53+
"""
54+
Creates a decontamination index (basically a list of uint64 hashes from ngrams) for each reference task.
55+
Ways to provide task data:
56+
- as input documents from the previous pipeline step with "text=label/correct answer"
57+
and metadata={"query": query/prompt/input, "task": task name}
58+
- as a list of strings in the format "suite|task" from the lighteval metadata table:
59+
https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl as `lighteval_tasks`
60+
- a path to a text file containing one such list, with one "suite|task" per line as `lighteval_tasks`
61+
you can also define your custom tasks with `custom_lighteval_tasks`. See explanation for `custom_tasks` here:
62+
https://github.com/huggingface/lighteval/tree/main?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks
63+
64+
"""
65+
66+
type = "🦠 - DECONT"
67+
name = "💥 N-grams build index"
68+
_requires_dependencies = ["nltk", "lighteval", "xxhash"]
69+
70+
def __init__(
71+
self,
72+
output_folder: DataFolderLike,
73+
lighteval_tasks: str | list[str] | None = None, # list in the format suite|task or path to one such list
74+
custom_lighteval_tasks: str | None = None,
75+
config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG,
76+
language: str = "english",
77+
):
78+
super().__init__()
79+
self.output_folder = get_datafolder(output_folder)
80+
# parse list of tasks
81+
if isinstance(lighteval_tasks, str):
82+
if file_exists(lighteval_tasks):
83+
with open_file(lighteval_tasks, "rt") as f:
84+
self.lighteval_tasks = f.read().strip().splitlines()
85+
else:
86+
self.lighteval_tasks = [lighteval_tasks]
87+
else:
88+
self.lighteval_tasks = lighteval_tasks
89+
self.custom_lighteval_tasks = custom_lighteval_tasks
90+
self.config = config
91+
self.language = language
92+
93+
def compute_hashes(self, label: str, query: str | None = None) -> list[int]:
94+
from nltk import ngrams
95+
from nltk.tokenize import word_tokenize
96+
97+
label_tokens = word_tokenize(simplify_text(label, self.config.norm_config), language=self.language)
98+
ngrams_to_compute = list(ngrams(label_tokens, self.config.n_grams))
99+
if query is not None:
100+
query_tokens = word_tokenize(simplify_text(query, self.config.norm_config), language=self.language)
101+
if self.config.find_query_ngrams:
102+
ngrams_to_compute.extend(ngrams(query_tokens, self.config.n_grams))
103+
if self.config.find_overlap_ngrams:
104+
# add tokens overlapping query and label
105+
"""
106+
A, B, C, D, E | F, G, H, I, J
107+
5 grams
108+
B, C, D, E, F (-N + 1 + i:) + (:i + 1)
109+
...
110+
E, F, G, H, I
111+
"""
112+
ngrams_to_compute.extend(
113+
[
114+
query_tokens[-self.config.n_grams + 1 + i :] + label_tokens[: i + 1]
115+
for i in range(self.config.n_grams - 1)
116+
# make sure we actually get a list of size N
117+
if len(query_tokens) >= self.config.n_grams - 1 - i and len(label_tokens) >= i + 1
118+
]
119+
)
120+
return list(map(xxhash64, map(" ".join, ngrams_to_compute)))
121+
122+
def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1):
123+
if world_size != 1:
124+
raise ValueError("Decontamination index building requires a single worker.")
125+
hashes = defaultdict(set)
126+
# use whatever date is parsed in with the following format:
127+
# doc.text -> label
128+
# doc.metadata["input"] -> input
129+
if data:
130+
for doc in data:
131+
if not self.config.find_query_ngrams and "query" not in doc.metadata:
132+
raise ValueError(
133+
"only_label_ngrams is False but could not find 'query' field in documents metadata"
134+
)
135+
hashes[doc.metadata.get("task", "input")].update(
136+
self.compute_hashes(doc.text, doc.metadata.get("query", None))
137+
)
138+
139+
# parse data from lighteval defined tasks
140+
from lighteval.tasks.lighteval_task import LightevalTask
141+
from lighteval.tasks.registry import Registry
142+
143+
task_dict = Registry(cache_dir=os.getenv("HF_HOME")).get_task_dict(
144+
self.lighteval_tasks, custom_tasks=self.custom_lighteval_tasks
145+
)
146+
LightevalTask.load_datasets(task_dict.values())
147+
148+
for task_name, task in task_dict.items():
149+
for eval_doc in task.eval_docs():
150+
try:
151+
golds = eval_doc.get_golds()
152+
query = eval_doc.query
153+
except Exception as e:
154+
logger.warning(f"Error while fetching doc data: {e}")
155+
continue
156+
for gold in golds:
157+
hashes[task_name].update(self.compute_hashes(gold, query))
158+
159+
for task_name, task_hashes in hashes.items():
160+
hashes_array = np.array(list(task_hashes), dtype="<u8")
161+
logger.info(f"Saving {len(task_hashes)} hashes for {task_name}")
162+
with self.output_folder.open(f"{task_name.replace(' ', '_')}.index.hashes", mode="wb") as f:
163+
if self.output_folder.is_local():
164+
hashes_array.tofile(f)
165+
else:
166+
f.write(hashes_array.tobytes())
167+
168+
169+
class NGramsDecontFilter(BaseFilter):
170+
"""
171+
Loads list of hashes created by the Indexer step.
172+
For each document in the block's input, we will check if any of its ngrams are part of the reference eval tasks.
173+
If so, they will be removed. The contaminated ngram and task where it was found will be saved in the removed
174+
document's metadata.
175+
"""
176+
177+
type = "🦠 - DECONT"
178+
name = "💥 N-grams decontaminate"
179+
_requires_dependencies = ["nltk", "xxhash"]
180+
181+
def __init__(
182+
self,
183+
index_folder: DataFolderLike,
184+
config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG,
185+
exclusion_writer: DiskWriter = None,
186+
language: str = "english",
187+
):
188+
super().__init__()
189+
self.index_folder = get_datafolder(index_folder)
190+
self.config = config
191+
self.exclusion_writer = exclusion_writer
192+
self.language = language
193+
self._index_hashes = None
194+
195+
def load_index_hashes(self):
196+
def load_index_from_file(file):
197+
with self.index_folder.open(file, mode="rb") as f:
198+
return file, read_np_from_file(f, np.dtype("<u8"), self.index_folder.is_local()).tolist()
199+
200+
with ThreadPoolExecutor() as pool:
201+
hashes = pool.map(load_index_from_file, self.index_folder.list_files())
202+
203+
self._index_hashes = {}
204+
for filename, hashlist in hashes:
205+
taskname = filename.removesuffix(".index.hashes")
206+
logger.info(f"Loading {len(hashlist)} hashes for {taskname}")
207+
for hash in hashlist:
208+
self._index_hashes[hash] = taskname
209+
210+
def filter(self, doc: Document) -> bool | Tuple[bool, str]:
211+
if self._index_hashes is None:
212+
self.load_index_hashes()
213+
214+
from nltk import ngrams
215+
from nltk.tokenize import word_tokenize
216+
217+
text_tokens = word_tokenize(simplify_text(doc.text, self.config.norm_config), language=self.language)
218+
ngrams_to_compute = list(ngrams(text_tokens, self.config.n_grams))
219+
for n_gram in map(" ".join, ngrams_to_compute):
220+
task = self._index_hashes.get(xxhash64(n_gram), None)
221+
if task is not None:
222+
doc.metadata["contaminated_ngram"] = n_gram
223+
doc.metadata["contaminated_task"] = task
224+
self.stat_update(f"contaminated_{task}")
225+
if ":" in task:
226+
self.stat_update(f"contaminated_tg_{task[:task.index(':')]}")
227+
return False, "contaminated"
228+
return True

Diff for: src/datatrove/pipeline/dedup/minhash.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class MinhashConfig:
5050
num_buckets: int = 14
5151
hashes_per_bucket: int = 8
5252

53-
use_64bit_hashes: bool = False
53+
use_64bit_hashes: bool = True
5454
seed: int = 1
5555

5656
norm_config: TextNormConfig = field(default_factory=TextNormConfig)

Diff for: src/datatrove/pipeline/filters/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .c4_quality_filter import C4ParagraphFilter, C4QualityFilter
1+
from .c4_filters import C4BadWordsFilter, C4ParagraphFilter, C4QualityFilter
22
from .fasttext_filter import FastTextClassifierFilter
33
from .fineweb_quality_filter import FineWebQualityFilter
44
from .gopher_quality_filter import GopherQualityFilter

0 commit comments

Comments
 (0)