Skip to content

Commit 800ca61

Browse files
bors[bot]snow01
andauthored
Merge #718
718: added get_or_insert_with function r=taiki-e a=snow01 get_or_insert_with function allows lazy creation of default value. Default values may be heavy objects, lazy creation doesn't create un-necessary objects when key is already there in the data structure. Co-authored-by: Shailendra Sharma <shailendra.sharma@gmail.com>
2 parents 311124c + 403e899 commit 800ca61

File tree

4 files changed

+189
-4
lines changed

4 files changed

+189
-4
lines changed

crossbeam-skiplist/src/base.rs

+25-4
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,21 @@ where
471471

472472
/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist.
473473
pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
474+
self.insert_internal(key, || value, false, guard)
475+
}
476+
477+
/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
478+
/// where value is calculated with a function.
479+
///
480+
///
481+
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
482+
/// discarded. If closure is modifying some other state (such as shared counters or shared
483+
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
484+
/// result of closure inserted
485+
pub fn get_or_insert_with<F>(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V>
486+
where
487+
F: FnOnce() -> V,
488+
{
474489
self.insert_internal(key, value, false, guard)
475490
}
476491

@@ -831,13 +846,16 @@ where
831846
/// Inserts an entry with the specified `key` and `value`.
832847
///
833848
/// If `replace` is `true`, then any existing entry with this key will first be removed.
834-
fn insert_internal(
849+
fn insert_internal<F>(
835850
&self,
836851
key: K,
837-
value: V,
852+
value: F,
838853
replace: bool,
839854
guard: &Guard,
840-
) -> RefEntry<'_, K, V> {
855+
) -> RefEntry<'_, K, V>
856+
where
857+
F: FnOnce() -> V,
858+
{
841859
self.check_guard(guard);
842860

843861
unsafe {
@@ -876,6 +894,9 @@ where
876894
}
877895
}
878896

897+
// create value before creating node, so extra allocation doesn't happen if value() function panics
898+
let value = value();
899+
879900
// Create a new node.
880901
let height = self.random_height();
881902
let (node, n) = {
@@ -1061,7 +1082,7 @@ where
10611082
/// If there is an existing entry with this key, it will be removed before inserting the new
10621083
/// one.
10631084
pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
1064-
self.insert_internal(key, value, true, guard)
1085+
self.insert_internal(key, || value, true, guard)
10651086
}
10661087

10671088
/// Removes an entry with the specified `key` from the map and returns it.

crossbeam-skiplist/src/map.rs

+33
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,39 @@ where
254254
Entry::new(self.inner.get_or_insert(key, value, guard))
255255
}
256256

257+
/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
258+
/// where value is calculated with a function.
259+
///
260+
///
261+
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
262+
/// discarded. If closure is modifying some other state (such as shared counters or shared
263+
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
264+
/// result of closure inserted
265+
////
266+
/// This function returns an [`Entry`] which
267+
/// can be used to access the key's associated value.
268+
///
269+
///
270+
/// # Example
271+
/// ```
272+
/// use crossbeam_skiplist::SkipMap;
273+
///
274+
/// let ages = SkipMap::new();
275+
/// let gates_age = ages.get_or_insert_with("Bill Gates", || 64);
276+
/// assert_eq!(*gates_age.value(), 64);
277+
///
278+
/// ages.insert("Steve Jobs", 65);
279+
/// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1);
280+
/// assert_eq!(*jobs_age.value(), 65);
281+
/// ```
282+
pub fn get_or_insert_with<F>(&self, key: K, value_fn: F) -> Entry<'_, K, V>
283+
where
284+
F: FnOnce() -> V,
285+
{
286+
let guard = &epoch::pin();
287+
Entry::new(self.inner.get_or_insert_with(key, value_fn, guard))
288+
}
289+
257290
/// Returns an iterator over all entries in the map,
258291
/// sorted by key.
259292
///

crossbeam-skiplist/tests/base.rs

+70
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,76 @@ fn get_or_insert() {
431431
assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600);
432432
}
433433

