Skip to content

Commit c8e37c1

Browse files
authored
attempt fix
1 parent 56e95f6 commit c8e37c1

File tree

1 file changed

+119
-30
lines changed

1 file changed

+119
-30
lines changed

src/pg/mod.rs

Lines changed: 119 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use futures_util::stream::{BoxStream, TryStreamExt};
2424
use futures_util::TryFutureExt;
2525
use futures_util::{Future, FutureExt, StreamExt};
2626
use std::borrow::Cow;
27+
use std::collections::HashMap;
2728
use std::sync::Arc;
2829
use tokio::sync::broadcast;
2930
use tokio::sync::oneshot;
@@ -367,9 +368,28 @@ impl AsyncPgConnection {
367368
//
368369
// We apply this workaround to prevent requiring all the diesel
369370
// serialization code to beeing async
370-
let mut metadata_lookup = PgAsyncMetadataLookup::new();
371-
let collect_bind_result =
372-
query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg);
371+
let mut dummy_lookup = SameOidEveryTime {
372+
first_byte: 0,
373+
};
374+
let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
375+
let collect_bind_result_0 = query.collect_binds(&mut bind_collector_0, &mut dummy_lookup, &Pg);
376+
377+
dummy_lookup.first_byte = 1;
378+
let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
379+
let collect_bind_result_1 = query.collect_binds(&mut bind_collector_1, &mut dummy_lookup, &Pg);
380+
381+
let mut metadata_lookup = PgAsyncMetadataLookup::new(&bind_collector_0.metadata);
382+
let collect_bind_result = query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg);
383+
384+
let fake_oid_locations = std::iter::zip(bind_collector_0.binds, bind_collector_1.binds)
385+
.enumerate()
386+
.flat_map(|(bind_index, (bytes_0, bytes_1))|) {
387+
std::iter::zip(bytes_0.unwrap_or_default(), bytes_1.unwrap_or_default())
388+
.enumerate()
389+
.filter_map(|(byte_index, bytes)| (bytes == (0, 1)).then_some((bind_index, byte_index)))
390+
}
391+
// Avoid storing the bind collectors in the returned Future
392+
.collect::<Vec<_>>();
373393

