Skip to content

Commit

Permalink
add unit tests for most common values
Browse files Browse the repository at this point in the history
  • Loading branch information
xx01cyx committed Nov 15, 2024
1 parent f005602 commit be51e7c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
3 changes: 2 additions & 1 deletion optd-cost-model/src/stats/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};

/// The Counter structure to track exact frequencies of fixed elements.
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Debug)]
#[derive(Default, Serialize, Deserialize, Debug)]
pub struct Counter<T: PartialEq + Eq + Hash + Clone + Serialize + DeserializeOwned> {
#[serde_as(as = "HashMap<serde_with::json::JsonString, _>")]
counts: HashMap<T, i32>, // The exact counts of an element T.
Expand Down Expand Up @@ -37,6 +37,7 @@ where
if let Some(frequency) = self.counts.get_mut(&elem) {
*frequency += occ;
}
self.total_count += occ;
}

/// Digests an array of data into the Counter structure.
Expand Down
57 changes: 54 additions & 3 deletions optd-cost-model/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,59 @@ impl AttributeCombValueStats {
}
}

impl From<serde_json::Value> for AttributeCombValueStats {
fn from(value: serde_json::Value) -> Self {
serde_json::from_value(value).unwrap()
#[cfg(test)]
mod tests {
use super::{Counter, MostCommonValues};
use crate::{common::values::Value, stats::AttributeCombValue};
use serde_json::json;

#[test]
fn test_most_common_values() {
let elem1 = vec![Some(Value::Int32(1))];
let elem2 = vec![Some(Value::Int32(2))];
let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]);
counter.insert_element(elem1.clone(), 5);
counter.insert_element(elem2.clone(), 15);

let mcvs = MostCommonValues::Counter(counter);
assert_eq!(mcvs.freq(&elem1), Some(0.25));
assert_eq!(mcvs.freq(&elem2), Some(0.75));
assert_eq!(mcvs.total_freq(), 1.0);

let elem1_cloned = elem1.clone();
let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned);
let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1);
assert_eq!(mcvs.freq_over_pred(pred1), 0.25);
assert_eq!(mcvs.freq_over_pred(pred2), 0.75);

assert_eq!(mcvs.cnt(), 2);
}

#[test]
fn test_most_common_values_serde() {
let elem1 = vec![Some(Value::Int32(1))];
let elem2 = vec![Some(Value::Int32(2))];
let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]);
counter.insert_element(elem1.clone(), 5);
counter.insert_element(elem2.clone(), 15);

let mcvs = MostCommonValues::Counter(counter);
let serialized = serde_json::to_value(&mcvs).unwrap();
println!("serialized: {:?}", serialized);

let deserialized: MostCommonValues = serde_json::from_value(serialized).unwrap();
assert_eq!(mcvs.freq(&elem1), Some(0.25));
assert_eq!(mcvs.freq(&elem2), Some(0.75));
assert_eq!(mcvs.total_freq(), 1.0);

let elem1_cloned = elem1.clone();
let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned);
let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1);
assert_eq!(mcvs.freq_over_pred(pred1), 0.25);
assert_eq!(mcvs.freq_over_pred(pred2), 0.75);

assert_eq!(mcvs.cnt(), 2);
}

// TODO: Add tests for Distribution
}

0 comments on commit be51e7c

Please sign in to comment.