Skip to content

Commit 1cb81a8

Browse files
committed
Added tests and par_fold_hashmap lint
1 parent 6d628e1 commit 1cb81a8

9 files changed

+247
-6
lines changed

lints/fold/ui/main.fixed

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// run-rustfix
22
fn main() {
33
warn_fold_simple();
4+
warn_fold_vec();
5+
warn_fold_hashmap();
46
get_upload_file_total_size();
57
}
68

@@ -13,6 +15,24 @@ fn warn_fold_simple() {
1315
println!("Sum: {}", sum);
1416
}
1517

18+
fn warn_fold_vec() {
19+
let mut data = vec![];
20+
let numbers = vec![1, 2, 3, 4, 5];
21+
data = numbers.iter().fold(data, |mut data, &num| { data.push(num * 3); data });
22+
23+
println!("Data: {:?}", data);
24+
}
25+
26+
fn warn_fold_hashmap() {
27+
use std::collections::HashMap;
28+
29+
let mut data = HashMap::new();
30+
let numbers = vec![1, 2, 3, 4, 5];
31+
data = numbers.iter().fold(data, |mut data, &num| { data.insert(num, num.to_string()); data });
32+
33+
println!("Data: {:?}", data);
34+
}
35+
1636
fn get_upload_file_total_size() -> u64 {
1737
let some_num = vec![0; 10];
1838
let mut file_total_size = 0;

lints/fold/ui/main.rs

+24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// run-rustfix
22
fn main() {
33
warn_fold_simple();
4+
warn_fold_vec();
5+
warn_fold_hashmap();
46
get_upload_file_total_size();
57
}
68

@@ -15,6 +17,28 @@ fn warn_fold_simple() {
1517
println!("Sum: {}", sum);
1618
}
1719

20+
fn warn_fold_vec() {
21+
let mut data = vec![];
22+
let numbers = vec![1, 2, 3, 4, 5];
23+
numbers.iter().for_each(|&num| {
24+
data.push(num * 3);
25+
});
26+
27+
println!("Data: {:?}", data);
28+
}
29+
30+
fn warn_fold_hashmap() {
31+
use std::collections::HashMap;
32+
33+
let mut data = HashMap::new();
34+
let numbers = vec![1, 2, 3, 4, 5];
35+
numbers.iter().for_each(|&num| {
36+
data.insert(num, num.to_string());
37+
});
38+
39+
println!("Data: {:?}", data);
40+
}
41+
1842
fn get_upload_file_total_size() -> u64 {
1943
let some_num = vec![0; 10];
2044
let mut file_total_size = 0;

lints/fold/ui/main.stderr

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error: implicit fold
2-
--> $DIR/main.rs:11:5
2+
--> $DIR/main.rs:13:5
33
|
44
LL | / numbers.iter().for_each(|&num| {
55
LL | | sum += num;
@@ -10,13 +10,35 @@ LL | | });
1010
= help: to override `-D warnings` add `#[allow(fold_simple)]`
1111

1212
error: implicit fold
13-
--> $DIR/main.rs:21:5
13+
--> $DIR/main.rs:23:5
14+
|
15+
LL | / numbers.iter().for_each(|&num| {
16+
LL | | data.push(num * 3);
17+
LL | | });
18+
| |______^ help: try using `fold` instead: `data = numbers.iter().fold(data, |mut data, &num| { data.push(num * 3); data })`
19+
|
20+
= note: `-D fold-vec` implied by `-D warnings`
21+
= help: to override `-D warnings` add `#[allow(fold_vec)]`
22+
23+
error: implicit fold
24+
--> $DIR/main.rs:35:5
25+
|
26+
LL | / numbers.iter().for_each(|&num| {
27+
LL | | data.insert(num, num.to_string());
28+
LL | | });
29+
| |______^ help: try using `fold` instead: `data = numbers.iter().fold(data, |mut data, &num| { data.insert(num, num.to_string()); data })`
30+
|
31+
= note: `-D fold-hashmap` implied by `-D warnings`
32+
= help: to override `-D warnings` add `#[allow(fold_hashmap)]`
33+
34+
error: implicit fold
35+
--> $DIR/main.rs:45:5
1436
|
1537
LL | / (0..some_num.len()).into_iter().for_each(|_| {
1638
LL | | let (_, upload_size) = (true, 99);
1739
LL | | file_total_size += upload_size;
1840
LL | | });
1941
| |______^ help: try using `fold` instead: `file_total_size += (0..some_num.len()).into_iter().map(|_| {let (_, upload_size) = (true, 99); upload_size}).fold(0, |mut file_total_size, v| { file_total_size += v; file_total_size })`
2042

21-
error: aborting due to 2 previous errors
43+
error: aborting due to 4 previous errors
2244

lints/par_fold/src/hashmap.rs

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use rustc_errors::Applicability;
2+
use rustc_hir::{Expr, ExprKind, StmtKind};
3+
use rustc_lint::{LateContext, LateLintPass, LintContext};
4+
use rustc_session::{declare_lint, declare_lint_pass};
5+
use rustc_span::{sym, Symbol};
6+
use utils::span_to_snippet_macro;
7+
8+
declare_lint! {
9+
pub WARN_PAR_FOLD_HASHMAP,
10+
Warn,
11+
"suggest using parallel fold"
12+
}
13+
14+
declare_lint_pass!(ParFoldHashMap => [WARN_PAR_FOLD_HASHMAP]);
15+
impl<'tcx> LateLintPass<'tcx> for ParFoldHashMap {
16+
fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) {
17+
if let ExprKind::MethodCall(path, recv, args, _span) = &expr.kind
18+
&& path.ident.name == Symbol::intern("fold")
19+
{
20+
assert_eq!(args.len(), 2);
21+
let id_expr = args[0];
22+
let op_expr = args[1];
23+
24+
// Check the penultimate statement of the fold for a `c.push(v)`
25+
// Quite a specific target, can we be more general?
26+
let ExprKind::Closure(op_cls) = op_expr.kind else {
27+
return;
28+
};
29+
let hir_map = cx.tcx.hir();
30+
let cls_body = hir_map.body(op_cls.body);
31+
32+
let Ok(StmtKind::Semi(fold_op)) =
33+
utils::get_penult_stmt(cls_body.value).map(|s| s.kind)
34+
else {
35+
return;
36+
};
37+
38+
let ExprKind::MethodCall(path, _, _, _) = fold_op.kind else {
39+
return;
40+
};
41+
if path.ident.name != Symbol::intern("insert") {
42+
return;
43+
}
44+
45+
// Check that this method is on a hashmap
46+
let base_ty = cx
47+
.tcx
48+
.typeck(expr.hir_id.owner.def_id)
49+
.node_type(id_expr.hir_id);
50+
let Some(adt) = base_ty.ty_adt_def() else {
51+
return;
52+
};
53+
if !cx.tcx.is_diagnostic_item(sym::HashMap, adt.did()) {
54+
return;
55+
}
56+
57+
// Assume that if we make it here, we can apply the pattern.
58+
let src_map = cx.sess().source_map();
59+
let cls_snip = span_to_snippet_macro(src_map, op_expr.span);
60+
let recv_snip = span_to_snippet_macro(src_map, recv.span);
61+
let id_snip = span_to_snippet_macro(src_map, id_expr.span);
62+
63+
let fold_snip = format!("fold(|| HashMap::new(), {cls_snip})");
64+
let reduce_snip = "reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })";
65+
let mut extend_snip =
66+
format!("{{ {id_snip}.extend({recv_snip}.{fold_snip}.{reduce_snip}); {id_snip} }}");
67+
extend_snip = extend_snip.replace(".iter()", ".par_iter()");
68+
extend_snip = extend_snip.replace(".iter_mut()", ".par_iter_mut()");
69+
extend_snip = extend_snip.replace(".into_iter()", ".into_par_iter()");
70+
71+
cx.span_lint(WARN_PAR_FOLD_HASHMAP, expr.span, "sequential fold", |diag| {
72+
diag.span_suggestion(
73+
expr.span,
74+
"try using a parallel fold on the iterator",
75+
extend_snip,
76+
Applicability::MachineApplicable,
77+
);
78+
});
79+
}
80+
}
81+
}

lints/par_fold/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ extern crate rustc_span;
1515

1616
mod par_fold_simple;
1717
mod vec;
18+
mod hashmap;
1819

1920
#[allow(clippy::no_mangle_with_rust_abi)]
2021
#[cfg_attr(not(feature = "rlib"), no_mangle)]
2122
pub fn register_lints(_sess: &rustc_session::Session, lint_store: &mut rustc_lint::LintStore) {
2223
lint_store.register_late_pass(|_| Box::new(par_fold_simple::ParFoldSimple));
2324
lint_store.register_late_pass(|_| Box::new(vec::ParFoldVec));
25+
lint_store.register_late_pass(|_| Box::new(hashmap::ParFoldHashMap));
2426
}
2527

2628
#[test]

lints/par_fold/ui/par_fold_simple.fixed

+26
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use rayon::prelude::*;
55

66
fn main() {
77
warn_fold_simple();
8+
warn_fold_vec();
9+
warn_fold_hashmap();
810
}
911

1012
fn warn_fold_simple() {
@@ -17,3 +19,27 @@ fn warn_fold_simple() {
1719

1820
println!("Sum: {}", sum);
1921
}
22+
23+
fn warn_fold_vec() {
24+
let mut data = vec![];
25+
let numbers = vec![1, 2, 3, 4, 5];
26+
data = { data.extend(numbers.par_iter().fold(|| Vec::new(), |mut data, &num| {
27+
data.push(num * 3);
28+
data
29+
}).reduce(|| Vec::new(), |mut a, b| { a.extend(b); a })); data };
30+
31+
println!("Data: {:?}", data);
32+
}
33+
34+
fn warn_fold_hashmap() {
35+
use std::collections::HashMap;
36+
37+
let mut data = HashMap::new();
38+
let numbers = vec![1, 2, 3, 4, 5];
39+
data = { data.extend(numbers.par_iter().fold(|| HashMap::new(), |mut data, &num| {
40+
data.insert(num, num.to_string());
41+
data
42+
}).reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })); data };
43+
44+
println!("Data: {:?}", data);
45+
}

