Skip to content

Commit 07d0ead

Browse files
committed
return a NonNull instead of a &'db
In old code, we converted to a `&'db` when creating a new tracked struct or interning, but this value in fact persisted beyond the end of `'db` (i.e., into the new revision). We now refactor so that we create the `Foo<'db>` from a `NonNull<T>` instead of a `&'db T`, and then only create safe references when users access fields. This makes miri happy.
1 parent 8c51f37 commit 07d0ead

File tree

9 files changed

+158
-85
lines changed

9 files changed

+158
-85
lines changed

components/salsa-2022-macros/src/interned.rs

+22-13
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,24 @@ impl InternedStruct {
160160
data_ident: &syn::Ident,
161161
config_ident: &syn::Ident,
162162
) -> syn::ItemImpl {
163+
let the_ident = self.the_ident();
163164
let lt_db = &self.named_db_lifetime();
164165
let (_, _, _, type_generics, _) = self.the_ident_and_generics();
165166
parse_quote_spanned!(
166167
config_ident.span() =>
167168

168169
impl salsa::interned::Configuration for #config_ident {
169170
type Data<#lt_db> = #data_ident #type_generics;
171+
172+
type Struct<#lt_db> = #the_ident < #lt_db >;
173+
174+
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::interned::ValueStruct<Self>>) -> Self::Struct<'db> {
175+
#the_ident(ptr, std::marker::PhantomData)
176+
}
177+
178+
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct<Self> {
179+
unsafe { s.0.as_ref() }
180+
}
170181
}
171182
)
172183
}
@@ -191,21 +202,21 @@ impl InternedStruct {
191202

192203
let field_getters: Vec<syn::ImplItemFn> = self
193204
.all_fields()
194-
.map(|field| {
205+
.map(|field: &crate::salsa_struct::SalsaField| {
195206
let field_name = field.name();
196207
let field_ty = field.ty();
197208
let field_vis = field.vis();
198209
let field_get_name = field.get_name();
199210
if field.is_clone_field() {
200211
parse_quote_spanned! { field_get_name.span() =>
201212
#field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> #field_ty {
202-
std::clone::Clone::clone(&unsafe { &*self.0 }.data().#field_name)
213+
std::clone::Clone::clone(&unsafe { self.0.as_ref() }.data().#field_name)
203214
}
204215
}
205216
} else {
206217
parse_quote_spanned! { field_get_name.span() =>
207218
#field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> & #db_lt #field_ty {
208-
&unsafe { &*self.0 }.data().#field_name
219+
&unsafe { self.0.as_ref() }.data().#field_name
209220
}
210221
}
211222
}
@@ -218,18 +229,15 @@ impl InternedStruct {
218229
let constructor_name = self.constructor_name();
219230
let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() =>
220231
#vis fn #constructor_name(
221-
db: &#db_dyn_ty,
232+
db: &#db_lt #db_dyn_ty,
222233
#(#field_names: #field_tys,)*
223234
) -> Self {
224235
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
225236
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #the_ident #type_generics >>::ingredient(jar);
226-
Self(
227-
ingredients.intern(runtime, #data_ident {
228-
#(#field_names,)*
229-
__phantom: std::marker::PhantomData,
230-
}),
231-
std::marker::PhantomData,
232-
)
237+
ingredients.intern(runtime, #data_ident {
238+
#(#field_names,)*
239+
__phantom: std::marker::PhantomData,
240+
})
233241
}
234242
};
235243

@@ -262,6 +270,7 @@ impl InternedStruct {
262270
self.the_ident_and_generics();
263271
let db_dyn_ty = self.db_dyn_ty();
264272
let jar_ty = self.jar_ty();
273+
let db_lt = self.named_db_lifetime();
265274

266275
let field_getters: Vec<syn::ImplItemFn> = self
267276
.all_fields()
@@ -296,7 +305,7 @@ impl InternedStruct {
296305
let constructor_name = self.constructor_name();
297306
let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() =>
298307
#vis fn #constructor_name(
299-
db: &#db_dyn_ty,
308+
db: & #db_lt #db_lt #db_dyn_ty,
300309
#(#field_names: #field_tys,)*
301310
) -> Self {
302311
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
@@ -384,7 +393,7 @@ impl InternedStruct {
384393
fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self {
385394
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
386395
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
387-
Self(ingredients.interned_value(id), std::marker::PhantomData)
396+
ingredients.interned_value(id)
388397
}
389398
}
390399
})

components/salsa-2022-macros/src/salsa_struct.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
349349
#(#attrs)*
350350
#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
351351
#visibility struct #ident #generics (
352-
*const salsa::#module::ValueStruct < #config_ident >,
352+
std::ptr::NonNull<salsa::#module::ValueStruct < #config_ident >>,
353353
std::marker::PhantomData < & #lifetime salsa::#module::ValueStruct < #config_ident > >
354354
);
355355
})
@@ -360,7 +360,9 @@ impl<A: AllowedOptions> SalsaStruct<A> {
360360
pub(crate) fn access_salsa_id_from_self(&self) -> syn::Expr {
361361
match self.the_struct_kind() {
362362
TheStructKind::Id => parse_quote!(self.0),
363-
TheStructKind::Pointer(_) => parse_quote!(salsa::id::AsId::as_id(unsafe { &*self.0 })),
363+
TheStructKind::Pointer(_) => {
364+
parse_quote!(salsa::id::AsId::as_id(unsafe { self.0.as_ref() }))
365+
}
364366
}
365367
}
366368

@@ -434,7 +436,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
434436
#where_clause
435437
{
436438
fn as_id(&self) -> salsa::Id {
437-
salsa::id::AsId::as_id(unsafe { &*self.0 })
439+
salsa::id::AsId::as_id(unsafe { self.0.as_ref() })
438440
}
439441
}
440442

components/salsa-2022-macros/src/tracked_fn.rs

+10
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ fn interned_configuration_impl(
343343
parse_quote!(
344344
impl salsa::interned::Configuration for #config_ty {
345345
type Data<#db_lt> = #intern_data_ty;
346+
347+
type Struct<#db_lt> = & #db_lt salsa::interned::ValueStruct<Self>;
348+
349+
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::interned::ValueStruct<Self>>) -> Self::Struct<'db> {
350+
unsafe { ptr.as_ref() }
351+
}
352+
353+
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct<Self> {
354+
s
355+
}
346356
}
347357
)
348358
}

