@@ -24,6 +24,7 @@ use futures_util::stream::{BoxStream, TryStreamExt};
24
24
use futures_util:: TryFutureExt ;
25
25
use futures_util:: { Future , FutureExt , StreamExt } ;
26
26
use std:: borrow:: Cow ;
27
+ use std:: collections:: HashMap ;
27
28
use std:: sync:: Arc ;
28
29
use tokio:: sync:: broadcast;
29
30
use tokio:: sync:: oneshot;
@@ -367,9 +368,28 @@ impl AsyncPgConnection {
367
368
//
368
369
// We apply this workaround to prevent requiring all the diesel
369
370
// serialization code to beeing async
370
- let mut metadata_lookup = PgAsyncMetadataLookup :: new ( ) ;
371
- let collect_bind_result =
372
- query. collect_binds ( & mut bind_collector, & mut metadata_lookup, & Pg ) ;
371
+ let mut dummy_lookup = SameOidEveryTime {
372
+ first_byte : 0 ,
373
+ } ;
374
+ let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
375
+ let collect_bind_result_0 = query. collect_binds ( & mut bind_collector_0, & mut dummy_lookup, & Pg ) ;
376
+
377
+ dummy_lookup. first_byte = 1 ;
378
+ let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
379
+ let collect_bind_result_1 = query. collect_binds ( & mut bind_collector_1, & mut dummy_lookup, & Pg ) ;
380
+
381
+ let mut metadata_lookup = PgAsyncMetadataLookup :: new ( & bind_collector_0. metadata ) ;
382
+ let collect_bind_result = query. collect_binds ( & mut bind_collector, & mut metadata_lookup, & Pg ) ;
383
+
384
+ let fake_oid_locations = std:: iter:: zip ( bind_collector_0. binds , bind_collector_1. binds )
385
+ . enumerate ( )
386
+ . flat_map ( |( bind_index, ( bytes_0, bytes_1) ) |) {
387
+ std:: iter:: zip ( bytes_0. unwrap_or_default ( ) , bytes_1. unwrap_or_default ( ) )
388
+ . enumerate ( )
389
+ . filter_map ( |( byte_index, bytes) | ( bytes == ( 0 , 1 ) ) . then_some ( ( bind_index, byte_index) ) )
390
+ }
391
+ // Avoid storing the bind collectors in the returned Future
392
+ . collect :: < Vec < _ > > ( ) ;
373
393
374
394
let raw_connection = self . conn . clone ( ) ;
375
395
let stmt_cache = self . stmt_cache . clone ( ) ;
@@ -379,41 +399,67 @@ impl AsyncPgConnection {
379
399
async move {
380
400
let sql = sql?;
381
401
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
402
+ collect_bind_result_0?;
403
+ collect_bind_result_1?;
382
404
collect_bind_result?;
383
405
// Check whether we need to resolve some types at all
384
406
//
385
407
// If the user doesn't use custom types there is no need
386
408
// to borther with that at all
387
409
if !metadata_lookup. unresolved_types . is_empty ( ) {
388
410
let metadata_cache = & mut * metadata_cache. lock ( ) . await ;
389
- let mut next_unresolved = metadata_lookup. unresolved_types . into_iter ( ) ;
390
- for m in & mut bind_collector. metadata {
411
+ let real_oids = HashMap :: < u32 , u32 > :: new ( ) ;
412
+
413
+ for ( index, ( ref schema, ref lookup_type_name) in metadata_lookup. unresolved_types . into_iter ( ) . enumerate ( ) {
391
414
// for each unresolved item
392
415
// we check whether it's arleady in the cache
393
416
// or perform a lookup and insert it into the cache
394
- if m. oid ( ) . is_err ( ) {
395
- if let Some ( ( ref schema, ref lookup_type_name) ) = next_unresolved. next ( ) {
396
- let cache_key = PgMetadataCacheKey :: new (
397
- schema. as_ref ( ) . map ( Into :: into) ,
398
- lookup_type_name. into ( ) ,
399
- ) ;
400
- if let Some ( entry) = metadata_cache. lookup_type ( & cache_key) {
401
- * m = entry;
402
- } else {
403
- let type_metadata = lookup_type (
404
- schema. clone ( ) ,
405
- lookup_type_name. clone ( ) ,
406
- & raw_connection,
407
- )
408
- . await ?;
409
- * m = PgTypeMetadata :: from_result ( Ok ( type_metadata) ) ;
410
-
411
- metadata_cache. store_type ( cache_key, type_metadata) ;
412
- }
413
- } else {
414
- break ;
415
- }
416
- }
417
+ let cache_key = PgMetadataCacheKey :: new (
418
+ schema. as_ref ( ) . map ( Into :: into) ,
419
+ lookup_type_name. into ( ) ,
420
+ ) ;
421
+ let real_metadata = if let Some ( type_metadata) = metadata_cache. lookup_type ( & cache_key) {
422
+ type_metadata
423
+ } else {
424
+ let type_metadata = lookup_type (
425
+ schema. clone ( ) ,
426
+ lookup_type_name. clone ( ) ,
427
+ & raw_connection,
428
+ )
429
+ . await ?;
430
+ metadata_cache. store_type ( cache_key, type_metadata) ;
431
+
432
+ PgTypeMetadata :: from_result ( Ok ( type_metadata) )
433
+ } ;
434
+ let ( fake_oid, fake_array_oid) = metadata_lookup. fake_oids ( index) ;
435
+ real_oids. extend ( [
436
+ ( fake_oid, real_metadata. oid ( ) ?) ,
437
+ ( fake_array_oid, real_metadata. array_oid ( ) ?) ,
438
+ ] ) ;
439
+ }
440
+
441
+ // Replace fake OIDs with real OIDs in `bind_collector.metadata`
442
+ for m in & mut bind_collector. metadata {
443
+ let [ oid, array_oid] = [ m. oid ( ) ?, m. array_oid ( ) ?]
444
+ . map ( |oid| {
445
+ real_oids
446
+ . get ( & oid)
447
+ . copied ( )
448
+ // If `oid` is not a key in `real_oids`, then `HasSqlType::metadata` returned it as a
449
+ // hardcoded value instead of being lied to by `PgAsyncMetadataLookup`. In this case,
450
+ // the existing value is already the real OID, so it's kept.
451
+ . unwrap_or ( oid)
452
+ } ) ;
453
+ * m = PgTypeMetadata :: new ( oid, array_oid) ;
454
+ }
455
+ // Replace fake OIDs with real OIDs in `bind_collector.binds`
456
+ for location in fake_oid_locations {
457
+ replace_fake_oid ( & mut bind_collector. binds , & real_oids, location)
458
+ . ok_or_else ( || {
459
+ Error :: SerializationError (
460
+ format ! ( "diesel_async failed to replace a type OID serialized in bind value {bind_index}" ) . into ( ) ,
461
+ )
462
+ } ) ;
417
463
}
418
464
}
419
465
let key = match query_id {
@@ -452,16 +498,30 @@ impl AsyncPgConnection {
452
498
}
453
499
}
454
500
501
+ /// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
502
+ /// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
455
503
struct PgAsyncMetadataLookup {
456
504
unresolved_types : Vec < ( Option < String > , String ) > ,
505
+ min_fake_oid : u32 ,
457
506
}
458
507
459
508
impl PgAsyncMetadataLookup {
460
- fn new ( ) -> Self {
509
+ fn new ( metadata_0 : & [ PgTypeMetadata ] ) -> Self {
510
+ let max_hardcoded_oid = metadata_0
511
+ . iter ( )
512
+ . flat_map ( |m| [ m. oid ( ) . unwrap_or ( 0 ) , m. array_oid ( ) . unwrap_or ( 0 ) ] )
513
+ . max ( )
514
+ . unwrap_or ( 0 ) ;
461
515
Self {
462
516
unresolved_types : Vec :: new ( ) ,
517
+ min_fake_oid : max_hardcoded_oid + 1 ,
463
518
}
464
519
}
520
+
521
+ fn fake_oids ( & self , index : usize ) -> ( u32 , u32 ) {
522
+ let oid = self . min_fake_oid + ( ( index as u32 ) * 2 ) ;
523
+ ( oid, oid + 1 )
524
+ }
465
525
}
466
526
467
527
impl PgMetadataLookup for PgAsyncMetadataLookup {
@@ -470,9 +530,24 @@ impl PgMetadataLookup for PgAsyncMetadataLookup {
470
530
PgMetadataCacheKey :: new ( schema. map ( Cow :: Borrowed ) , Cow :: Borrowed ( type_name) ) ;
471
531
472
532
let cache_key = cache_key. into_owned ( ) ;
533
+ let index = self . unresolved_types . len ( ) ;
473
534
self . unresolved_types
474
535
. push ( ( schema. map ( ToOwned :: to_owned) , type_name. to_owned ( ) ) ) ;
475
- PgTypeMetadata :: from_result ( Err ( FailedToLookupTypeError :: new ( cache_key) ) )
536
+ PgTypeMetadata :: from_result ( Ok ( self . fake_oids ( index) ) )
537
+ }
538
+ }
539
+
540
+ /// Allows unambiguously determining:
541
+ /// * where OIDs are written in `bind_collector.binds` after being returned by `lookup_type`
542
+ /// * determining the maximum hardcoded OID in `bind_collector.metadata`
543
+ struct SameOidEveryTime {
544
+ first_byte : u8 ,
545
+ }
546
+
547
+ impl PgMetadataLookup for SameOidEveryTime {
548
+ fn lookup_type ( & mut self , _type_name : & str , _schema : Option < & str > ) -> PgTypeMetadata {
549
+ let oid = u32:: from_be_bytes ( [ self . first_byte , 0 , 0 , 0 ] ) ;
550
+ PgTypeMetadata :: new ( oid, oid)
476
551
}
477
552
}
478
553
@@ -506,6 +581,20 @@ async fn lookup_type(
506
581
Ok ( ( r. get ( 0 ) , r. get ( 1 ) ) )
507
582
}
508
583
584
+ fn replace_fake_oid (
585
+ binds : & mut Vec < Option < Vec < u8 > > > ,
586
+ real_oids : HashMap < u32 , u32 > ,
587
+ ( bind_index, byte_index) : ( u32 , u32 ) ,
588
+ ) -> Option < ( ) > {
589
+ let serialized_oid = binds
590
+ . get_mut ( bind_index) ?
591
+ . as_mut ( ) ?
592
+ . get_mut ( byte_index..) ?
593
+ . first_chunk_mut :: < 4 > ( ) ?;
594
+ * serialized_oid = real_oids. get ( & u32:: from_be_bytes ( * serialized_oid) ) ?. to_be_bytes ( ) ;
595
+ Some ( ( ) )
596
+ }
597
+
509
598
async fn drive_future < R > (
510
599
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
511
600
client_future : impl Future < Output = Result < R , diesel:: result:: Error > > ,
0 commit comments