434+
#[test]
435+
fn get_or_insert_with() {
436+
let guard = &epoch::pin();
437+
let s = SkipList::new(epoch::default_collector().clone());
438+
s.insert(3, 3, guard);
439+
s.insert(5, 5, guard);
440+
s.insert(1, 1, guard);
441+
s.insert(4, 4, guard);
442+
s.insert(2, 2, guard);
443+
444+
assert_eq!(*s.get(&4, guard).unwrap().value(), 4);
445+
assert_eq!(*s.insert(4, 40, guard).value(), 40);
446+
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
447+
448+
assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40);
449+
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
450+
assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600);
451+
}
452+
453+
#[test]
454+
fn get_or_insert_with_panic() {
455+
use std::panic;
456+
457+
let s = SkipList::new(epoch::default_collector().clone());
458+
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
459+
let guard = &epoch::pin();
460+
s.get_or_insert_with(4, || panic!(), guard);
461+
}));
462+
assert!(res.is_err());
463+
assert!(s.is_empty());
464+
let guard = &epoch::pin();
465+
assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40);
466+
assert_eq!(s.len(), 1);
467+
}
468+
469+
#[test]
470+
fn get_or_insert_with_parallel_run() {
471+
use std::sync::{Arc, Mutex};
472+
473+
let s = Arc::new(SkipList::new(epoch::default_collector().clone()));
474+
let s2 = s.clone();
475+
let called = Arc::new(Mutex::new(false));
476+
let called2 = called.clone();
477+
let handle = std::thread::spawn(move || {
478+
let guard = &epoch::pin();
479+
assert_eq!(
480+
*s2.get_or_insert_with(
481+
7,
482+
|| {
483+
*called2.lock().unwrap() = true;
484+
485+
// allow main thread to run before we return result
486+
std::thread::sleep(std::time::Duration::from_secs(4));
487+
70
488+
},
489+
guard,
490+
)
491+
.value(),
492+
700
493+
);
494+
});
495+
std::thread::sleep(std::time::Duration::from_secs(2));
496+
let guard = &epoch::pin();
497+
498+
// main thread writes the value first
499+
assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700);
500+
handle.join().unwrap();
501+
assert!(*called.lock().unwrap());
502+
}
503+
434504
#[test]
435505
fn get_next_prev() {
436506
let guard = &epoch::pin();

crossbeam-skiplist/tests/map.rs

+61
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,67 @@ fn get_or_insert() {
370370
assert_eq!(*s.get_or_insert(6, 600).value(), 600);
371371
}
372372

373+
#[test]
374+
fn get_or_insert_with() {
375+
let s = SkipMap::new();
376+
s.insert(3, 3);
377+
s.insert(5, 5);
378+
s.insert(1, 1);
379+
s.insert(4, 4);
380+
s.insert(2, 2);
381+
382+
assert_eq!(*s.get(&4).unwrap().value(), 4);
383+
assert_eq!(*s.insert(4, 40).value(), 40);
384+
assert_eq!(*s.get(&4).unwrap().value(), 40);
385+
386+
assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40);
387+
assert_eq!(*s.get(&4).unwrap().value(), 40);
388+
assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600);
389+
}
390+
391+
#[test]
392+
fn get_or_insert_with_panic() {
393+
use std::panic;
394+
395+
let s = SkipMap::new();
396+
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
397+
s.get_or_insert_with(4, || panic!());
398+
}));
399+
assert!(res.is_err());
400+
assert!(s.is_empty());
401+
assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40);
402+
assert_eq!(s.len(), 1);
403+
}
404+
405+
#[test]
406+
fn get_or_insert_with_parallel_run() {
407+
use std::sync::{Arc, Mutex};
408+
409+
let s = Arc::new(SkipMap::new());
410+
let s2 = s.clone();
411+
let called = Arc::new(Mutex::new(false));
412+
let called2 = called.clone();
413+
let handle = std::thread::spawn(move || {
414+
assert_eq!(
415+
*s2.get_or_insert_with(7, || {
416+
*called2.lock().unwrap() = true;
417+
418+
// allow main thread to run before we return result
419+
std::thread::sleep(std::time::Duration::from_secs(4));
420+
70
421+
})
422+
.value(),
423+
700
424+
);
425+
});
426+
std::thread::sleep(std::time::Duration::from_secs(2));
427+
428+
// main thread writes the value first
429+
assert_eq!(*s.get_or_insert(7, 700).value(), 700);
430+
handle.join().unwrap();
431+
assert!(*called.lock().unwrap());
432+
}
433+
373434
#[test]
374435
fn get_next_prev() {
375436
let s = SkipMap::new();

0 commit comments

Comments
 (0)