Skip to content

Commit 0d70701

Browse files
committed
Add OSCAR filter
1 parent 979a0c7 commit 0d70701

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

Diff for: src/datatrove/pipeline/filters/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .sampler_filter import SamplerFilter
1010
from .unigram_log_probs import UnigramLogProbFilter
1111
from .url_filter import URLFilter
12-
from .multilingual_policy_filter import MultilingualPolicyFilter
12+
from .multilingual_policy_filter import MultilingualPolicyFilter
13+
from .oscar_filter import OSCARFilter

Diff for: src/datatrove/pipeline/filters/oscar_filter.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import re
2+
3+
from datatrove.data import Document
4+
from datatrove.pipeline.filters.base_filter import BaseFilter
5+
from datatrove.pipeline.writers.disk_base import DiskWriter
6+
7+
DEFAULT_OSCAR_MIN_HARMFUL_PP = 25.0
8+
DEFAULT_OSCAR_MAX_HARMFUL_PP = 100_000
9+
10+
DEFAULT_EXCLUDE_CATEGORIES = {
11+
# See http://dsi.ut-capitole.fr/blacklists/index_en.php
12+
"agressif",
13+
"adult",
14+
"cryptojacking",
15+
"dangerous_material",
16+
"phishing",
17+
"warez",
18+
"ddos",
19+
"hacking",
20+
"malware",
21+
"mixed_adult",
22+
"sect",
23+
}
24+
25+
26+
class OSCARFilter(BaseFilter):
27+
name = "🗑 OSCAR"
28+
29+
def __init__(self, regex_exp: str,
30+
exclusion_writer: DiskWriter = None,
31+
min_harmful_ppl: float = DEFAULT_OSCAR_MIN_HARMFUL_PP,
32+
max_harmful_ppl: float = DEFAULT_OSCAR_MAX_HARMFUL_PP,
33+
exclude_categories: set = DEFAULT_EXCLUDE_CATEGORIES):
34+
"""
35+
filters if regex finds at least one match
36+
37+
Args:
38+
regex_exp: regex expression
39+
exclusion_writer:
40+
"""
41+
super().__init__(exclusion_writer)
42+
self.regex = re.compile(regex_exp)
43+
self.min_harmful_ppl = min_harmful_ppl
44+
self.max_harmful_ppl = max_harmful_ppl
45+
self.exclude_categories = exclude_categories
46+
47+
def filter(self, doc: Document) -> bool | tuple[bool, str]:
48+
"""Args:
49+
doc: document
50+
51+
Returns:
52+
is_filter
53+
"""
54+
if doc['metadata']['oscar_quality_warnings']:
55+
return False, 'oscar_quality_warning'
56+
if doc['metadata']['harmful_pp'] and doc['metadata']['harmful_pp'] < self.min_harmful_ppl:
57+
return False, 'kenlm_min_harmful_ppl'
58+
if doc['metadata']['harmful_pp'] and doc['metadata']['harmful_pp'] > self.max_harmful_ppl:
59+
return False, 'kenlm_max_harmful_ppl'
60+
if doc['medatdata']['oscar_categories'] and len(set(doc['medatdata']['oscar_categories']) & self.exclude_categories) > 0:
61+
return False, 'oscar_category'
62+
return True

0 commit comments

Comments
 (0)