Skip to content

Commit 296a8c7

Browse files
authored
allow reuse of cached provisional memos within the same cycle iteration (#786)
* test for caching provisional values * add iteration-count to cycle heads * CycleHeads insert/extend checks iteration count match * update iteration count in cycle heads * all tests passing * remove debug prints * just walk active query stack once * switch to tracking active cycle iterations on ZalsaLocal * Revert "switch to tracking active cycle iterations on ZalsaLocal" This reverts commit 4ea3d85. * Revert "just walk active query stack once" This reverts commit 2d79486. * make ActiveQuery::iteration_count private with accessor * iterate active query stack in reverse * use tracing::trace! in hot path * try a cold annotation on validate_same_iteration * Revert "try a cold annotation on validate_same_iteration" This reverts commit 49ceb84.
1 parent 395b29d commit 296a8c7

11 files changed

+180
-51
lines changed

src/active_query.rs

+23-5
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ pub(crate) struct ActiveQuery {
6060

6161
/// Provisional cycle results that this query depends on.
6262
cycle_heads: CycleHeads,
63+
64+
/// If this query is a cycle head, iteration count of that cycle.
65+
iteration_count: u32,
6366
}
6467

6568
impl ActiveQuery {
@@ -126,10 +129,14 @@ impl ActiveQuery {
126129
changed_at: self.changed_at,
127130
}
128131
}
132+
133+
pub(super) fn iteration_count(&self) -> u32 {
134+
self.iteration_count
135+
}
129136
}
130137

