diff --git a/dt-connector/src/rdb_router.rs b/dt-connector/src/rdb_router.rs index 97d9d2a2..26935147 100644 --- a/dt-connector/src/rdb_router.rs +++ b/dt-connector/src/rdb_router.rs @@ -10,6 +10,7 @@ use dt_common::{ utils::sql_util::SqlUtil, }; use std::collections::HashMap; +use regex::Regex; use dt_common::meta::{col_value::ColValue, row_data::RowData}; use serde::{Deserialize, Serialize}; @@ -63,9 +64,21 @@ impl RdbRouter { } pub fn get_tb_map<'a>(&'a self, schema: &'a str, tb: &'a str) -> (&'a str, &'a str) { + // First try exact match if let Some((dst_schema, dst_tb)) = self.tb_map.get(&(schema.into(), tb.into())) { return (dst_schema, dst_tb); } + + // Then try wildcard match + let escape_pairs = SqlUtil::get_escape_pairs(&self.db_type); + for ((src_schema, src_tb), (dst_schema, dst_tb)) in &self.tb_map { + if Self::match_token(src_schema, schema, &escape_pairs) + && Self::match_token(src_tb, tb, &escape_pairs) { + return (dst_schema, dst_tb); + } + } + + // Finally fallback to schema map if let Some(dst_schema) = self.schema_map.get(schema) { return (dst_schema, tb); } @@ -297,6 +310,28 @@ impl RdbRouter { } Ok(results) } + + fn match_token(pattern: &str, item: &str, escape_pairs: &[(char, char)]) -> bool { + // if pattern is enclosed by escapes, it is considered as exactly match + for escape_pair in escape_pairs.iter() { + if SqlUtil::is_escaped(pattern, escape_pair) { + return pattern == SqlUtil::escape(item, escape_pair); + } + } + // only support 2 wildchars : '*' and '?', '.' is NOT supported + // * : matching mutiple chars + // ? : for matching 0-1 chars + let mut pattern = pattern + .replace('.', "\\.") + .replace('*', ".*") + .replace('?', ".?"); + pattern = format!(r"^{}$", pattern); + + Regex::new(&pattern) + .with_context(|| format!("invalid filter pattern: [{}]", pattern)) + .unwrap() + .is_match(item) + } } #[cfg(test)] @@ -450,4 +485,80 @@ mod tests { assert_eq!(router.get_topic("db:1", "tb:2"), "test2"); assert_eq!(router.get_topic("db:2", "tb:1"), "test"); } + + #[test] + fn test_tb_map_wildcard_matching() { + let config = RouterConfig::Rdb { + schema_map: "test_db_1:dst_test_db_1".to_string(), + tb_map: "test_db_2.*:dst_test_db_2.*,test_db_3.tb_*:dst_test_db_3.dst_tb_*".to_string(), + col_map: "".to_string(), + topic_map: "".to_string(), + }; + + let router = RdbRouter::from_config(&config, &DbType::Mysql).unwrap(); + + // Test exact match + assert_eq!( + router.get_tb_map("test_db_2", "tb_1"), + ("dst_test_db_2", "tb_1") + ); + assert_eq!( + router.get_tb_map("test_db_2", "tb_2"), + ("dst_test_db_2", "tb_2") + ); + + // Test wildcard match + assert_eq!( + router.get_tb_map("test_db_3", "tb_1"), + ("dst_test_db_3", "dst_tb_1") + ); + assert_eq!( + router.get_tb_map("test_db_3", "tb_2"), + ("dst_test_db_3", "dst_tb_2") + ); + + // Test fallback to schema map + assert_eq!( + router.get_tb_map("test_db_1", "tb_1"), + ("dst_test_db_1", "tb_1") + ); + + // Test no match + assert_eq!( + router.get_tb_map("test_db_4", "tb_1"), + ("test_db_4", "tb_1") + ); + } + + #[test] + fn test_tb_map_wildcard_matching_with_escapes() { + let config = RouterConfig::Rdb { + schema_map: "test_db_1:dst_test_db_1".to_string(), + tb_map: r#"`test_db_2`.*:dst_test_db_2.*,`test_db_3`.`tb_*`:dst_test_db_3.dst_tb_*"#.to_string(), + col_map: "".to_string(), + topic_map: "".to_string(), + }; + + let router = RdbRouter::from_config(&config, &DbType::Mysql).unwrap(); + + // Test exact match with escapes + assert_eq!( + router.get_tb_map("test_db_2", "tb_1"), + ("dst_test_db_2", "tb_1") + ); + assert_eq!( + router.get_tb_map("test_db_2", "tb_2"), + ("dst_test_db_2", "tb_2") + ); + + // Test wildcard match with escapes + assert_eq!( + router.get_tb_map("test_db_3", "tb_1"), + ("dst_test_db_3", "dst_tb_1") + ); + assert_eq!( + router.get_tb_map("test_db_3", "tb_2"), + ("dst_test_db_3", "dst_tb_2") + ); + } }