374394
let raw_connection = self.conn.clone();
375395
let stmt_cache = self.stmt_cache.clone();
@@ -379,41 +399,67 @@ impl AsyncPgConnection {
379399
async move {
380400
let sql = sql?;
381401
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
402+
collect_bind_result_0?;
403+
collect_bind_result_1?;
382404
collect_bind_result?;
383405
// Check whether we need to resolve some types at all
384406
//
385407
// If the user doesn't use custom types there is no need
386408
// to borther with that at all
387409
if !metadata_lookup.unresolved_types.is_empty() {
388410
let metadata_cache = &mut *metadata_cache.lock().await;
389-
let mut next_unresolved = metadata_lookup.unresolved_types.into_iter();
390-
for m in &mut bind_collector.metadata {
411+
let real_oids = HashMap::<u32, u32>::new();
412+
413+
for (index, (ref schema, ref lookup_type_name) in metadata_lookup.unresolved_types.into_iter().enumerate() {
391414
// for each unresolved item
392415
// we check whether it's arleady in the cache
393416
// or perform a lookup and insert it into the cache
394-
if m.oid().is_err() {
395-
if let Some((ref schema, ref lookup_type_name)) = next_unresolved.next() {
396-
let cache_key = PgMetadataCacheKey::new(
397-
schema.as_ref().map(Into::into),
398-
lookup_type_name.into(),
399-
);
400-
if let Some(entry) = metadata_cache.lookup_type(&cache_key) {
401-
*m = entry;
402-
} else {
403-
let type_metadata = lookup_type(
404-
schema.clone(),
405-
lookup_type_name.clone(),
406-
&raw_connection,
407-
)
408-
.await?;
409-
*m = PgTypeMetadata::from_result(Ok(type_metadata));
410-
411-
metadata_cache.store_type(cache_key, type_metadata);
412-
}
413-
} else {
414-
break;
415-
}
416-
}
417+
let cache_key = PgMetadataCacheKey::new(
418+
schema.as_ref().map(Into::into),
419+
lookup_type_name.into(),
420+
);
421+
let real_metadata = if let Some(type_metadata) = metadata_cache.lookup_type(&cache_key) {
422+
type_metadata
423+
} else {
424+
let type_metadata = lookup_type(
425+
schema.clone(),
426+
lookup_type_name.clone(),
427+
&raw_connection,
428+
)
429+
.await?;
430+
metadata_cache.store_type(cache_key, type_metadata);
431+
432+
PgTypeMetadata::from_result(Ok(type_metadata))
433+
};
434+
let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
435+
real_oids.extend([
436+
(fake_oid, real_metadata.oid()?),
437+
(fake_array_oid, real_metadata.array_oid()?),
438+
]);
439+
}
440+
441+
// Replace fake OIDs with real OIDs in `bind_collector.metadata`
442+
for m in &mut bind_collector.metadata {
443+
let [oid, array_oid] = [m.oid()?, m.array_oid()?]
444+
.map(|oid| {
445+
real_oids
446+
.get(&oid)
447+
.copied()
448+
// If `oid` is not a key in `real_oids`, then `HasSqlType::metadata` returned it as a
449+
// hardcoded value instead of being lied to by `PgAsyncMetadataLookup`. In this case,
450+
// the existing value is already the real OID, so it's kept.
451+
.unwrap_or(oid)
452+
});
453+
*m = PgTypeMetadata::new(oid, array_oid);
454+
}
455+
// Replace fake OIDs with real OIDs in `bind_collector.binds`
456+
for location in fake_oid_locations {
457+
replace_fake_oid(&mut bind_collector.binds, &real_oids, location)
458+
.ok_or_else(|| {
459+
Error::SerializationError(
460+
format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
461+
)
462+
});
417463
}
418464
}
419465
let key = match query_id {
@@ -452,16 +498,30 @@ impl AsyncPgConnection {
452498
}
453499
}
454500

501+
/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
502+
/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
455503
struct PgAsyncMetadataLookup {
456504
unresolved_types: Vec<(Option<String>, String)>,
505+
min_fake_oid: u32,
457506
}
458507

459508
impl PgAsyncMetadataLookup {
460-
fn new() -> Self {
509+
fn new(metadata_0: &[PgTypeMetadata]) -> Self {
510+
let max_hardcoded_oid = metadata_0
511+
.iter()
512+
.flat_map(|m| [m.oid().unwrap_or(0), m.array_oid().unwrap_or(0)])
513+
.max()
514+
.unwrap_or(0);
461515
Self {
462516
unresolved_types: Vec::new(),
517+
min_fake_oid: max_hardcoded_oid + 1,
463518
}
464519
}
520+
521+
fn fake_oids(&self, index: usize) -> (u32, u32) {
522+
let oid = self.min_fake_oid + ((index as u32) * 2);
523+
(oid, oid + 1)
524+
}
465525
}
466526

467527
impl PgMetadataLookup for PgAsyncMetadataLookup {
@@ -470,9 +530,24 @@ impl PgMetadataLookup for PgAsyncMetadataLookup {
470530
PgMetadataCacheKey::new(schema.map(Cow::Borrowed), Cow::Borrowed(type_name));
471531

472532
let cache_key = cache_key.into_owned();
533+
let index = self.unresolved_types.len();
473534
self.unresolved_types
474535
.push((schema.map(ToOwned::to_owned), type_name.to_owned()));
475-
PgTypeMetadata::from_result(Err(FailedToLookupTypeError::new(cache_key)))
536+
PgTypeMetadata::from_result(Ok(self.fake_oids(index)))
537+
}
538+
}
539+
540+
/// Allows unambiguously determining:
541+
/// * where OIDs are written in `bind_collector.binds` after being returned by `lookup_type`
542+
/// * determining the maximum hardcoded OID in `bind_collector.metadata`
543+
struct SameOidEveryTime {
544+
first_byte: u8,
545+
}
546+
547+
impl PgMetadataLookup for SameOidEveryTime {
548+
fn lookup_type(&mut self, _type_name: &str, _schema: Option<&str>) -> PgTypeMetadata {
549+
let oid = u32::from_be_bytes([self.first_byte, 0, 0, 0]);
550+
PgTypeMetadata::new(oid, oid)
476551
}
477552
}
478553

@@ -506,6 +581,20 @@ async fn lookup_type(
506581
Ok((r.get(0), r.get(1)))
507582
}
508583

584+
fn replace_fake_oid(
585+
binds: &mut Vec<Option<Vec<u8>>>,
586+
real_oids: HashMap<u32, u32>,
587+
(bind_index, byte_index): (u32, u32),
588+
) -> Option<()> {
589+
let serialized_oid = binds
590+
.get_mut(bind_index)?
591+
.as_mut()?
592+
.get_mut(byte_index..)?
593+
.first_chunk_mut::<4>()?;
594+
*serialized_oid = real_oids.get(&u32::from_be_bytes(*serialized_oid))?.to_be_bytes();
595+
Some(())
596+
}
597+
509598
async fn drive_future<R>(
510599
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
511600
client_future: impl Future<Output = Result<R, diesel::result::Error>>,

0 commit comments

Comments
 (0)