Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script to decontaminate datasets against benchmark datasets #416

Merged
merged 11 commits into from
Feb 24, 2025
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,37 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```

#### Data decontamination

Following [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:

```shell
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python-10k" \
--problem_column problem \
-c
```

It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.

Arguments for the script:

```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]

options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup, -c Wether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
```

### Launching jobs on a Slurm cluster

If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
Expand Down
124 changes: 124 additions & 0 deletions scripts/decontaminate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python-10k" \
--split train \
--ngram_size 8 \
--problem_column problem \
-c
"""

import collections

from tqdm import tqdm


def normalize_string(text: str) -> str:
"""Basic string normalization."""
# Convert to lowercase and normalize whitespace
text = text.lower().strip()
# Replace multiple spaces with single space
text = " ".join(text.split())
return text


def word_ngrams(text: str, n: int) -> list:
"""Generate word-level n-grams from text."""
words = text.split()
return [" ".join(words[i : i + n]) for i in range(len(words) - n + 1)]


def build_ngram_lookup(documents: list[str], ngram_size: int = 8) -> dict[str, set[int]]:
"""Build ngram lookup for documents."""
lookup = collections.defaultdict(set)

for doc_id, document in enumerate(tqdm(documents)):
normalized_text = normalize_string(document)
ngrams = word_ngrams(normalized_text, ngram_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)

return lookup


def build_ngram_single(document: str, ngram_size: int = 8) -> set[str]:
normalized_text = normalize_string(document)
ngrams = word_ngrams(normalized_text, ngram_size)

return set(ngrams)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="Name of the dataset to check for contamination.")
parser.add_argument("--split", type=str, default="train", help="Split to check for contamination, defaults to `train`.")
parser.add_argument("--ngram_size", type=int, default=8, help="Size of n-grams to build, defaults to 8.")
parser.add_argument(
"--problem_column", type=str, default="problem", help="Name of the column containing the problem (prompt)."
)
parser.add_argument(
"--cleanup",
"-c",
action="store_true",
help="Wether to remove the contaminated rows before pushing the dataset.",
)
parser.add_argument(
"--new_dataset_name",
type=str,
default=None,
help="New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name."
)
args = parser.parse_args()

from datasets import load_dataset, Dataset

# Load the dataset to check for contamination
ds = load_dataset(args.dataset, split=args.split)

eval_datasets = {
"aime_2024": (load_dataset("HuggingFaceH4/aime_2024", split="train"), "problem"),
"aime_2025": (load_dataset("yentinglin/aime_2025", split="train"), "problem"),
"math_500": (load_dataset("HuggingFaceH4/MATH-500", split="test"), "problem"),
"gpqa": (load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train", trust_remote_code=True), "Question"),
"lcb": (
load_dataset(
"livecodebench/code_generation_lite", split="test", version_tag="v4_v5", trust_remote_code=True
),
"question_content",
),
}
ngram_lookups = {}
for ds_name, (eval_dataset, problem_col) in eval_datasets.items():
ngram_lookups[ds_name] = build_ngram_lookup(eval_dataset[problem_col], ngram_size=args.ngram_size)

for eval_name, ngram_lookup in ngram_lookups.items():
# Update the ngram_loopup variable for each dataset
def find_contaminated(row):
# For each example we have to build the ngrams and check for all of them on each row
ngrams = build_ngram_single(row[args.problem_column], ngram_size=args.ngram_size)
row[f"contaminated_{eval_name}"] = any(set(ngram in ngram_lookup for ngram in ngrams))
return row

ds = ds.map(find_contaminated, num_proc=8)

# Allow cleaning up via CLI args (removing the contaminated examples and dropping the columns)
def cleanup(dataset: Dataset) -> Dataset:
initial_size = len(dataset)
contamination_cols = [col for col in dataset.column_names if col.startswith("contaminated_")]
for col in contamination_cols:
if col.startswith("contaminated_"):
size_prior = len(dataset)
dataset = dataset.filter(lambda x: not x[col], num_proc=8)
if len(dataset) < size_prior:
print(f"Removed {size_prior - len(dataset)} samples from '{col.replace('contaminated_', '')}'")
dataset = dataset.remove_columns(contamination_cols)
print(f"Initial size: {initial_size}, Final size: {len(dataset)}")
return dataset

if args.cleanup:
ds = cleanup(ds)

new_ds_name = args.new_dataset_name or f"{args.dataset}_decontaminated"
ds.push_to_hub(new_ds_name, split="train", private=False)
print(f"Decontaminated dataset: {new_ds_name}")