lints/par_fold/ui/par_fold_simple.rs

+26
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use rayon::prelude::*;
55

66
fn main() {
77
warn_fold_simple();
8+
warn_fold_vec();
9+
warn_fold_hashmap();
810
}
911

1012
fn warn_fold_simple() {
@@ -17,3 +19,27 @@ fn warn_fold_simple() {
1719

1820
println!("Sum: {}", sum);
1921
}
22+
23+
fn warn_fold_vec() {
24+
let mut data = vec![];
25+
let numbers = vec![1, 2, 3, 4, 5];
26+
data = numbers.iter().fold(data, |mut data, &num| {
27+
data.push(num * 3);
28+
data
29+
});
30+
31+
println!("Data: {:?}", data);
32+
}
33+
34+
fn warn_fold_hashmap() {
35+
use std::collections::HashMap;
36+
37+
let mut data = HashMap::new();
38+
let numbers = vec![1, 2, 3, 4, 5];
39+
data = numbers.iter().fold(data, |mut data, &num| {
40+
data.insert(num, num.to_string());
41+
data
42+
});
43+
44+
println!("Data: {:?}", data);
45+
}

lints/par_fold/ui/par_fold_simple.stderr

+42-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error: sequential fold
2-
--> $DIR/par_fold_simple.rs:13:12
2+
--> $DIR/par_fold_simple.rs:15:12
33
|
44
LL | sum += numbers.iter().map(|&num| num).fold(0, |mut sum, v| {
55
| ____________^
@@ -15,5 +15,45 @@ help: try using a parallel fold on the iterator
1515
LL | sum += numbers.par_iter().map(|&num| num).reduce(|| 0, |mut sum, v| {
1616
| ~~~~~~~~ ~~~~~~ ~~~~
1717

18-
error: aborting due to 1 previous error
18+
error: sequential fold
19+
--> $DIR/par_fold_simple.rs:26:12
20+
|
21+
LL | data = numbers.iter().fold(data, |mut data, &num| {
22+
| ____________^
23+
LL | | data.push(num * 3);
24+
LL | | data
25+
LL | | });
26+
| |______^
27+
|
28+
= note: `-D warn-par-fold-vec` implied by `-D warnings`
29+
= help: to override `-D warnings` add `#[allow(warn_par_fold_vec)]`
30+
help: try using a parallel fold on the iterator
31+
|
32+
LL ~ data = { data.extend(numbers.par_iter().fold(|| Vec::new(), |mut data, &num| {
33+
LL + data.push(num * 3);
34+
LL + data
35+
LL ~ }).reduce(|| Vec::new(), |mut a, b| { a.extend(b); a })); data };
36+
|
37+
38+
error: sequential fold
39+
--> $DIR/par_fold_simple.rs:39:12
40+
|
41+
LL | data = numbers.iter().fold(data, |mut data, &num| {
42+
| ____________^
43+
LL | | data.insert(num, num.to_string());
44+
LL | | data
45+
LL | | });
46+
| |______^
47+
|
48+
= note: `-D warn-par-fold-hashmap` implied by `-D warnings`
49+
= help: to override `-D warnings` add `#[allow(warn_par_fold_hashmap)]`
50+
help: try using a parallel fold on the iterator
51+
|
52+
LL ~ data = { data.extend(numbers.par_iter().fold(|| HashMap::new(), |mut data, &num| {
53+
LL + data.insert(num, num.to_string());
54+
LL + data
55+
LL ~ }).reduce(|| HashMap::new(), |mut a, b| { a.extend(b); a })); data };
56+
|
57+
58+
error: aborting due to 3 previous errors
1959

rust-toolchain

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[toolchain]
22
channel = "nightly-2024-02-22"
3-
components = ["llvm-tools-preview", "rustc-dev"]
3+
components = ["llvm-tools-preview", "rustc-dev"]

0 commit comments

Comments
 (0)