components/salsa-2022-macros/src/tracked_struct.rs

+17-11
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ impl TrackedStruct {
9292
let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect();
9393
let id_field_indices = self.id_field_indices();
9494
let arity = self.all_field_count();
95+
let the_ident = self.the_ident();
9596
let lt_db = &self.named_db_lifetime();
9697

9798
// Create the function body that will update the revisions for each field.
@@ -132,8 +133,19 @@ impl TrackedStruct {
132133
parse_quote! {
133134
impl salsa::tracked_struct::Configuration for #config_ident {
134135
type Fields<#lt_db> = ( #(#field_tys,)* );
136+
137+
type Struct<#lt_db> = #the_ident<#lt_db>;
138+
135139
type Revisions = [salsa::Revision; #arity];
136140

141+
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::tracked_struct::ValueStruct<Self>>) -> Self::Struct<'db> {
142+
#the_ident(ptr, std::marker::PhantomData)
143+
}
144+
145+
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::tracked_struct::ValueStruct<Self> {
146+
unsafe { s.0.as_ref() }
147+
}
148+
137149
#[allow(clippy::unused_unit)]
138150
fn id_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash {
139151
( #( &fields.#id_field_indices ),* )
@@ -205,7 +217,7 @@ impl TrackedStruct {
205217
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> & #lt_db #field_ty
206218
{
207219
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
208-
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
220+
let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index);
209221
&fields.#field_index
210222
}
211223
}
@@ -214,7 +226,7 @@ impl TrackedStruct {
214226
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> #field_ty
215227
{
216228
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
217-
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
229+
let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index);
218230
fields.#field_index.clone()
219231
}
220232
}
@@ -232,11 +244,6 @@ impl TrackedStruct {
232244

233245
let salsa_id = self.access_salsa_id_from_self();
234246

235-
let ctor = match the_kind {
236-
TheStructKind::Id => quote!(salsa::id::FromId::from_as_id(#data)),
237-
TheStructKind::Pointer(_) => quote!(Self(#data, std::marker::PhantomData)),
238-
};
239-
240247
let lt_db = self.maybe_elided_db_lifetime();
241248
parse_quote! {
242249
#[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)]
@@ -246,11 +253,10 @@ impl TrackedStruct {
246253
{
247254
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
248255
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< Self >>::ingredient(__jar);
249-
let #data = __ingredients.0.new_struct(
256+
__ingredients.0.new_struct(
250257
__runtime,
251258
(#(#field_names,)*),
252-
);
253-
#ctor
259+
)
254260
}
255261

256262
pub fn salsa_id(&self) -> salsa::Id {
@@ -354,7 +360,7 @@ impl TrackedStruct {
354360
fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self {
355361
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
356362
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
357-
Self(ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id), std::marker::PhantomData)
363+
ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id)
358364
}
359365
}
360366
})

components/salsa-2022/src/alloc.rs

+12-14
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,28 @@ impl<T> Alloc<T> {
1313
data: unsafe { NonNull::new_unchecked(data) },
1414
}
1515
}
16-
}
1716

18-
impl<T> Drop for Alloc<T> {
19-
fn drop(&mut self) {
20-
let data: *mut T = self.data.as_ptr();
21-
let data: Box<T> = unsafe { Box::from_raw(data) };
22-
drop(data);
17+
pub fn as_raw(&self) -> NonNull<T> {
18+
self.data
2319
}
24-
}
2520

26-
impl<T> std::ops::Deref for Alloc<T> {
27-
type Target = T;
28-
29-
fn deref(&self) -> &Self::Target {
21+
pub unsafe fn as_ref(&self) -> &T {
3022
unsafe { self.data.as_ref() }
3123
}
32-
}
3324

34-
impl<T> std::ops::DerefMut for Alloc<T> {
35-
fn deref_mut(&mut self) -> &mut Self::Target {
25+
pub unsafe fn as_mut(&mut self) -> &mut T {
3626
unsafe { self.data.as_mut() }
3727
}
3828
}
3929

30+
impl<T> Drop for Alloc<T> {
31+
fn drop(&mut self) {
32+
let data: *mut T = self.data.as_ptr();
33+
let data: Box<T> = unsafe { Box::from_raw(data) };
34+
drop(data);
35+
}
36+
}
37+
4038
unsafe impl<T> Send for Alloc<T> where T: Send {}
4139

4240
unsafe impl<T> Sync for Alloc<T> where T: Sync {}

components/salsa-2022/src/interned.rs

+30-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crossbeam::atomic::AtomicCell;
22
use std::fmt;
33
use std::hash::Hash;
44
use std::marker::PhantomData;
5+
use std::ptr::NonNull;
56

67
use crate::alloc::Alloc;
78
use crate::durability::Durability;
@@ -18,8 +19,28 @@ use super::ingredient::Ingredient;
1819
use super::routes::IngredientIndex;
1920
use super::Revision;
2021

21-
pub trait Configuration {
22+
pub trait Configuration: Sized {
2223
type Data<'db>: InternedData;
24+
type Struct<'db>: Copy;
25+
26+
/// Create an end-user struct from the underlying raw pointer.
27+
///
28+
/// This call is an "end-step" to the tracked struct lookup/creation
29+
/// process in a given revision: it occurs only when the struct is newly
30+
/// created or, if a struct is being reused, after we have updated its
31+
/// fields (or confirmed it is green and no updates are required).
32+
///
33+
/// # Unsafety
34+
///
35+
/// Requires that `ptr` represents a "confirmed" value in this revision,
36+
/// which means that it will remain valid and immutable for the remainder of this
37+
/// revision, represented by the lifetime `'db`.
38+
unsafe fn struct_from_raw<'db>(ptr: NonNull<ValueStruct<Self>>) -> Self::Struct<'db>;
39+
40+
/// Deref the struct to yield the underlying value struct.
41+
/// Since we are still part of the `'db` lifetime in which the struct was created,
42+
/// this deref is safe, and the value-struct fields are immutable and verified.
43+
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db ValueStruct<Self>;
2344
}
2445

2546
pub trait InternedData: Sized + Eq + Hash + Clone {}
@@ -83,15 +104,11 @@ where
83104
}
84105

85106
pub fn intern_id<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> crate::Id {
86-
self.intern(runtime, data).as_id()
107+
C::deref_struct(self.intern(runtime, data)).as_id()
87108
}
88109

89110
/// Intern data to a unique reference.
90-
pub fn intern<'db>(
91-
&'db self,
92-
runtime: &'db Runtime,
93-
data: C::Data<'db>,
94-
) -> &'db ValueStruct<C> {
111+
pub fn intern<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> C::Struct<'db> {
95112
runtime.report_tracked_read(
96113
DependencyIndex::for_table(self.ingredient_index),
97114
Durability::MAX,
@@ -126,27 +143,27 @@ where
126143
id: next_id,
127144
fields: internal_data,
128145
}));
129-
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
130-
let value_ref = unsafe { transmute_lifetime(self, &**value) };
146+
let value_raw = value.as_raw();
131147
drop(value);
132148
entry.insert(next_id);
133-
value_ref
149+
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
150+
unsafe { C::struct_from_raw(value_raw) }
134151
}
135152
}
136153
}
137154

138-
pub fn interned_value<'db>(&'db self, id: Id) -> &'db ValueStruct<C> {
155+
pub fn interned_value<'db>(&'db self, id: Id) -> C::Struct<'db> {
139156
let r = self.value_map.get(&id).unwrap();
140157

141158
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
142-
unsafe { transmute_lifetime(self, &**r) }
159+
unsafe { C::struct_from_raw(r.as_raw()) }
143160
}
144161

145162
/// Lookup the data for an interned value based on its id.
146163
/// Rarely used since end-users generally carry a struct with a pointer directly
147164
/// to the interned item.
148165
pub fn data<'db>(&'db self, id: Id) -> &'db C::Data<'db> {
149-
self.interned_value(id).data()
166+
C::deref_struct(self.interned_value(id)).data()
150167
}
151168

152169
/// Variant of `data` that takes a (unnecessary) database argument.

0 commit comments

Comments
 (0)