From 7a5495974121e2e33713d7feec173965d54f0dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 28 Feb 2025 17:08:07 +0000 Subject: [PATCH] Fixd it? --- Cargo.lock | 31 ++++++++++++++ backends/v3/Cargo.toml | 1 + backends/v3/src/radix.rs | 90 +++++++++++++++++++++++++++++++++------- 3 files changed, 108 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4603f77d206..e33e4b8d418 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1226,6 +1226,29 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1784,6 +1807,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "0.14.32" @@ -4826,6 +4855,7 @@ dependencies = [ "base64 0.22.1", "clap 4.5.30", "criterion", + "env_logger", "futures", "futures-util", "grpc-metadata", @@ -4847,6 +4877,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest 0.11.27", + "rustc-hash 2.1.1", "serde", "serde_json", "slotmap", diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 588a2716fe1..b71213466b0 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -70,6 +70,7 @@ prost-build = "0.12.1" [dev-dependencies] criterion = "0.3" +env_logger = "0.11" itertools = "0.13" rustc-hash = "2" diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index e697dc01bd4..7cf244c9e26 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -99,6 +99,10 @@ impl Allocator for RadixAllocator { self.cache_blocks .incref(prefix_node) .expect("Failed to increment refcount"); + tracing::debug!( + "Increasing refcount of {prefix_node:?} with blocks {:?} to avoid deallocation", + self.cache_blocks.nodes[prefix_node].blocks + ); let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; @@ -106,6 +110,7 @@ impl Allocator for RadixAllocator { let suffix_blocks = suffix_len.div_ceil(self.block_size); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + tracing::debug!("Prefix blocks: {:?}:", blocks); match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), @@ -286,19 +291,28 @@ impl RadixTrie { /// Find worker. fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { let node = &self.nodes[node_id]; + tracing::debug!("Finding key {:?} in node {:?}", key, node_id); if key.len() >= self.block_size { let node_key = hash(&key[..self.block_size]); if let Some(&child_id) = node.children.get(&node_key) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); + tracing::debug!("Found child {:?} with key {:?}", child_id, child.key); let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + tracing::debug!("Shared prefix len: {shared_prefix_len}"); assert_eq!(shared_prefix_len % self.block_size, 0); + tracing::debug!( + "Adding blocks: {:?}", + &child.blocks[..shared_prefix_len / self.block_size] + ); blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); let key = &key[shared_prefix_len..]; if !key.is_empty() { node_id = self.find_(child_id, key, blocks); + } else { + node_id = child_id; } } } @@ -384,6 +398,8 @@ impl RadixTrie { let node = self.remove_node(node_id); evicted.extend(node.blocks); + tracing::debug!("Evicted {evicted:?}"); + if evicted.len() >= n_blocks { break; } @@ -398,6 +414,7 @@ impl RadixTrie { node.key.truncate(truncate_tokens); evicted.extend(node.blocks.split_off(truncate_blocks)); self.leaves.insert((last_access, node_id)); + tracing::debug!("Evicted {evicted:?}"); break; } } @@ -883,12 +900,14 @@ mod tests { #[test] fn invariants_hold_on_many_insertions() { + let _ = env_logger::builder().is_test(true).try_init(); + const VOCAB_SIZE: u32 = 2; - const DATA_LEN: usize = 1_000_000; + const DATA_LEN: usize = 1_000; //const MAX_PREFILL_LEN: usize = 2usize.pow(12); - const MAX_PREFILL_LEN: usize = 2usize.pow(14); - const MAX_DECODE_LEN: usize = 2usize.pow(4); + const MAX_PREFILL_LEN: usize = 4; + const MAX_DECODE_LEN: usize = 4; let vocab_range = Uniform::new(0, VOCAB_SIZE); let data_range = Uniform::new(0, DATA_LEN); @@ -899,19 +918,24 @@ mod tests { let data = (0..DATA_LEN) .map(|_| vocab_range.sample(&mut rng)) .collect::>(); - let mut allocator = RadixAllocator::new(1, 4_000_000, None); + let mut allocator = RadixAllocator::new(1, 50, None); let mut allocations = Vec::new(); + let mut allocation_generation = HashMap::new(); + for i in 0..500 { // Allocate until all blocks are used. 'allocation: loop { // Use offset 0 half of the times for prefix sharing. - let prefill_offset = *[0, 0, 1, data_range.sample(&mut rng)] + let prefill_offset = *[data_range.sample(&mut rng)] + //let prefill_offset = *[0, 0, 1, data_range.sample(&mut rng)] .choose(&mut rng) .unwrap(); - let prefill_len = prefill_len_range.sample(&mut rng); - let decode_len = decode_len_range.sample(&mut rng); + //let prefill_len = prefill_len_range.sample(&mut rng); + let prefill_len = 2; + //let decode_len = decode_len_range.sample(&mut rng); + let decode_len = 2; let prefill = data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec(); @@ -932,20 +956,37 @@ mod tests { .collect::>(); let blockset = allocation.blocks.iter().copied().collect::>(); + // No duplicate blocks in an allocation. + if allocation.blocks.len() != blockset.len() { + eprintln!("allocation: {:?}", allocation); + eprintln!("blockset: {:?}", blockset); + } + assert_eq!( + allocation.blocks.len(), + blockset.len(), + "Duplicate blocks in allocation" + ); + + allocation_generation.insert(allocation.allocation_id, i); + allocations.push((allocation, blockset, prefix_blocks, non_prefix_blocks)); } // Check invariants. Skip first iteration, since there is no prefix sharing yet. if i > 1 { - check_allocation_invariants(&allocations); + check_allocation_invariants(&allocations, &allocation_generation); } // Remove 20% of the allocations, randomly. - allocations.shuffle(&mut rng); - let remove_index = (allocations.len() as f64 * 0.8) as usize; - for (allocation, _, _, _) in allocations.drain(remove_index..) { - allocator.free(allocation.blocks.clone(), allocation.allocation_id); - } + //allocations.shuffle(&mut rng); + //let remove_index = (allocations.len() as f64 * 0.8) as usize; + //for (allocation, _, _, _) in allocations.drain(remove_index..) { + // allocator.free(allocation.blocks.clone(), allocation.allocation_id); + //} + allocations.into_iter().for_each(|(allocation, _, _, _)| { + allocator.free(allocation.blocks.clone(), allocation.allocation_id) + }); + allocations = Vec::new(); } } @@ -956,6 +997,7 @@ mod tests { FxHashSet, FxHashSet, )], + allocation_generation: &HashMap, ) { for i in 0..allocations.len() { let (allocation, blockset, prefix_blocks, non_prefix_blocks) = &allocations[i]; @@ -964,6 +1006,10 @@ mod tests { assert!(!blockset.contains(&0), "Block 0 must not be allocated"); // No duplicate blocks in an allocation. + if allocation.blocks.len() != blockset.len() { + eprintln!("allocation: {:?}", allocation); + eprintln!("blockset: {:?}", blockset); + } assert_eq!( allocation.blocks.len(), blockset.len(), @@ -976,8 +1022,24 @@ mod tests { // non_prefix_blocks.len() //); - for (_, _, other_prefix_blocks, other_non_prefix_blocks) in &allocations[i + 1..] { + for (other_allocation, _, other_prefix_blocks, other_non_prefix_blocks) in + &allocations[i + 1..] + { if !other_non_prefix_blocks.is_disjoint(non_prefix_blocks) { + eprintln!("allocation: {:?}", allocation); + eprintln!("other allocation: {:?}", other_allocation); + eprintln!( + "allocation generation: {:?}", + allocation_generation + .get(&allocation.allocation_id) + .unwrap() + ); + eprintln!( + "other allocation generation: {:?}", + allocation_generation + .get(&other_allocation.allocation_id) + .unwrap() + ); eprintln!( "overlapping: {:?}", non_prefix_blocks.intersection(other_non_prefix_blocks),