This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new bb10a1be Fix unit tests in tpch.rs (#1195)
bb10a1be is described below
commit bb10a1bebd52ebb91515efa7a2a977df740c2d7a
Author: Ming Chen <[email protected]>
AuthorDate: Thu Mar 20 02:30:43 2025 -0400
Fix unit tests in tpch.rs (#1195)
---
benchmarks/queries/q10.sql | 3 +-
benchmarks/queries/q18.sql | 3 +-
benchmarks/queries/q2.sql | 3 +-
benchmarks/queries/q21.sql | 3 +-
benchmarks/queries/q3.sql | 3 +-
benchmarks/src/bin/tpch.rs | 168 +++++++++++++++++++++++++++++++++++----------
6 files changed, 142 insertions(+), 41 deletions(-)
diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql
index cf45e434..ef48cafe 100644
--- a/benchmarks/queries/q10.sql
+++ b/benchmarks/queries/q10.sql
@@ -28,4 +28,5 @@ group by
c_address,
c_comment
order by
- revenue desc;
\ No newline at end of file
+ revenue desc
+limit 20;
\ No newline at end of file
diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql
index 835de28a..c3da5b76 100644
--- a/benchmarks/queries/q18.sql
+++ b/benchmarks/queries/q18.sql
@@ -29,4 +29,5 @@ group by
o_totalprice
order by
o_totalprice desc,
- o_orderdate;
\ No newline at end of file
+ o_orderdate
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql
index f66af210..4927ff6e 100644
--- a/benchmarks/queries/q2.sql
+++ b/benchmarks/queries/q2.sql
@@ -40,4 +40,5 @@ order by
s_acctbal desc,
n_name,
s_name,
- p_partkey;
\ No newline at end of file
+ p_partkey
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql
index 9d2fe32c..f318f938 100644
--- a/benchmarks/queries/q21.sql
+++ b/benchmarks/queries/q21.sql
@@ -36,4 +36,5 @@ group by
s_name
order by
numwait desc,
- s_name;
\ No newline at end of file
+ s_name
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql
index 7dbc6d9e..601f6fe5 100644
--- a/benchmarks/queries/q3.sql
+++ b/benchmarks/queries/q3.sql
@@ -19,4 +19,5 @@ group by
o_shippriority
order by
revenue desc,
- o_orderdate;
\ No newline at end of file
+ o_orderdate
+limit 10;
\ No newline at end of file
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index a3928d0b..16a59172 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -798,37 +798,55 @@ async fn get_table(
table_format: &str,
target_partitions: usize,
) -> Result<Arc<dyn TableProvider>> {
- let (format, path, extension): (Arc<dyn FileFormat>, String, &'static str)
=
- match table_format {
- // dbgen creates .tbl ('|' delimited) files without header
- "tbl" => {
- let path = format!("{path}/{table}.tbl");
-
- let format = CsvFormat::default()
- .with_delimiter(b'|')
- .with_has_header(false);
-
- (Arc::new(format), path, ".tbl")
- }
- "csv" => {
- let path = format!("{path}/{table}");
- let format = CsvFormat::default()
- .with_delimiter(b',')
- .with_has_header(true);
-
- (Arc::new(format), path, DEFAULT_CSV_EXTENSION)
- }
- "parquet" => {
- let path = format!("{path}/{table}");
- let format =
ParquetFormat::default().with_enable_pruning(true);
-
- (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION)
- }
- other => {
- unimplemented!("Invalid file format '{}'", other);
- }
- };
- let schema = Arc::new(get_schema(table));
+ let (format, path, extension, schema): (
+ Arc<dyn FileFormat>,
+ String,
+ &'static str,
+ Schema,
+ ) = match table_format {
+ // dbgen creates .tbl ('|' delimited) files without header
+ "tbl" => {
+ let path = format!("{path}/{table}.tbl");
+
+ let format = CsvFormat::default()
+ .with_delimiter(b'|')
+ .with_has_header(false);
+
+ (
+ Arc::new(format),
+ path,
+ ".tbl",
+ get_tbl_tpch_table_schema(table),
+ )
+ }
+ "csv" => {
+ let path = format!("{path}/{table}");
+ let format = CsvFormat::default()
+ .with_delimiter(b',')
+ .with_has_header(true);
+
+ (
+ Arc::new(format),
+ path,
+ DEFAULT_CSV_EXTENSION,
+ get_schema(table),
+ )
+ }
+ "parquet" => {
+ let path = format!("{path}/{table}");
+ let format = ParquetFormat::default().with_enable_pruning(true);
+
+ (
+ Arc::new(format),
+ path,
+ DEFAULT_PARQUET_EXTENSION,
+ get_schema(table),
+ )
+ }
+ other => {
+ unimplemented!("Invalid file format '{}'", other);
+ }
+ };
let options = ListingOptions {
format,
@@ -845,7 +863,7 @@ async fn get_table(
let config = if table_format == "parquet" {
config.infer_schema(ctx).await?
} else {
- config.with_schema(schema)
+ config.with_schema(Arc::new(schema))
};
Ok(Arc::new(ListingTable::try_new(config)?))
@@ -1138,18 +1156,18 @@ fn get_answer_schema(n: usize) -> Schema {
7 => Schema::new(vec![
Field::new("supp_nation", DataType::Utf8, true),
Field::new("cust_nation", DataType::Utf8, true),
- Field::new("l_year", DataType::Float64, true),
+ Field::new("l_year", DataType::Int32, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
]),
8 => Schema::new(vec![
- Field::new("o_year", DataType::Float64, true),
+ Field::new("o_year", DataType::Int32, true),
Field::new("mkt_share", DataType::Decimal128(15, 2), true),
]),
9 => Schema::new(vec![
Field::new("nation", DataType::Utf8, true),
- Field::new("o_year", DataType::Float64, true),
+ Field::new("o_year", DataType::Int32, true),
Field::new("sum_profit", DataType::Decimal128(15, 2), true),
]),
@@ -1358,6 +1376,7 @@ mod tests {
verify_query(14).await
}
+ #[ignore] // TODO: support multiline queries
#[tokio::test]
async fn q15() -> Result<()> {
verify_query(15).await
@@ -1368,6 +1387,15 @@ mod tests {
verify_query(16).await
}
+ // Python code to reproduce the "348406.05" result in DuckDB:
+ // ```python
+ // import duckdb
+ // lineitem = duckdb.read_csv("data/lineitem.tbl",
columns={'l_orderkey':'int64', 'l_partkey':'int64', 'l_suppkey':'int64',
'l_linenumber':'int64', 'l_quantity':'int64',
'l_extendedprice':'decimal(15,2)', 'l_discount':'decimal(15,2)',
'l_tax':'decimal(15,2)', 'l_returnflag':'varchar','l_linestatus':'varchar',
'l_shipdate':'date', 'l_commitdate':'date', 'l_receiptdate':'date',
'l_shipinstruct':'varchar', 'l_shipmode':'varchar', 'l_comment':'varchar'})
+ // part = duckdb.read_csv("data/part.tbl", columns={'p_partkey':'int64',
'p_name':'varchar', 'p_mfgr':'varchar', 'p_brand':'varchar',
'p_type':'varchar', 'p_size':'int64', 'p_container':'varchar',
'p_retailprice':'double', 'p_comment':'varchar'})
+ // duckdb.sql("select sum(l_extendedprice) / 7.0 as avg_yearly from
lineitem, part where p_partkey = l_partkey and p_brand = 'Brand#23' and
p_container = 'MED BOX' and l_quantity < (select 0.2 * avg(l_quantity) from
lineitem where l_partkey = p_partkey )")
+ // ```
+ // That is the same as DataFusion's output.
+ #[ignore = "the expected result is 348406.02 whereas both DataFusion and
DuckDB return 348406.05"]
#[tokio::test]
async fn q17() -> Result<()> {
verify_query(17).await
@@ -1534,6 +1562,72 @@ mod tests {
Ok(())
}
+ // We read the expected results from CSV files so we need to normalize the
+ // query results before we compare them with the expected results for the
+ // following reasons:
+ //
+ // 1. Float numbers have only two digits after the decimal point in CSV so
+ // we need to convert results to Decimal(15, 2) and then back to floats.
+ //
+ // 2. Decimal numbers are fixed as Decimal(15, 2) in CSV.
+ //
+ // 3. Strings may have trailing spaces and need to be trimmed.
+ //
+ // 4. Rename columns using the expected schema to make schema matching
+ // because, for q18, we have aggregate field `sum(l_quantity)` that is
+ // called `sum_l_quantity` in the expected results.
+ async fn normalize_for_verification(
+ batches: Vec<RecordBatch>,
+ expected_schema: Schema,
+ ) -> Result<Vec<RecordBatch>> {
+ if batches.is_empty() {
+ return Ok(vec![]);
+ }
+ let ctx = SessionContext::new();
+ let schema = batches[0].schema();
+ let df = ctx.read_batches(batches)?;
+ let df = df.select(
+ schema
+ .fields()
+ .iter()
+ .zip(expected_schema.fields())
+ .map(|(field, expected_field)| {
+ match Field::data_type(field) {
+ // Normalize decimals to Decimal(15, 2)
+ DataType::Decimal128(_, _) => {
+ // We convert to float64 and then to decimal(15,
2).
+ // Directly converting between Decimals caused test
+ // failures.
+ let inner_cast = Box::new(Expr::Cast(Cast::new(
+ Box::new(col(Field::name(field))),
+ DataType::Float64,
+ )));
+ Expr::Cast(Cast::new(inner_cast,
DataType::Decimal128(15, 2)))
+ .alias(Field::name(expected_field))
+ }
+ // Normalize floats to have 2 digits after the decimal
point
+ DataType::Float64 => {
+ let inner_cast = Box::new(Expr::Cast(Cast::new(
+ Box::new(col(Field::name(field))),
+ DataType::Decimal128(15, 2),
+ )));
+ Expr::Cast(Cast::new(inner_cast,
DataType::Float64))
+ .alias(Field::name(expected_field))
+ }
+ // Normalize strings by trimming trailing spaces.
+ DataType::Utf8 => Expr::Cast(Cast::new(
+ Box::new(trim(vec![col(Field::name(field))])),
+ Field::data_type(field).to_owned(),
+ ))
+ .alias(Field::name(field)),
+ _ => col(Field::name(expected_field)),
+ }
+ })
+ .collect::<Vec<Expr>>(),
+ )?;
+ df.collect().await
+ }
+
async fn verify_query(n: usize) -> Result<()> {
if let Ok(path) = env::var("TPCH_DATA") {
// load expected answers from tpch-dbgen
@@ -1554,8 +1648,10 @@ mod tests {
output_path: None,
};
let actual = benchmark_datafusion(opt).await?;
+ let expected_schema = get_answer_schema(n);
+ let normalized = normalize_for_verification(actual,
expected_schema).await?;
- assert_expected_results(&expected, &actual)
+ assert_expected_results(&expected, &normalized)
} else {
println!("TPCH_DATA environment variable not set, skipping test");
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]