Skip to content

Commit

Permalink
Fixd it?
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Feb 28, 2025
1 parent ccbd3e5 commit 7a54959
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 14 deletions.
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backends/v3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ prost-build = "0.12.1"

[dev-dependencies]
criterion = "0.3"
env_logger = "0.11"
itertools = "0.13"
rustc-hash = "2"

Expand Down
90 changes: 76 additions & 14 deletions backends/v3/src/radix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ impl Allocator for RadixAllocator {
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
tracing::debug!(
"Increasing refcount of {prefix_node:?} with blocks {:?} to avoid deallocation",
self.cache_blocks.nodes[prefix_node].blocks
);

let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;

let suffix_blocks = suffix_len.div_ceil(self.block_size);

tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
tracing::debug!("Prefix blocks: {:?}:", blocks);

match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
Expand Down Expand Up @@ -286,19 +291,28 @@ impl RadixTrie {
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
tracing::debug!("Finding key {:?} in node {:?}", key, node_id);

if key.len() >= self.block_size {
let node_key = hash(&key[..self.block_size]);
if let Some(&child_id) = node.children.get(&node_key) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
tracing::debug!("Found child {:?} with key {:?}", child_id, child.key);
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
tracing::debug!("Shared prefix len: {shared_prefix_len}");
assert_eq!(shared_prefix_len % self.block_size, 0);
tracing::debug!(
"Adding blocks: {:?}",
&child.blocks[..shared_prefix_len / self.block_size]
);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);

let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
} else {
node_id = child_id;
}
}
}
Expand Down Expand Up @@ -384,6 +398,8 @@ impl RadixTrie {
let node = self.remove_node(node_id);
evicted.extend(node.blocks);

tracing::debug!("Evicted {evicted:?}");

if evicted.len() >= n_blocks {
break;
}
Expand All @@ -398,6 +414,7 @@ impl RadixTrie {
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
tracing::debug!("Evicted {evicted:?}");
break;
}
}
Expand Down Expand Up @@ -883,12 +900,14 @@ mod tests {

#[test]
fn invariants_hold_on_many_insertions() {
let _ = env_logger::builder().is_test(true).try_init();

const VOCAB_SIZE: u32 = 2;
const DATA_LEN: usize = 1_000_000;
const DATA_LEN: usize = 1_000;

//const MAX_PREFILL_LEN: usize = 2usize.pow(12);
const MAX_PREFILL_LEN: usize = 2usize.pow(14);
const MAX_DECODE_LEN: usize = 2usize.pow(4);
const MAX_PREFILL_LEN: usize = 4;
const MAX_DECODE_LEN: usize = 4;

let vocab_range = Uniform::new(0, VOCAB_SIZE);
let data_range = Uniform::new(0, DATA_LEN);
Expand All @@ -899,19 +918,24 @@ mod tests {
let data = (0..DATA_LEN)
.map(|_| vocab_range.sample(&mut rng))
.collect::<Vec<_>>();
let mut allocator = RadixAllocator::new(1, 4_000_000, None);
let mut allocator = RadixAllocator::new(1, 50, None);

let mut allocations = Vec::new();

let mut allocation_generation = HashMap::new();

for i in 0..500 {
// Allocate until all blocks are used.
'allocation: loop {
// Use offset 0 half of the times for prefix sharing.
let prefill_offset = *[0, 0, 1, data_range.sample(&mut rng)]
let prefill_offset = *[data_range.sample(&mut rng)]
//let prefill_offset = *[0, 0, 1, data_range.sample(&mut rng)]
.choose(&mut rng)
.unwrap();
let prefill_len = prefill_len_range.sample(&mut rng);
let decode_len = decode_len_range.sample(&mut rng);
//let prefill_len = prefill_len_range.sample(&mut rng);
let prefill_len = 2;
//let decode_len = decode_len_range.sample(&mut rng);
let decode_len = 2;

let prefill =
data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec();
Expand All @@ -932,20 +956,37 @@ mod tests {
.collect::<FxHashSet<_>>();
let blockset = allocation.blocks.iter().copied().collect::<FxHashSet<_>>();

// No duplicate blocks in an allocation.
if allocation.blocks.len() != blockset.len() {
eprintln!("allocation: {:?}", allocation);
eprintln!("blockset: {:?}", blockset);
}
assert_eq!(
allocation.blocks.len(),
blockset.len(),
"Duplicate blocks in allocation"
);

allocation_generation.insert(allocation.allocation_id, i);

allocations.push((allocation, blockset, prefix_blocks, non_prefix_blocks));
}

// Check invariants. Skip first iteration, since there is no prefix sharing yet.
if i > 1 {
check_allocation_invariants(&allocations);
check_allocation_invariants(&allocations, &allocation_generation);
}

// Remove 20% of the allocations, randomly.
allocations.shuffle(&mut rng);
let remove_index = (allocations.len() as f64 * 0.8) as usize;
for (allocation, _, _, _) in allocations.drain(remove_index..) {
allocator.free(allocation.blocks.clone(), allocation.allocation_id);
}
//allocations.shuffle(&mut rng);
//let remove_index = (allocations.len() as f64 * 0.8) as usize;
//for (allocation, _, _, _) in allocations.drain(remove_index..) {
// allocator.free(allocation.blocks.clone(), allocation.allocation_id);
//}
allocations.into_iter().for_each(|(allocation, _, _, _)| {
allocator.free(allocation.blocks.clone(), allocation.allocation_id)
});
allocations = Vec::new();
}
}

Expand All @@ -956,6 +997,7 @@ mod tests {
FxHashSet<u32>,
FxHashSet<u32>,
)],
allocation_generation: &HashMap<u64, usize>,
) {
for i in 0..allocations.len() {
let (allocation, blockset, prefix_blocks, non_prefix_blocks) = &allocations[i];
Expand All @@ -964,6 +1006,10 @@ mod tests {
assert!(!blockset.contains(&0), "Block 0 must not be allocated");

// No duplicate blocks in an allocation.
if allocation.blocks.len() != blockset.len() {
eprintln!("allocation: {:?}", allocation);
eprintln!("blockset: {:?}", blockset);
}
assert_eq!(
allocation.blocks.len(),
blockset.len(),
Expand All @@ -976,8 +1022,24 @@ mod tests {
// non_prefix_blocks.len()
//);

for (_, _, other_prefix_blocks, other_non_prefix_blocks) in &allocations[i + 1..] {
for (other_allocation, _, other_prefix_blocks, other_non_prefix_blocks) in
&allocations[i + 1..]
{
if !other_non_prefix_blocks.is_disjoint(non_prefix_blocks) {
eprintln!("allocation: {:?}", allocation);
eprintln!("other allocation: {:?}", other_allocation);
eprintln!(
"allocation generation: {:?}",
allocation_generation
.get(&allocation.allocation_id)
.unwrap()
);
eprintln!(
"other allocation generation: {:?}",
allocation_generation
.get(&other_allocation.allocation_id)
.unwrap()
);
eprintln!(
"overlapping: {:?}",
non_prefix_blocks.intersection(other_non_prefix_blocks),
Expand Down

0 comments on commit 7a54959

Please sign in to comment.