From be51e7cd9de76e284620c0c65a65f22e24cfec26 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 17:22:28 -0500 Subject: [PATCH] add unit tests for most common values --- optd-cost-model/src/stats/counter.rs | 3 +- optd-cost-model/src/stats/mod.rs | 57 ++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/counter.rs index baa32ab..82b6a34 100644 --- a/optd-cost-model/src/stats/counter.rs +++ b/optd-cost-model/src/stats/counter.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; /// The Counter structure to track exact frequencies of fixed elements. #[serde_with::serde_as] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Default, Serialize, Deserialize, Debug)] pub struct Counter { #[serde_as(as = "HashMap")] counts: HashMap, // The exact counts of an element T. @@ -37,6 +37,7 @@ where if let Some(frequency) = self.counts.get_mut(&elem) { *frequency += occ; } + self.total_count += occ; } /// Digests an array of data into the Counter structure. diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 287b20a..4657330 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -116,8 +116,59 @@ impl AttributeCombValueStats { } } -impl From for AttributeCombValueStats { - fn from(value: serde_json::Value) -> Self { - serde_json::from_value(value).unwrap() +#[cfg(test)] +mod tests { + use super::{Counter, MostCommonValues}; + use crate::{common::values::Value, stats::AttributeCombValue}; + use serde_json::json; + + #[test] + fn test_most_common_values() { + let elem1 = vec![Some(Value::Int32(1))]; + let elem2 = vec![Some(Value::Int32(2))]; + let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]); + counter.insert_element(elem1.clone(), 5); + counter.insert_element(elem2.clone(), 15); + + let mcvs = MostCommonValues::Counter(counter); + assert_eq!(mcvs.freq(&elem1), Some(0.25)); + assert_eq!(mcvs.freq(&elem2), Some(0.75)); + assert_eq!(mcvs.total_freq(), 1.0); + + let elem1_cloned = elem1.clone(); + let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned); + let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1); + assert_eq!(mcvs.freq_over_pred(pred1), 0.25); + assert_eq!(mcvs.freq_over_pred(pred2), 0.75); + + assert_eq!(mcvs.cnt(), 2); } + + #[test] + fn test_most_common_values_serde() { + let elem1 = vec![Some(Value::Int32(1))]; + let elem2 = vec![Some(Value::Int32(2))]; + let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]); + counter.insert_element(elem1.clone(), 5); + counter.insert_element(elem2.clone(), 15); + + let mcvs = MostCommonValues::Counter(counter); + let serialized = serde_json::to_value(&mcvs).unwrap(); + println!("serialized: {:?}", serialized); + + let deserialized: MostCommonValues = serde_json::from_value(serialized).unwrap(); + assert_eq!(mcvs.freq(&elem1), Some(0.25)); + assert_eq!(mcvs.freq(&elem2), Some(0.75)); + assert_eq!(mcvs.total_freq(), 1.0); + + let elem1_cloned = elem1.clone(); + let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned); + let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1); + assert_eq!(mcvs.freq_over_pred(pred1), 0.25); + assert_eq!(mcvs.freq_over_pred(pred2), 0.75); + + assert_eq!(mcvs.cnt(), 2); + } + + // TODO: Add tests for Distribution }