From 4e47df89c561061bfec04872a81f31214905939b Mon Sep 17 00:00:00 2001 From: Abhishek Tripathi Date: Fri, 9 May 2025 17:43:17 +0530 Subject: [PATCH 1/3] fix: unit tests for python values --- Cargo.toml | 2 +- src/base/value.rs | 8 ++--- src/py/convert.rs | 86 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3bec286a..87e08fa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ name = "cocoindex_engine" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.24.1", features = ["chrono"] } +pyo3 = { version = "0.24.1", features = ["chrono", "auto-initialize"] } pythonize = "0.24.0" pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] } diff --git a/src/base/value.rs b/src/base/value.rs index c3266fb7..50cd7300 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -340,7 +340,7 @@ impl KeyValue { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum BasicValue { Bytes(Bytes), Str(Arc), @@ -511,7 +511,7 @@ impl BasicValue { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub enum Value { #[default] Null, @@ -747,7 +747,7 @@ impl Value { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct FieldValues { pub fields: Vec>, } @@ -821,7 +821,7 @@ where } } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, PartialEq)] pub struct ScopeValue(pub FieldValues); impl Deref for ScopeValue { diff --git a/src/py/convert.rs b/src/py/convert.rs index 327ba828..ca8e0513 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -226,3 +226,89 @@ pub fn value_from_py_object<'py>( }; Ok(result) } + + +#[cfg(test)] +mod tests { + use super::*; + use crate::base::{schema, value}; + use pyo3::Python; + use std::sync::Arc; + + #[test] + fn test_roundtrip_basic_values() { + Python::with_gil(|py| { + // Test Int64 + let int_value = value::Value::Basic(value::BasicValue::Int64(42)); + let int_type = schema::ValueType::Basic(schema::BasicValueType::Int64); + let py_obj = value_to_py_object(py, &int_value).unwrap(); + let roundtrip_value = value_from_py_object(&int_type, &py_obj).unwrap(); + assert_eq!(int_value, roundtrip_value, "Int64 roundtrip failed"); + + // Test String + let str_value = value::Value::Basic(value::BasicValue::Str(Arc::from("test string"))); + let str_type = schema::ValueType::Basic(schema::BasicValueType::Str); + let py_obj = value_to_py_object(py, &str_value).unwrap(); + let roundtrip_value = value_from_py_object(&str_type, &py_obj).unwrap(); + assert_eq!(str_value, roundtrip_value, "String roundtrip failed"); + + // Test Bool + let bool_value = value::Value::Basic(value::BasicValue::Bool(true)); + let bool_type = schema::ValueType::Basic(schema::BasicValueType::Bool); + let py_obj = value_to_py_object(py, &bool_value).unwrap(); + let roundtrip_value = value_from_py_object(&bool_type, &py_obj).unwrap(); + assert_eq!(bool_value, roundtrip_value, "Bool roundtrip failed"); + }); + } + + #[test] + fn test_roundtrip_struct() { + Python::with_gil(|py| { + // Create a struct schema with multiple fields + let struct_schema = schema::StructSchema { + description: Some(Arc::from("Test struct")), + fields: Arc::new(vec![ + schema::FieldSchema { + name: "id".into(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Int64), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "name".into(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Str), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "active".into(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Bool), + nullable: false, + attrs: Default::default(), + }, + }, + ]), + }; + + // Create a struct value matching the schema + let struct_value = value::Value::Struct(value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(1)), + value::Value::Basic(value::BasicValue::Str(Arc::from("test"))), + value::Value::Basic(value::BasicValue::Bool(true)), + ], + }); + + // Perform roundtrip conversion + let struct_type = schema::ValueType::Struct(struct_schema); + let py_obj = value_to_py_object(py, &struct_value).unwrap(); + let roundtrip_value = value_from_py_object(&struct_type, &py_obj).unwrap(); + assert_eq!(struct_value, roundtrip_value, "Struct roundtrip failed"); + }); + } +} From 42c3edd58944624ff804c82627a2b18aa43bf5bd Mon Sep 17 00:00:00 2001 From: Abhishek Tripathi Date: Wed, 14 May 2025 12:40:53 +0530 Subject: [PATCH 2/3] fix: KTable test --- src/base/value.rs | 10 +- src/py/convert.rs | 324 ++++++++++++++++++++++++++++++++++++---------- src/server.rs | 2 +- src/settings.rs | 4 +- 4 files changed, 264 insertions(+), 76 deletions(-) diff --git a/src/base/value.rs b/src/base/value.rs index 50cd7300..a1c2e21d 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -71,7 +71,7 @@ impl<'de> Deserialize<'de> for RangeValue { } /// Value of key. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize)] pub enum KeyValue { Bytes(Bytes), Str(Arc), @@ -340,7 +340,7 @@ impl KeyValue { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub enum BasicValue { Bytes(Bytes), Str(Arc), @@ -511,7 +511,7 @@ impl BasicValue { } } -#[derive(Debug, Clone, Default, PartialEq)] +#[derive(Debug, Clone, Default, PartialEq, Deserialize)] pub enum Value { #[default] Null, @@ -747,7 +747,7 @@ impl Value { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct FieldValues { pub fields: Vec>, } @@ -821,7 +821,7 @@ where } } -#[derive(Debug, Clone, Serialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ScopeValue(pub FieldValues); impl Deref for ScopeValue { diff --git a/src/py/convert.rs b/src/py/convert.rs index ca8e0513..4267f4b6 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use super::IntoPyResult; use crate::base::{schema, value}; +#[derive(Debug)] pub struct Pythonized(pub T); impl<'py, T: DeserializeOwned> FromPyObject<'py> for Pythonized { @@ -168,6 +169,7 @@ fn field_values_from_py_object<'py>( list.len() ))); } + Ok(value::FieldValues { fields: schema .fields @@ -198,6 +200,7 @@ pub fn value_from_py_object<'py>( .into_iter() .map(|v| field_values_from_py_object(&schema.row, &v)) .collect::>>()?; + match schema.kind { schema::TableKind::UTable => { value::Value::UTable(values.into_iter().map(|v| v.into()).collect()) @@ -205,6 +208,7 @@ pub fn value_from_py_object<'py>( schema::TableKind::LTable => { value::Value::LTable(values.into_iter().map(|v| v.into()).collect()) } + schema::TableKind::KTable => value::Value::KTable( values .into_iter() @@ -227,88 +231,272 @@ pub fn value_from_py_object<'py>( Ok(result) } - +// `value_from_py_object` and `value_to_py_object` are functions internal to this module. `Pythonized` is the API exposed by the module. Ideally we want to test on the behavior of the public API. #[cfg(test)] mod tests { - use super::*; - use crate::base::{schema, value}; + use super::*; // To bring Pythonized into scope + use crate::base::schema; + use crate::base::value; + use crate::base::value::ScopeValue; // Changed import from GeneralValueSpec use pyo3::Python; + use std::collections::BTreeMap; use std::sync::Arc; - #[test] - fn test_roundtrip_basic_values() { + fn assert_roundtrip_conversion(original_value: &value::Value, value_type: &schema::ValueType) { Python::with_gil(|py| { - // Test Int64 - let int_value = value::Value::Basic(value::BasicValue::Int64(42)); - let int_type = schema::ValueType::Basic(schema::BasicValueType::Int64); - let py_obj = value_to_py_object(py, &int_value).unwrap(); - let roundtrip_value = value_from_py_object(&int_type, &py_obj).unwrap(); - assert_eq!(int_value, roundtrip_value, "Int64 roundtrip failed"); - - // Test String - let str_value = value::Value::Basic(value::BasicValue::Str(Arc::from("test string"))); - let str_type = schema::ValueType::Basic(schema::BasicValueType::Str); - let py_obj = value_to_py_object(py, &str_value).unwrap(); - let roundtrip_value = value_from_py_object(&str_type, &py_obj).unwrap(); - assert_eq!(str_value, roundtrip_value, "String roundtrip failed"); - - // Test Bool - let bool_value = value::Value::Basic(value::BasicValue::Bool(true)); - let bool_type = schema::ValueType::Basic(schema::BasicValueType::Bool); - let py_obj = value_to_py_object(py, &bool_value).unwrap(); - let roundtrip_value = value_from_py_object(&bool_type, &py_obj).unwrap(); - assert_eq!(bool_value, roundtrip_value, "Bool roundtrip failed"); + // Convert Rust value to Python object + let pythonized_value = Pythonized(original_value.clone()); + let py_object = pythonized_value.into_pyobject(py).unwrap_or_else(|e| { + panic!("Failed to convert Rust value to Python object: {:?}", e) + }); + + println!("Python object: {:?}", py_object); + // Convert Python object back to Rust value + // let roundtripped_value = Pythonized::::extract_bound(&py_object) + // .unwrap_or_else(|e| panic!("Failed to convert Python object back to Rust value: {:?}", e)); + let roundtripped_value = + value_from_py_object(value_type, &py_object).unwrap_or_else(|e| { + panic!( + "Failed to convert Python object back to Rust value: {:?}", + e + ) + }); + + println!("Roundtripped value: {:?}", roundtripped_value); + // Compare values + match (&original_value, &roundtripped_value) { + (value::Value::Basic(orig), value::Value::Basic(round)) => { + assert_eq!(orig, round, "BasicValue mismatch"); + } + (value::Value::Struct(orig), value::Value::Struct(round)) => { + assert_eq!( + orig.fields.len(), + round.fields.len(), + "Struct field count mismatch" + ); + for (o, r) in orig.fields.iter().zip(round.fields.iter()) { + assert_eq!(o, r, "Struct field value mismatch"); + } + } + (value::Value::UTable(orig), value::Value::UTable(round)) => { + assert_eq!(orig.len(), round.len(), "UTable row count mismatch"); + for (o, r) in orig.iter().zip(round.iter()) { + assert_eq!( + o.fields.len(), + r.fields.len(), + "UTable field count mismatch" + ); + for (of, rf) in o.fields.iter().zip(r.fields.iter()) { + assert_eq!(of, rf, "UTable field value mismatch"); + } + } + } + (value::Value::LTable(orig), value::Value::LTable(round)) => { + assert_eq!(orig.len(), round.len(), "LTable row count mismatch"); + for (o, r) in orig.iter().zip(round.iter()) { + assert_eq!( + o.fields.len(), + r.fields.len(), + "LTable field count mismatch" + ); + for (of, rf) in o.fields.iter().zip(r.fields.iter()) { + assert_eq!(of, rf, "LTable field value mismatch"); + } + } + } + (value::Value::KTable(orig), value::Value::KTable(round)) => { + assert_eq!(orig.len(), round.len(), "KTable entry count mismatch"); + for (ok, ov) in orig.iter() { + let rv = round + .get(ok) + .unwrap_or_else(|| panic!("Missing key in KTable roundtrip: {:?}", ok)); + assert_eq!( + ov.fields.len(), + rv.fields.len(), + "KTable field count mismatch" + ); + for (of, rf) in ov.fields.iter().zip(rv.fields.iter()) { + assert_eq!(of, rf, "KTable field value mismatch"); + } + } + } + _ => panic!( + "Value type mismatch: expected {:?}, got {:?}", + original_value, roundtripped_value + ), + } }); } + #[test] + fn test_roundtrip_basic_values() { + let values_and_types = vec![ + ( + value::Value::Basic(value::BasicValue::Int64(42)), + schema::ValueType::Basic(schema::BasicValueType::Int64), + ), + ( + value::Value::Basic(value::BasicValue::Float64(3.14)), + schema::ValueType::Basic(schema::BasicValueType::Float64), + ), + ( + value::Value::Basic(value::BasicValue::Str(Arc::from("hello"))), + schema::ValueType::Basic(schema::BasicValueType::Str), + ), + ( + value::Value::Basic(value::BasicValue::Bool(true)), + schema::ValueType::Basic(schema::BasicValueType::Bool), + ), + ]; + + for (val, typ) in values_and_types { + assert_roundtrip_conversion(&val, &typ); + } + } + #[test] fn test_roundtrip_struct() { - Python::with_gil(|py| { - // Create a struct schema with multiple fields - let struct_schema = schema::StructSchema { - description: Some(Arc::from("Test struct")), - fields: Arc::new(vec![ - schema::FieldSchema { - name: "id".into(), - value_type: schema::EnrichedValueType { - typ: schema::ValueType::Basic(schema::BasicValueType::Int64), - nullable: false, - attrs: Default::default(), - }, + let struct_schema = schema::StructSchema { + description: Some(Arc::from("Test struct description")), + fields: Arc::new(vec![ + schema::FieldSchema { + name: "a".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Int64), + nullable: false, + attrs: Default::default(), }, - schema::FieldSchema { - name: "name".into(), - value_type: schema::EnrichedValueType { - typ: schema::ValueType::Basic(schema::BasicValueType::Str), - nullable: false, - attrs: Default::default(), - }, + }, + schema::FieldSchema { + name: "b".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Str), + nullable: false, + attrs: Default::default(), }, - schema::FieldSchema { - name: "active".into(), - value_type: schema::EnrichedValueType { - typ: schema::ValueType::Basic(schema::BasicValueType::Bool), - nullable: false, - attrs: Default::default(), - }, - }, - ]), - }; - - // Create a struct value matching the schema - let struct_value = value::Value::Struct(value::FieldValues { - fields: vec![ - value::Value::Basic(value::BasicValue::Int64(1)), - value::Value::Basic(value::BasicValue::Str(Arc::from("test"))), - value::Value::Basic(value::BasicValue::Bool(true)), - ], - }); + }, + ]), + }; + + let struct_val_data = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(10)), + value::Value::Basic(value::BasicValue::Str(Arc::from("world"))), + ], + }; - // Perform roundtrip conversion - let struct_type = schema::ValueType::Struct(struct_schema); - let py_obj = value_to_py_object(py, &struct_value).unwrap(); - let roundtrip_value = value_from_py_object(&struct_type, &py_obj).unwrap(); - assert_eq!(struct_value, roundtrip_value, "Struct roundtrip failed"); + let struct_val = value::Value::Struct(struct_val_data); + let struct_typ = schema::ValueType::Struct(struct_schema); // No clone needed + + assert_roundtrip_conversion(&struct_val, &struct_typ); + } + + #[test] + fn test_roundtrip_table_types() { + let row_schema_struct = Arc::new(schema::StructSchema { + description: Some(Arc::from("Test table row description")), + fields: Arc::new(vec![ + schema::FieldSchema { + name: "key_col".to_string(), // Will be used as key for KTable implicitly + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Int64), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "data_col_1".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Str), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "data_col_2".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Bool), + nullable: false, + attrs: Default::default(), + }, + }, + ]), }); + + let row1_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(1)), + value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))), + value::Value::Basic(value::BasicValue::Bool(true)), + ], + }; + let row1_scope_val: value::ScopeValue = row1_fields.into(); + + let row2_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(2)), + value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))), + value::Value::Basic(value::BasicValue::Bool(false)), + ], + }; + let row2_scope_val: value::ScopeValue = row2_fields.into(); + + // UTable + let utable_schema = schema::TableSchema { + kind: schema::TableKind::UTable, + row: (*row_schema_struct).clone(), + }; + let utable_val = value::Value::UTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]); + let utable_typ = schema::ValueType::Table(utable_schema); + assert_roundtrip_conversion(&utable_val, &utable_typ); + + // LTable + let ltable_schema = schema::TableSchema { + kind: schema::TableKind::LTable, + row: (*row_schema_struct).clone(), + }; + let ltable_val = value::Value::LTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]); + let ltable_typ = schema::ValueType::Table(ltable_schema); + assert_roundtrip_conversion(<able_val, <able_typ); + + // KTable + let ktable_schema = schema::TableSchema { + kind: schema::TableKind::KTable, + row: (*row_schema_struct).clone(), + }; + let mut ktable_data = BTreeMap::new(); + + // Create KTable entries where the ScopeValue doesn't include the key field + // This matches how the Python code will serialize/deserialize + let row1_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))), + value::Value::Basic(value::BasicValue::Bool(true)), + ], + }; + let row1_scope_val: value::ScopeValue = row1_fields.into(); + + let row2_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))), + value::Value::Basic(value::BasicValue::Bool(false)), + ], + }; + let row2_scope_val: value::ScopeValue = row2_fields.into(); + + // For KTable, the key is extracted from the first field of ScopeValue based on current serialization + let key1 = value::Value::::Basic(value::BasicValue::Int64(1)) + .into_key() + .unwrap(); + let key2 = value::Value::::Basic(value::BasicValue::Int64(2)) + .into_key() + .unwrap(); + + ktable_data.insert(key1, row1_scope_val.clone()); + ktable_data.insert(key2, row2_scope_val.clone()); + + let ktable_val = value::Value::KTable(ktable_data); + let ktable_typ = schema::ValueType::Table(ktable_schema); + assert_roundtrip_conversion(&ktable_val, &ktable_typ); } } diff --git a/src/server.rs b/src/server.rs index ef2c1853..988c318e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use tower_http::{ trace::TraceLayer, }; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct ServerSettings { pub address: String, #[serde(default)] diff --git a/src/settings.rs b/src/settings.rs index 2cbcf146..350ec7ba 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,13 +1,13 @@ use serde::Deserialize; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct DatabaseConnectionSpec { pub url: String, pub user: Option, pub password: Option, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct Settings { pub database: DatabaseConnectionSpec, } From 1791a469b729933d702796586fffd21dc477cb65 Mon Sep 17 00:00:00 2001 From: Abhishek Tripathi Date: Wed, 14 May 2025 12:43:53 +0530 Subject: [PATCH 3/3] chore: cleanup --- src/py/convert.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/py/convert.rs b/src/py/convert.rs index 4267f4b6..a2347650 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -231,13 +231,12 @@ pub fn value_from_py_object<'py>( Ok(result) } -// `value_from_py_object` and `value_to_py_object` are functions internal to this module. `Pythonized` is the API exposed by the module. Ideally we want to test on the behavior of the public API. #[cfg(test)] mod tests { - use super::*; // To bring Pythonized into scope + use super::*; use crate::base::schema; use crate::base::value; - use crate::base::value::ScopeValue; // Changed import from GeneralValueSpec + use crate::base::value::ScopeValue; use pyo3::Python; use std::collections::BTreeMap; use std::sync::Arc; @@ -251,9 +250,6 @@ mod tests { }); println!("Python object: {:?}", py_object); - // Convert Python object back to Rust value - // let roundtripped_value = Pythonized::::extract_bound(&py_object) - // .unwrap_or_else(|e| panic!("Failed to convert Python object back to Rust value: {:?}", e)); let roundtripped_value = value_from_py_object(value_type, &py_object).unwrap_or_else(|e| { panic!(