Skip to content

Commit bb10a1b

Browse files
authored
Fix unit tests in tpch.rs (#1195)
1 parent f621894 commit bb10a1b

File tree

6 files changed

+142
-41
lines changed

6 files changed

+142
-41
lines changed

benchmarks/queries/q10.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ group by
2828
c_address,
2929
c_comment
3030
order by
31-
revenue desc;
31+
revenue desc
32+
limit 20;

benchmarks/queries/q18.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ group by
2929
o_totalprice
3030
order by
3131
o_totalprice desc,
32-
o_orderdate;
32+
o_orderdate
33+
limit 100;

benchmarks/queries/q2.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ order by
4040
s_acctbal desc,
4141
n_name,
4242
s_name,
43-
p_partkey;
43+
p_partkey
44+
limit 100;

benchmarks/queries/q21.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ group by
3636
s_name
3737
order by
3838
numwait desc,
39-
s_name;
39+
s_name
40+
limit 100;

benchmarks/queries/q3.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ group by
1919
o_shippriority
2020
order by
2121
revenue desc,
22-
o_orderdate;
22+
o_orderdate
23+
limit 10;

benchmarks/src/bin/tpch.rs

Lines changed: 132 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -798,37 +798,55 @@ async fn get_table(
798798
table_format: &str,
799799
target_partitions: usize,
800800
) -> Result<Arc<dyn TableProvider>> {
801-
let (format, path, extension): (Arc<dyn FileFormat>, String, &'static str) =
802-
match table_format {
803-
// dbgen creates .tbl ('|' delimited) files without header
804-
"tbl" => {
805-
let path = format!("{path}/{table}.tbl");
806-
807-
let format = CsvFormat::default()
808-
.with_delimiter(b'|')
809-
.with_has_header(false);
810-
811-
(Arc::new(format), path, ".tbl")
812-
}
813-
"csv" => {
814-
let path = format!("{path}/{table}");
815-
let format = CsvFormat::default()
816-
.with_delimiter(b',')
817-
.with_has_header(true);
818-
819-
(Arc::new(format), path, DEFAULT_CSV_EXTENSION)
820-
}
821-
"parquet" => {
822-
let path = format!("{path}/{table}");
823-
let format = ParquetFormat::default().with_enable_pruning(true);
824-
825-
(Arc::new(format), path, DEFAULT_PARQUET_EXTENSION)
826-
}
827-
other => {
828-
unimplemented!("Invalid file format '{}'", other);
829-
}
830-
};
831-
let schema = Arc::new(get_schema(table));
801+
let (format, path, extension, schema): (
802+
Arc<dyn FileFormat>,
803+
String,
804+
&'static str,
805+
Schema,
806+
) = match table_format {
807+
// dbgen creates .tbl ('|' delimited) files without header
808+
"tbl" => {
809+
let path = format!("{path}/{table}.tbl");
810+
811+
let format = CsvFormat::default()
812+
.with_delimiter(b'|')
813+
.with_has_header(false);
814+
815+
(
816+
Arc::new(format),
817+
path,
818+
".tbl",
819+
get_tbl_tpch_table_schema(table),
820+
)
821+
}
822+
"csv" => {
823+
let path = format!("{path}/{table}");
824+
let format = CsvFormat::default()
825+
.with_delimiter(b',')
826+
.with_has_header(true);
827+
828+
(
829+
Arc::new(format),
830+
path,
831+
DEFAULT_CSV_EXTENSION,
832+
get_schema(table),
833+
)
834+
}
835+
"parquet" => {
836+
let path = format!("{path}/{table}");
837+
let format = ParquetFormat::default().with_enable_pruning(true);
838+
839+
(
840+
Arc::new(format),
841+
path,
842+
DEFAULT_PARQUET_EXTENSION,
843+
get_schema(table),
844+
)
845+
}
846+
other => {
847+
unimplemented!("Invalid file format '{}'", other);
848+
}
849+
};
832850

833851
let options = ListingOptions {
834852
format,
@@ -845,7 +863,7 @@ async fn get_table(
845863
let config = if table_format == "parquet" {
846864
config.infer_schema(ctx).await?
847865
} else {
848-
config.with_schema(schema)
866+
config.with_schema(Arc::new(schema))
849867
};
850868

851869
Ok(Arc::new(ListingTable::try_new(config)?))
@@ -1138,18 +1156,18 @@ fn get_answer_schema(n: usize) -> Schema {
11381156
7 => Schema::new(vec![
11391157
Field::new("supp_nation", DataType::Utf8, true),
11401158
Field::new("cust_nation", DataType::Utf8, true),
1141-
Field::new("l_year", DataType::Float64, true),
1159+
Field::new("l_year", DataType::Int32, true),
11421160
Field::new("revenue", DataType::Decimal128(15, 2), true),
11431161
]),
11441162

11451163
8 => Schema::new(vec![
1146-
Field::new("o_year", DataType::Float64, true),
1164+
Field::new("o_year", DataType::Int32, true),
11471165
Field::new("mkt_share", DataType::Decimal128(15, 2), true),
11481166
]),
11491167

11501168
9 => Schema::new(vec![
11511169
Field::new("nation", DataType::Utf8, true),
1152-
Field::new("o_year", DataType::Float64, true),
1170+
Field::new("o_year", DataType::Int32, true),
11531171
Field::new("sum_profit", DataType::Decimal128(15, 2), true),
11541172
]),
11551173

@@ -1358,6 +1376,7 @@ mod tests {
13581376
verify_query(14).await
13591377
}
13601378

1379+
#[ignore] // TODO: support multiline queries
13611380
#[tokio::test]
13621381
async fn q15() -> Result<()> {
13631382
verify_query(15).await
@@ -1368,6 +1387,15 @@ mod tests {
13681387
verify_query(16).await
13691388
}
13701389

1390+
// Python code to reproduce the "348406.05" result in DuckDB:
1391+
// ```python
1392+
// import duckdb
1393+
// lineitem = duckdb.read_csv("data/lineitem.tbl", columns={'l_orderkey':'int64', 'l_partkey':'int64', 'l_suppkey':'int64', 'l_linenumber':'int64', 'l_quantity':'int64', 'l_extendedprice':'decimal(15,2)', 'l_discount':'decimal(15,2)', 'l_tax':'decimal(15,2)', 'l_returnflag':'varchar','l_linestatus':'varchar', 'l_shipdate':'date', 'l_commitdate':'date', 'l_receiptdate':'date', 'l_shipinstruct':'varchar', 'l_shipmode':'varchar', 'l_comment':'varchar'})
1394+
// part = duckdb.read_csv("data/part.tbl", columns={'p_partkey':'int64', 'p_name':'varchar', 'p_mfgr':'varchar', 'p_brand':'varchar', 'p_type':'varchar', 'p_size':'int64', 'p_container':'varchar', 'p_retailprice':'double', 'p_comment':'varchar'})
1395+
// duckdb.sql("select sum(l_extendedprice) / 7.0 as avg_yearly from lineitem, part where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX' and l_quantity < (select 0.2 * avg(l_quantity) from lineitem where l_partkey = p_partkey )")
1396+
// ```
1397+
// That is the same as DataFusion's output.
1398+
#[ignore = "the expected result is 348406.02 whereas both DataFusion and DuckDB return 348406.05"]
13711399
#[tokio::test]
13721400
async fn q17() -> Result<()> {
13731401
verify_query(17).await
@@ -1534,6 +1562,72 @@ mod tests {
15341562
Ok(())
15351563
}
15361564

1565+
// We read the expected results from CSV files so we need to normalize the
1566+
// query results before we compare them with the expected results for the
1567+
// following reasons:
1568+
//
1569+
// 1. Float numbers have only two digits after the decimal point in CSV so
1570+
// we need to convert results to Decimal(15, 2) and then back to floats.
1571+
//
1572+
// 2. Decimal numbers are fixed as Decimal(15, 2) in CSV.
1573+
//
1574+
// 3. Strings may have trailing spaces and need to be trimmed.
1575+
//
1576+
// 4. Rename columns using the expected schema to make schema matching
1577+
// because, for q18, we have aggregate field `sum(l_quantity)` that is
1578+
// called `sum_l_quantity` in the expected results.
1579+
async fn normalize_for_verification(
1580+
batches: Vec<RecordBatch>,
1581+
expected_schema: Schema,
1582+
) -> Result<Vec<RecordBatch>> {
1583+
if batches.is_empty() {
1584+
return Ok(vec![]);
1585+
}
1586+
let ctx = SessionContext::new();
1587+
let schema = batches[0].schema();
1588+
let df = ctx.read_batches(batches)?;
1589+
let df = df.select(
1590+
schema
1591+
.fields()
1592+
.iter()
1593+
.zip(expected_schema.fields())
1594+
.map(|(field, expected_field)| {
1595+
match Field::data_type(field) {
1596+
// Normalize decimals to Decimal(15, 2)
1597+
DataType::Decimal128(_, _) => {
1598+
// We convert to float64 and then to decimal(15, 2).
1599+
// Directly converting between Decimals caused test
1600+
// failures.
1601+
let inner_cast = Box::new(Expr::Cast(Cast::new(
1602+
Box::new(col(Field::name(field))),
1603+
DataType::Float64,
1604+
)));
1605+
Expr::Cast(Cast::new(inner_cast, DataType::Decimal128(15, 2)))
1606+
.alias(Field::name(expected_field))
1607+
}
1608+
// Normalize floats to have 2 digits after the decimal point
1609+
DataType::Float64 => {
1610+
let inner_cast = Box::new(Expr::Cast(Cast::new(
1611+
Box::new(col(Field::name(field))),
1612+
DataType::Decimal128(15, 2),
1613+
)));
1614+
Expr::Cast(Cast::new(inner_cast, DataType::Float64))
1615+
.alias(Field::name(expected_field))
1616+
}
1617+
// Normalize strings by trimming trailing spaces.
1618+
DataType::Utf8 => Expr::Cast(Cast::new(
1619+
Box::new(trim(vec![col(Field::name(field))])),
1620+
Field::data_type(field).to_owned(),
1621+
))
1622+
.alias(Field::name(field)),
1623+
_ => col(Field::name(expected_field)),
1624+
}
1625+
})
1626+
.collect::<Vec<Expr>>(),
1627+
)?;
1628+
df.collect().await
1629+
}
1630+
15371631
async fn verify_query(n: usize) -> Result<()> {
15381632
if let Ok(path) = env::var("TPCH_DATA") {
15391633
// load expected answers from tpch-dbgen
@@ -1554,8 +1648,10 @@ mod tests {
15541648
output_path: None,
15551649
};
15561650
let actual = benchmark_datafusion(opt).await?;
1651+
let expected_schema = get_answer_schema(n);
1652+
let normalized = normalize_for_verification(actual, expected_schema).await?;
15571653

1558-
assert_expected_results(&expected, &actual)
1654+
assert_expected_results(&expected, &normalized)
15591655
} else {
15601656
println!("TPCH_DATA environment variable not set, skipping test");
15611657
}

0 commit comments

Comments
 (0)