Skip to content

Commit 5ee3bdd

Browse files
authored
refactor: Clean up par_map a bit (#742)
* Adjust safety argument of `par_map` * More parallel APIs * assert `ParallelDb` `Send` promise * fix: Fix `par_map` unsoundness * Add more parallel API tests
1 parent 35cdd67 commit 5ee3bdd

File tree

8 files changed

+304
-58
lines changed

8 files changed

+304
-58
lines changed

Diff for: src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ mod key;
2121
mod memo_ingredient_indices;
2222
mod nonce;
2323
#[cfg(feature = "rayon")]
24-
mod par_map;
24+
mod parallel;
2525
mod revision;
2626
mod runtime;
2727
mod salsa_struct;
@@ -52,7 +52,7 @@ pub use self::update::Update;
5252
pub use self::zalsa::IngredientIndex;
5353
pub use crate::attach::with_attached_database;
5454
#[cfg(feature = "rayon")]
55-
pub use par_map::par_map;
55+
pub use parallel::{join, par_map, scope, Scope};
5656
#[cfg(feature = "macros")]
5757
pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update};
5858

Diff for: src/par_map.rs

-54
This file was deleted.

Diff for: src/parallel.rs

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
2+
3+
use crate::Database;
4+
5+
pub fn par_map<Db, F, T, R, C>(db: &Db, inputs: impl IntoParallelIterator<Item = T>, op: F) -> C
6+
where
7+
Db: Database + ?Sized,
8+
F: Fn(&Db, T) -> R + Sync + Send,
9+
T: Send,
10+
R: Send + Sync,
11+
C: FromParallelIterator<R>,
12+
{
13+
inputs
14+
.into_par_iter()
15+
.map_with(DbForkOnClone(db.fork_db()), |db, element| {
16+
op(db.0.as_view(), element)
17+
})
18+
.collect()
19+
}
20+
21+
struct DbForkOnClone(Box<dyn Database>);
22+
23+
impl Clone for DbForkOnClone {
24+
fn clone(&self) -> Self {
25+
DbForkOnClone(self.0.fork_db())
26+
}
27+
}
28+
29+
pub struct Scope<'scope, 'local, Db: Database + ?Sized> {
30+
db: &'local Db,
31+
base: &'local rayon::Scope<'scope>,
32+
}
33+
34+
impl<'scope, 'local, Db: Database + ?Sized> Scope<'scope, 'local, Db> {
35+
pub fn spawn<BODY>(&self, body: BODY)
36+
where
37+
BODY: for<'l> FnOnce(&'l Scope<'scope, 'l, Db>) + Send + 'scope,
38+
{
39+
let db = self.db.fork_db();
40+
self.base.spawn(move |scope| {
41+
let scope = Scope {
42+
db: db.as_view::<Db>(),
43+
base: scope,
44+
};
45+
body(&scope)
46+
})
47+
}
48+
49+
pub fn db(&self) -> &'local Db {
50+
self.db
51+
}
52+
}
53+
54+
pub fn scope<'scope, Db: Database + ?Sized, OP, R>(db: &Db, op: OP) -> R
55+
where
56+
OP: FnOnce(&Scope<'scope, '_, Db>) -> R + Send,
57+
R: Send,
58+
{
59+
rayon::in_place_scope(move |s| op(&Scope { db, base: s }))
60+
}
61+
62+
pub fn join<A, B, RA, RB, Db: Database + ?Sized>(db: &Db, a: A, b: B) -> (RA, RB)
63+
where
64+
A: FnOnce(&Db) -> RA + Send,
65+
B: FnOnce(&Db) -> RB + Send,
66+
RA: Send,
67+
RB: Send,
68+
{
69+
// we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get
70+
// moved to another thread before the closure is executed
71+
let db_a = db.fork_db();
72+
let db_b = db.fork_db();
73+
rayon::join(
74+
move || a(db_a.as_view::<Db>()),
75+
move || b(db_b.as_view::<Db>()),
76+
)
77+
}

