@@ -431,6 +431,76 @@ fn get_or_insert() {
431
431
assert_eq ! ( * s. get_or_insert( 6 , 600 , guard) . value( ) , 600 ) ;
432
432
}
433
433
434
+ #[ test]
435
+ fn get_or_insert_with ( ) {
436
+ let guard = & epoch:: pin ( ) ;
437
+ let s = SkipList :: new ( epoch:: default_collector ( ) . clone ( ) ) ;
438
+ s. insert ( 3 , 3 , guard) ;
439
+ s. insert ( 5 , 5 , guard) ;
440
+ s. insert ( 1 , 1 , guard) ;
441
+ s. insert ( 4 , 4 , guard) ;
442
+ s. insert ( 2 , 2 , guard) ;
443
+
444
+ assert_eq ! ( * s. get( & 4 , guard) . unwrap( ) . value( ) , 4 ) ;
445
+ assert_eq ! ( * s. insert( 4 , 40 , guard) . value( ) , 40 ) ;
446
+ assert_eq ! ( * s. get( & 4 , guard) . unwrap( ) . value( ) , 40 ) ;
447
+
448
+ assert_eq ! ( * s. get_or_insert_with( 4 , || 400 , guard) . value( ) , 40 ) ;
449
+ assert_eq ! ( * s. get( & 4 , guard) . unwrap( ) . value( ) , 40 ) ;
450
+ assert_eq ! ( * s. get_or_insert_with( 6 , || 600 , guard) . value( ) , 600 ) ;
451
+ }
452
+
453
+ #[ test]
454
+ fn get_or_insert_with_panic ( ) {
455
+ use std:: panic;
456
+
457
+ let s = SkipList :: new ( epoch:: default_collector ( ) . clone ( ) ) ;
458
+ let res = panic:: catch_unwind ( panic:: AssertUnwindSafe ( || {
459
+ let guard = & epoch:: pin ( ) ;
460
+ s. get_or_insert_with ( 4 , || panic ! ( ) , guard) ;
461
+ } ) ) ;
462
+ assert ! ( res. is_err( ) ) ;
463
+ assert ! ( s. is_empty( ) ) ;
464
+ let guard = & epoch:: pin ( ) ;
465
+ assert_eq ! ( * s. get_or_insert_with( 4 , || 40 , guard) . value( ) , 40 ) ;
466
+ assert_eq ! ( s. len( ) , 1 ) ;
467
+ }
468
+
469
+ #[ test]
470
+ fn get_or_insert_with_parallel_run ( ) {
471
+ use std:: sync:: { Arc , Mutex } ;
472
+
473
+ let s = Arc :: new ( SkipList :: new ( epoch:: default_collector ( ) . clone ( ) ) ) ;
474
+ let s2 = s. clone ( ) ;
475
+ let called = Arc :: new ( Mutex :: new ( false ) ) ;
476
+ let called2 = called. clone ( ) ;
477
+ let handle = std:: thread:: spawn ( move || {
478
+ let guard = & epoch:: pin ( ) ;
479
+ assert_eq ! (
480
+ * s2. get_or_insert_with(
481
+ 7 ,
482
+ || {
483
+ * called2. lock( ) . unwrap( ) = true ;
484
+
485
+ // allow main thread to run before we return result
486
+ std:: thread:: sleep( std:: time:: Duration :: from_secs( 4 ) ) ;
487
+ 70
488
+ } ,
489
+ guard,
490
+ )
491
+ . value( ) ,
492
+ 700
493
+ ) ;
494
+ } ) ;
495
+ std:: thread:: sleep ( std:: time:: Duration :: from_secs ( 2 ) ) ;
496
+ let guard = & epoch:: pin ( ) ;
497
+
498
+ // main thread writes the value first
499
+ assert_eq ! ( * s. get_or_insert( 7 , 700 , guard) . value( ) , 700 ) ;
500
+ handle. join ( ) . unwrap ( ) ;
501
+ assert ! ( * called. lock( ) . unwrap( ) ) ;
502
+ }
503
+
434
504
#[ test]
435
505
fn get_next_prev ( ) {
436
506
let guard = & epoch:: pin ( ) ;
0 commit comments