131138
impl ActiveQuery {
132-
fn new(database_key_index: DatabaseKeyIndex) -> Self {
139+
fn new(database_key_index: DatabaseKeyIndex, iteration_count: u32) -> Self {
133140
ActiveQuery {
134141
database_key_index,
135142
durability: Durability::MAX,
@@ -141,6 +148,7 @@ impl ActiveQuery {
141148
accumulated: Default::default(),
142149
accumulated_inputs: Default::default(),
143150
cycle_heads: Default::default(),
151+
iteration_count,
144152
}
145153
}
146154

@@ -156,6 +164,7 @@ impl ActiveQuery {
156164
ref mut accumulated,
157165
accumulated_inputs,
158166
ref mut cycle_heads,
167+
iteration_count: _,
159168
} = self;
160169

161170
let edges = QueryEdges::new(input_outputs.drain(..));
@@ -196,15 +205,17 @@ impl ActiveQuery {
196205
accumulated,
197206
accumulated_inputs: _,
198207
cycle_heads,
208+
iteration_count,
199209
} = self;
200210
input_outputs.clear();
201211
disambiguator_map.clear();
202212
tracked_struct_ids.clear();
203213
accumulated.clear();
204214
*cycle_heads = Default::default();
215+
*iteration_count = 0;
205216
}
206217

207-
fn reset_for(&mut self, new_database_key_index: DatabaseKeyIndex) {
218+
fn reset_for(&mut self, new_database_key_index: DatabaseKeyIndex, new_iteration_count: u32) {
208219
let Self {
209220
database_key_index,
210221
durability,
@@ -216,12 +227,14 @@ impl ActiveQuery {
216227
accumulated,
217228
accumulated_inputs,
218229
cycle_heads,
230+
iteration_count,
219231
} = self;
220232
*database_key_index = new_database_key_index;
221233
*durability = Durability::MAX;
222234
*changed_at = Revision::start();
223235
*untracked_read = false;
224236
*accumulated_inputs = Default::default();
237+
*iteration_count = new_iteration_count;
225238
debug_assert!(
226239
input_outputs.is_empty(),
227240
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
@@ -266,11 +279,16 @@ impl ops::DerefMut for QueryStack {
266279
}
267280

268281
impl QueryStack {
269-
pub(crate) fn push_new_query(&mut self, database_key_index: DatabaseKeyIndex) {
282+
pub(crate) fn push_new_query(
283+
&mut self,
284+
database_key_index: DatabaseKeyIndex,
285+
iteration_count: u32,
286+
) {
270287
if self.len < self.stack.len() {
271-
self.stack[self.len].reset_for(database_key_index);
288+
self.stack[self.len].reset_for(database_key_index, iteration_count);
272289
} else {
273-
self.stack.push(ActiveQuery::new(database_key_index));
290+
self.stack
291+
.push(ActiveQuery::new(database_key_index, iteration_count));
274292
}
275293
self.len += 1;
276294
}

src/cycle.rs

+63-26
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,19 @@ pub enum CycleRecoveryStrategy {
8686
/// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A
8787
/// would be the cycle head. It returns an "initial value" when the cycle is encountered (if
8888
/// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the
89-
/// cycle until it converges. Any provisional value generated by any query in the cycle will track
90-
/// the cycle head(s) (can be plural in case of nested cycles) representing the cycles it is part
91-
/// of. This struct tracks these cycle heads.
89+
/// cycle until it converges.
90+
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
91+
pub struct CycleHead {
92+
pub database_key_index: DatabaseKeyIndex,
93+
pub iteration_count: u32,
94+
}
95+
96+
/// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be
97+
/// plural in case of nested cycles) representing the cycles it is part of, and the current
98+
/// iteration count for each cycle head. This struct tracks these cycle heads.
9299
#[derive(Clone, Debug, Default)]
93100
#[allow(clippy::box_collection)]
94-
pub struct CycleHeads(Option<Box<Vec<DatabaseKeyIndex>>>);
101+
pub struct CycleHeads(Option<Box<Vec<CycleHead>>>);
95102

96103
impl CycleHeads {
97104
pub(crate) fn is_empty(&self) -> bool {
@@ -100,15 +107,25 @@ impl CycleHeads {
100107
self.0.is_none()
101108
}
102109

110+
pub(crate) fn initial(database_key_index: DatabaseKeyIndex) -> Self {
111+
Self(Some(Box::new(vec![CycleHead {
112+
database_key_index,
113+
iteration_count: 0,
114+
}])))
115+
}
116+
103117
pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool {
104-
self.0.as_ref().is_some_and(|heads| heads.contains(value))
118+
self.into_iter()
119+
.any(|head| head.database_key_index == *value)
105120
}
106121

107122
pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool {
108123
let Some(cycle_heads) = &mut self.0 else {
109124
return false;
110125
};
111-
let found = cycle_heads.iter().position(|&head| head == *value);
126+
let found = cycle_heads
127+
.iter()
128+
.position(|&head| head.database_key_index == *value);
112129
let Some(found) = found else { return false };
113130
cycle_heads.swap_remove(found);
114131
if cycle_heads.is_empty() {
@@ -117,43 +134,63 @@ impl CycleHeads {
117134
true
118135
}
119136

137+
pub(crate) fn update_iteration_count(
138+
&mut self,
139+
cycle_head_index: DatabaseKeyIndex,
140+
new_iteration_count: u32,
141+
) {
142+
if let Some(cycle_head) = self.0.as_mut().and_then(|cycle_heads| {
143+
cycle_heads
144+
.iter_mut()
145+
.find(|cycle_head| cycle_head.database_key_index == cycle_head_index)
146+
}) {
147+
cycle_head.iteration_count = new_iteration_count;
148+
}
149+
}
150+
120151
#[inline]
121-
pub(crate) fn insert_into(self, cycle_heads: &mut Vec<DatabaseKeyIndex>) {
152+
pub(crate) fn insert_into(self, cycle_heads: &mut Vec<CycleHead>) {
122153
if let Some(heads) = self.0 {
123-
for head in *heads {
124-
if !cycle_heads.contains(&head) {
125-
cycle_heads.push(head);
126-
}
127-
}
154+
insert_into_impl(&heads, cycle_heads);
128155
}
129156
}
130157

131158
pub(crate) fn extend(&mut self, other: &Self) {
132159
if let Some(other) = &other.0 {
133160
let heads = &mut **self.0.get_or_insert_with(|| Box::new(Vec::new()));
134-
heads.reserve(other.len());
135-
other.iter().for_each(|&head| {
136-
if !heads.contains(&head) {
137-
heads.push(head);
138-
}
139-
});
161+
insert_into_impl(other, heads);
162+
}
163+
}
164+
}
165+
166+
#[inline]
167+
fn insert_into_impl(insert_from: &Vec<CycleHead>, insert_into: &mut Vec<CycleHead>) {
168+
insert_into.reserve(insert_from.len());
169+
for head in insert_from {
170+
if let Some(existing) = insert_into
171+
.iter()
172+
.find(|candidate| candidate.database_key_index == head.database_key_index)
173+
{
174+
assert!(existing.iteration_count == head.iteration_count);
175+
} else {
176+
insert_into.push(*head);
140177
}
141178
}
142179
}
143180

144181
impl IntoIterator for CycleHeads {
145-
type Item = DatabaseKeyIndex;
182+
type Item = CycleHead;
146183
type IntoIter = <Vec<Self::Item> as IntoIterator>::IntoIter;
147184

148185
fn into_iter(self) -> Self::IntoIter {
149186
self.0.map(|heads| *heads).unwrap_or_default().into_iter()
150187
}
151188
}
152189

153-
pub struct CycleHeadsIter<'a>(std::slice::Iter<'a, DatabaseKeyIndex>);
190+
pub struct CycleHeadsIter<'a>(std::slice::Iter<'a, CycleHead>);
154191

155192
impl Iterator for CycleHeadsIter<'_> {
156-
type Item = DatabaseKeyIndex;
193+
type Item = CycleHead;
157194

158195
fn next(&mut self) -> Option<Self::Item> {
159196
self.0.next().copied()
@@ -167,7 +204,7 @@ impl Iterator for CycleHeadsIter<'_> {
167204
impl std::iter::FusedIterator for CycleHeadsIter<'_> {}
168205

169206
impl<'a> std::iter::IntoIterator for &'a CycleHeads {
170-
type Item = DatabaseKeyIndex;
207+
type Item = CycleHead;
171208
type IntoIter = CycleHeadsIter<'a>;
172209

173210
fn into_iter(self) -> Self::IntoIter {
@@ -180,14 +217,14 @@ impl<'a> std::iter::IntoIterator for &'a CycleHeads {
180217
}
181218
}
182219

183-
impl From<DatabaseKeyIndex> for CycleHeads {
184-
fn from(value: DatabaseKeyIndex) -> Self {
220+
impl From<CycleHead> for CycleHeads {
221+
fn from(value: CycleHead) -> Self {
185222
Self(Some(Box::new(vec![value])))
186223
}
187224
}
188225

189-
impl From<Vec<DatabaseKeyIndex>> for CycleHeads {
190-
fn from(value: Vec<DatabaseKeyIndex>) -> Self {
226+
impl From<Vec<CycleHead>> for CycleHeads {
227+
fn from(value: Vec<CycleHead>) -> Self {
191228
Self(if value.is_empty() {
192229
None
193230
} else {

src/function.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,11 @@ where
243243
fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool {
244244
let zalsa = db.zalsa();
245245
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
246-
.is_some_and(|memo| memo.cycle_heads().contains(&self.database_key_index(input)))
246+
.is_some_and(|memo| {
247+
memo.cycle_heads()
248+
.into_iter()
249+
.any(|head| head.database_key_index == self.database_key_index(input))
250+
})
247251
}
248252

249253
/// Attempts to claim `key_index`, returning `false` if a cycle occurs.

src/function/execute.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,19 @@ where
123123
if iteration_count > MAX_ITERATIONS {
124124
panic!("{database_key_index:?}: execute: too many cycle iterations");
125125
}
126+
revisions
127+
.cycle_heads
128+
.update_iteration_count(database_key_index, iteration_count);
126129
opt_last_provisional = Some(self.insert_memo(
127130
zalsa,
128131
id,
129132
Memo::new(Some(new_value), revision_now, revisions),
130133
memo_ingredient_index,
131134
));
132135

133-
active_query = db.zalsa_local().push_query(database_key_index);
136+
active_query = db
137+
.zalsa_local()
138+
.push_query(database_key_index, iteration_count);
134139

135140
continue;
136141
}

src/function/fetch.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ where
8383
if let Some(memo) = memo_guard {
8484
let database_key_index = self.database_key_index(id);
8585
if memo.value.is_some()
86-
&& self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
86+
&& (self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
87+
|| self.validate_same_iteration(db, database_key_index, memo))
8788
&& self.shallow_verify_memo(db, zalsa, database_key_index, memo)
8889
{
8990
// SAFETY: memo is present in memo_map and we have verified that it is
@@ -158,7 +159,7 @@ where
158159
};
159160

160161
// Push the query on the stack.
161-
let active_query = db.zalsa_local().push_query(database_key_index);
162+
let active_query = db.zalsa_local().push_query(database_key_index, 0);
162163

163164
// Now that we've claimed the item, check again to see if there's a "hot" value.
164165
let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);

src/function/maybe_changed_after.rs

+36-6
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ where
112112
CycleRecoveryStrategy::Fixpoint => {
113113
return Some(VerifyResult::Unchanged(
114114
InputAccumulatedValues::Empty,
115-
CycleHeads::from(database_key_index),
115+
CycleHeads::initial(database_key_index),
116116
));
117117
}
118118
},
@@ -131,7 +131,7 @@ where
131131
);
132132

133133
// Check if the inputs are still valid. We can just compare `changed_at`.
134-
let active_query = db.zalsa_local().push_query(database_key_index);
134+
let active_query = db.zalsa_local().push_query(database_key_index, 0);
135135
if let VerifyResult::Unchanged(_, cycle_heads) =
136136
self.deep_verify_memo(db, zalsa, old_memo, &active_query)
137137
{
@@ -243,14 +243,17 @@ where
243243
database_key_index: DatabaseKeyIndex,
244244
memo: &Memo<C::Output<'_>>,
245245
) -> bool {
246-
tracing::debug!(
246+
tracing::trace!(
247247
"{database_key_index:?}: validate_provisional(memo = {memo:#?})",
248248
memo = memo.tracing_debug()
249249
);
250250
if (&memo.revisions.cycle_heads).into_iter().any(|cycle_head| {
251251
zalsa
252-
.lookup_ingredient(cycle_head.ingredient_index())
253-
.is_provisional_cycle_head(db.as_dyn_database(), cycle_head.key_index())
252+
.lookup_ingredient(cycle_head.database_key_index.ingredient_index())
253+
.is_provisional_cycle_head(
254+
db.as_dyn_database(),
255+
cycle_head.database_key_index.key_index(),
256+
)
254257
}) {
255258
return false;
256259
}
@@ -260,6 +263,33 @@ where
260263
true
261264
}
262265

266+
/// If this is a provisional memo, validate that it was cached in the same iteration of the
267+
/// same cycle(s) that we are still executing. If so, it is valid for reuse. This avoids
268+
/// runaway re-execution of the same queries within a fixpoint iteration.
269+
pub(super) fn validate_same_iteration(
270+
&self,
271+
db: &C::DbView,
272+
database_key_index: DatabaseKeyIndex,
273+
memo: &Memo<C::Output<'_>>,
274+
) -> bool {
275+
tracing::trace!(
276+
"{database_key_index:?}: validate_same_iteration(memo = {memo:#?})",
277+
memo = memo.tracing_debug()
278+
);
279+
for cycle_head in &memo.revisions.cycle_heads {
280+
if !db.zalsa_local().with_query_stack(|stack| {
281+
stack.iter().rev().any(|entry| {
282+
entry.database_key_index == cycle_head.database_key_index
283+
&& entry.iteration_count() == cycle_head.iteration_count
284+
})
285+
}) {
286+
return false;
287+
}
288+
}
289+
290+
true
291+
}
292+
263293
/// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the
264294
/// current revision. When this returns Unchanged with no cycle heads, it also updates the
265295
/// memo's `verified_at` field if needed to make future calls cheaper.
@@ -390,7 +420,7 @@ where
390420

391421
let in_heads = cycle_heads
392422
.iter()
393-
.position(|&head| head == database_key_index)
423+
.position(|&head| head.database_key_index == database_key_index)
394424
.inspect(|&head| _ = cycle_heads.swap_remove(head))
395425
.is_some();
396426

0 commit comments

Comments
 (0)