diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 996290ed3ba..588a2716fe1 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -71,6 +71,7 @@ prost-build = "0.12.1" [dev-dependencies] criterion = "0.3" itertools = "0.13" +rustc-hash = "2" [features] default = ["ngrok"] diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 532ec6ddcc8..e697dc01bd4 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -208,6 +208,7 @@ impl Allocator for RadixAllocator { } } +#[derive(Debug)] struct RadixAllocation { prefix_node: NodeId, cached_prefix_len: usize, @@ -631,6 +632,12 @@ fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { mod tests { use std::sync::Arc; + use rand::{ + distributions::Uniform, prelude::Distribution, rngs::SmallRng, seq::SliceRandom, + SeedableRng, + }; + use rustc_hash::FxHashSet; + use super::*; #[test] @@ -873,4 +880,114 @@ mod tests { // Clear out the whole trie. assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); } + + #[test] + fn invariants_hold_on_many_insertions() { + const VOCAB_SIZE: u32 = 2; + const DATA_LEN: usize = 1_000_000; + + //const MAX_PREFILL_LEN: usize = 2usize.pow(12); + const MAX_PREFILL_LEN: usize = 2usize.pow(14); + const MAX_DECODE_LEN: usize = 2usize.pow(4); + + let vocab_range = Uniform::new(0, VOCAB_SIZE); + let data_range = Uniform::new(0, DATA_LEN); + let prefill_len_range = Uniform::new(0, MAX_PREFILL_LEN); + let decode_len_range = Uniform::new(0, MAX_DECODE_LEN); + + let mut rng = SmallRng::seed_from_u64(64); + let data = (0..DATA_LEN) + .map(|_| vocab_range.sample(&mut rng)) + .collect::>(); + let mut allocator = RadixAllocator::new(1, 4_000_000, None); + + let mut allocations = Vec::new(); + + for i in 0..500 { + // Allocate until all blocks are used. + 'allocation: loop { + // Use offset 0 half of the times for prefix sharing. + let prefill_offset = *[0, 0, 1, data_range.sample(&mut rng)] + .choose(&mut rng) + .unwrap(); + let prefill_len = prefill_len_range.sample(&mut rng); + let decode_len = decode_len_range.sample(&mut rng); + + let prefill = + data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec(); + + let allocation = match allocator + .allocate((prefill.len() + decode_len) as u32, Some(Arc::new(prefill))) + { + Some(allocation) => allocation, + None => break 'allocation, + }; + let prefix_blocks = allocation.blocks[..allocation.prefix_len as usize] + .iter() + .copied() + .collect::>(); + let non_prefix_blocks = allocation.blocks[allocation.prefix_len as usize..] + .iter() + .copied() + .collect::>(); + let blockset = allocation.blocks.iter().copied().collect::>(); + + allocations.push((allocation, blockset, prefix_blocks, non_prefix_blocks)); + } + + // Check invariants. Skip first iteration, since there is no prefix sharing yet. + if i > 1 { + check_allocation_invariants(&allocations); + } + + // Remove 20% of the allocations, randomly. + allocations.shuffle(&mut rng); + let remove_index = (allocations.len() as f64 * 0.8) as usize; + for (allocation, _, _, _) in allocations.drain(remove_index..) { + allocator.free(allocation.blocks.clone(), allocation.allocation_id); + } + } + } + + fn check_allocation_invariants( + allocations: &[( + BlockAllocation, + FxHashSet, + FxHashSet, + FxHashSet, + )], + ) { + for i in 0..allocations.len() { + let (allocation, blockset, prefix_blocks, non_prefix_blocks) = &allocations[i]; + + // 0 is used for health checks, must not be used. + assert!(!blockset.contains(&0), "Block 0 must not be allocated"); + + // No duplicate blocks in an allocation. + assert_eq!( + allocation.blocks.len(), + blockset.len(), + "Duplicate blocks in allocation" + ); + + //eprintln!( + // "Prefix blocks, non-prefix blocks: {} {}", + // prefix_blocks.len(), + // non_prefix_blocks.len() + //); + + for (_, _, other_prefix_blocks, other_non_prefix_blocks) in &allocations[i + 1..] { + if !other_non_prefix_blocks.is_disjoint(non_prefix_blocks) { + eprintln!( + "overlapping: {:?}", + non_prefix_blocks.intersection(other_non_prefix_blocks), + ); + } + assert!( + other_non_prefix_blocks.is_disjoint(non_prefix_blocks), + "Allocations share non-prefix blocks" + ) + } + } + } }