Diff for: tests/parallel/main.rs

+2
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ mod cycle_ab_peeping_c;
55
mod cycle_nested_three_threads;
66
mod cycle_panic;
77
mod parallel_cancellation;
8+
mod parallel_join;
89
mod parallel_map;
10+
mod parallel_scope;
911
mod signal;

Diff for: tests/parallel/parallel_join.rs

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#![cfg(feature = "rayon")]
2+
// test for rayon-like join interactions.
3+
4+
use salsa::Cancelled;
5+
use salsa::Setter;
6+
7+
use crate::setup::Knobs;
8+
use crate::setup::KnobsDatabase;
9+
10+
#[salsa::input]
11+
struct ParallelInput {
12+
a: u32,
13+
b: u32,
14+
}
15+
16+
#[salsa::tracked]
17+
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) {
18+
salsa::join(db, |db| input.a(db) + 1, |db| input.b(db) - 1)
19+
}
20+
21+
#[salsa::tracked]
22+
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) {
23+
db.signal(1);
24+
salsa::join(
25+
db,
26+
|db| {
27+
db.wait_for(2);
28+
input.a(db) + dummy(db)
29+
},
30+
|db| {
31+
db.wait_for(2);
32+
input.b(db) + dummy(db)
33+
},
34+
)
35+
}
36+
37+
#[salsa::tracked]
38+
fn dummy(_db: &dyn KnobsDatabase) -> u32 {
39+
panic!("should never get here!")
40+
}
41+
42+
#[test]
43+
#[cfg_attr(miri, ignore)]
44+
fn execute() {
45+
let db = salsa::DatabaseImpl::new();
46+
47+
let input = ParallelInput::new(&db, 10, 20);
48+
49+
tracked_fn(&db, input);
50+
}
51+
52+
// we expect this to panic, as `salsa::par_map` needs to be called from a query.
53+
#[test]
54+
#[cfg_attr(miri, ignore)]
55+
#[should_panic]
56+
fn direct_calls_panic() {
57+
let db = salsa::DatabaseImpl::new();
58+
59+
let input = ParallelInput::new(&db, 10, 20);
60+
let (_, _) = salsa::join(&db, |db| input.a(db) + 1, |db| input.b(db) - 1);
61+
}
62+
63+
// Cancellation signalling test
64+
//
65+
// The pattern is as follows.
66+
//
67+
// Thread A Thread B
68+
// -------- --------
69+
// a1
70+
// | wait for stage 1
71+
// signal stage 1 set input, triggers cancellation
72+
// wait for stage 2 (blocks) triggering cancellation sends stage 2
73+
// |
74+
// (unblocked)
75+
// dummy
76+
// panics
77+
78+
#[test]
79+
#[cfg_attr(miri, ignore)]
80+
fn execute_cancellation() {
81+
let mut db = Knobs::default();
82+
83+
let input = ParallelInput::new(&db, 10, 20);
84+
85+
let thread_a = std::thread::spawn({
86+
let db = db.clone();
87+
move || a1(&db, input)
88+
});
89+
90+
db.signal_on_did_cancel(2);
91+
input.set_a(&mut db).to(30);
92+
93+
// Assert thread A was cancelled
94+
let cancelled = thread_a
95+
.join()
96+
.unwrap_err()
97+
.downcast::<Cancelled>()
98+
.unwrap();
99+
100+
// and inspect the output
101+
expect_test::expect![[r#"
102+
PendingWrite
103+
"#]]
104+
.assert_debug_eq(&cancelled);
105+
}

Diff for: tests/parallel/parallel_map.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#![cfg(feature = "rayon")]
2-
// test for rayon interactions.
2+
// test for rayon-like parallel map interactions.
33

44
use salsa::Cancelled;
55
use salsa::Setter;

Diff for: tests/parallel/parallel_scope.rs

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#![cfg(feature = "rayon")]
2+
// test for rayon-like scope interactions.
3+
4+
use salsa::Cancelled;
5+
use salsa::Setter;
6+
7+
use crate::setup::Knobs;
8+
use crate::setup::KnobsDatabase;
9+
10+
#[salsa::input]
11+
struct ParallelInput {
12+
a: u32,
13+
b: u32,
14+
}
15+
16+
#[salsa::tracked]
17+
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) {
18+
let mut a = None;
19+
let mut b = None;
20+
salsa::scope(db, |scope| {
21+
scope.spawn(|scope| a = Some(input.a(scope.db()) + 1));
22+
scope.spawn(|scope| b = Some(input.b(scope.db()) + 1));
23+
});
24+
(a.unwrap(), b.unwrap())
25+
}
26+
27+
#[salsa::tracked]
28+
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) {
29+
db.signal(1);
30+
let mut a = None;
31+
let mut b = None;
32+
salsa::scope(db, |scope| {
33+
scope.spawn(|scope| {
34+
scope.db().wait_for(2);
35+
a = Some(input.a(scope.db()) + 1)
36+
});
37+
scope.spawn(|scope| {
38+
scope.db().wait_for(2);
39+
b = Some(input.b(scope.db()) + 1)
40+
});
41+
});
42+
(a.unwrap(), b.unwrap())
43+
}
44+
45+
#[salsa::tracked]
46+
fn dummy(_db: &dyn KnobsDatabase) -> u32 {
47+
panic!("should never get here!")
48+
}
49+
50+
#[test]
51+
#[cfg_attr(miri, ignore)]
52+
fn execute() {
53+
let db = salsa::DatabaseImpl::new();
54+
55+
let input = ParallelInput::new(&db, 10, 20);
56+
57+
tracked_fn(&db, input);
58+
}
59+
60+
// we expect this to panic, as `salsa::par_map` needs to be called from a query.
61+
#[test]
62+
#[cfg_attr(miri, ignore)]
63+
#[should_panic]
64+
fn direct_calls_panic() {
65+
let db = salsa::DatabaseImpl::new();
66+
67+
let input = ParallelInput::new(&db, 10, 20);
68+
salsa::scope(&db, |scope| {
69+
scope.spawn(|scope| _ = input.a(scope.db()) + 1);
70+
scope.spawn(|scope| _ = input.b(scope.db()) + 1);
71+
});
72+
}
73+
74+
// Cancellation signalling test
75+
//
76+
// The pattern is as follows.
77+
//
78+
// Thread A Thread B
79+
// -------- --------
80+
// a1
81+
// | wait for stage 1
82+
// signal stage 1 set input, triggers cancellation
83+
// wait for stage 2 (blocks) triggering cancellation sends stage 2
84+
// |
85+
// (unblocked)
86+
// dummy
87+
// panics
88+
89+
#[test]
90+
#[cfg_attr(miri, ignore)]
91+
fn execute_cancellation() {
92+
let mut db = Knobs::default();
93+
94+
let input = ParallelInput::new(&db, 10, 20);
95+
96+
let thread_a = std::thread::spawn({
97+
let db = db.clone();
98+
move || a1(&db, input)
99+
});
100+
101+
db.signal_on_did_cancel(2);
102+
input.set_a(&mut db).to(30);
103+
104+
// Assert thread A was cancelled
105+
let cancelled = thread_a
106+
.join()
107+
.unwrap_err()
108+
.downcast::<Cancelled>()
109+
.unwrap();
110+
111+
// and inspect the output
112+
expect_test::expect![[r#"
113+
PendingWrite
114+
"#]]
115+
.assert_debug_eq(&cancelled);
116+
}

Diff for: tests/parallel/setup.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::signal::Signal;
1111
/// a certain behavior.
1212
#[salsa::db]
1313
pub(crate) trait KnobsDatabase: Database {
14-
/// Signal that we are entering stage 1.
14+
/// Signal that we are entering stage `stage`.
1515
fn signal(&self, stage: usize);
1616

1717
/// Wait until we reach stage `stage` (no-op if we have already reached that stage).

0 commit comments

Comments
 (0)