This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new bb10a1be Fix unit tests in tpch.rs (#1195)
bb10a1be is described below

commit bb10a1bebd52ebb91515efa7a2a977df740c2d7a
Author: Ming Chen <[email protected]>
AuthorDate: Thu Mar 20 02:30:43 2025 -0400

    Fix unit tests in tpch.rs (#1195)
---
 benchmarks/queries/q10.sql |   3 +-
 benchmarks/queries/q18.sql |   3 +-
 benchmarks/queries/q2.sql  |   3 +-
 benchmarks/queries/q21.sql |   3 +-
 benchmarks/queries/q3.sql  |   3 +-
 benchmarks/src/bin/tpch.rs | 168 +++++++++++++++++++++++++++++++++++----------
 6 files changed, 142 insertions(+), 41 deletions(-)

diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql
index cf45e434..ef48cafe 100644
--- a/benchmarks/queries/q10.sql
+++ b/benchmarks/queries/q10.sql
@@ -28,4 +28,5 @@ group by
     c_address,
     c_comment
 order by
-    revenue desc;
\ No newline at end of file
+    revenue desc
+limit 20;
\ No newline at end of file
diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql
index 835de28a..c3da5b76 100644
--- a/benchmarks/queries/q18.sql
+++ b/benchmarks/queries/q18.sql
@@ -29,4 +29,5 @@ group by
     o_totalprice
 order by
     o_totalprice desc,
-    o_orderdate;
\ No newline at end of file
+    o_orderdate
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql
index f66af210..4927ff6e 100644
--- a/benchmarks/queries/q2.sql
+++ b/benchmarks/queries/q2.sql
@@ -40,4 +40,5 @@ order by
     s_acctbal desc,
     n_name,
     s_name,
-    p_partkey;
\ No newline at end of file
+    p_partkey
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql
index 9d2fe32c..f318f938 100644
--- a/benchmarks/queries/q21.sql
+++ b/benchmarks/queries/q21.sql
@@ -36,4 +36,5 @@ group by
     s_name
 order by
     numwait desc,
-    s_name;
\ No newline at end of file
+    s_name
+limit 100;
\ No newline at end of file
diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql
index 7dbc6d9e..601f6fe5 100644
--- a/benchmarks/queries/q3.sql
+++ b/benchmarks/queries/q3.sql
@@ -19,4 +19,5 @@ group by
     o_shippriority
 order by
     revenue desc,
-    o_orderdate;
\ No newline at end of file
+    o_orderdate
+limit 10;
\ No newline at end of file
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index a3928d0b..16a59172 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -798,37 +798,55 @@ async fn get_table(
     table_format: &str,
     target_partitions: usize,
 ) -> Result<Arc<dyn TableProvider>> {
-    let (format, path, extension): (Arc<dyn FileFormat>, String, &'static str) 
=
-        match table_format {
-            // dbgen creates .tbl ('|' delimited) files without header
-            "tbl" => {
-                let path = format!("{path}/{table}.tbl");
-
-                let format = CsvFormat::default()
-                    .with_delimiter(b'|')
-                    .with_has_header(false);
-
-                (Arc::new(format), path, ".tbl")
-            }
-            "csv" => {
-                let path = format!("{path}/{table}");
-                let format = CsvFormat::default()
-                    .with_delimiter(b',')
-                    .with_has_header(true);
-
-                (Arc::new(format), path, DEFAULT_CSV_EXTENSION)
-            }
-            "parquet" => {
-                let path = format!("{path}/{table}");
-                let format = 
ParquetFormat::default().with_enable_pruning(true);
-
-                (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION)
-            }
-            other => {
-                unimplemented!("Invalid file format '{}'", other);
-            }
-        };
-    let schema = Arc::new(get_schema(table));
+    let (format, path, extension, schema): (
+        Arc<dyn FileFormat>,
+        String,
+        &'static str,
+        Schema,
+    ) = match table_format {
+        // dbgen creates .tbl ('|' delimited) files without header
+        "tbl" => {
+            let path = format!("{path}/{table}.tbl");
+
+            let format = CsvFormat::default()
+                .with_delimiter(b'|')
+                .with_has_header(false);
+
+            (
+                Arc::new(format),
+                path,
+                ".tbl",
+                get_tbl_tpch_table_schema(table),
+            )
+        }
+        "csv" => {
+            let path = format!("{path}/{table}");
+            let format = CsvFormat::default()
+                .with_delimiter(b',')
+                .with_has_header(true);
+
+            (
+                Arc::new(format),
+                path,
+                DEFAULT_CSV_EXTENSION,
+                get_schema(table),
+            )
+        }
+        "parquet" => {
+            let path = format!("{path}/{table}");
+            let format = ParquetFormat::default().with_enable_pruning(true);
+
+            (
+                Arc::new(format),
+                path,
+                DEFAULT_PARQUET_EXTENSION,
+                get_schema(table),
+            )
+        }
+        other => {
+            unimplemented!("Invalid file format '{}'", other);
+        }
+    };
 
     let options = ListingOptions {
         format,
@@ -845,7 +863,7 @@ async fn get_table(
     let config = if table_format == "parquet" {
         config.infer_schema(ctx).await?
     } else {
-        config.with_schema(schema)
+        config.with_schema(Arc::new(schema))
     };
 
     Ok(Arc::new(ListingTable::try_new(config)?))
@@ -1138,18 +1156,18 @@ fn get_answer_schema(n: usize) -> Schema {
         7 => Schema::new(vec![
             Field::new("supp_nation", DataType::Utf8, true),
             Field::new("cust_nation", DataType::Utf8, true),
-            Field::new("l_year", DataType::Float64, true),
+            Field::new("l_year", DataType::Int32, true),
             Field::new("revenue", DataType::Decimal128(15, 2), true),
         ]),
 
         8 => Schema::new(vec![
-            Field::new("o_year", DataType::Float64, true),
+            Field::new("o_year", DataType::Int32, true),
             Field::new("mkt_share", DataType::Decimal128(15, 2), true),
         ]),
 
         9 => Schema::new(vec![
             Field::new("nation", DataType::Utf8, true),
-            Field::new("o_year", DataType::Float64, true),
+            Field::new("o_year", DataType::Int32, true),
             Field::new("sum_profit", DataType::Decimal128(15, 2), true),
         ]),
 
@@ -1358,6 +1376,7 @@ mod tests {
         verify_query(14).await
     }
 
+    #[ignore] // TODO: support multiline queries
     #[tokio::test]
     async fn q15() -> Result<()> {
         verify_query(15).await
@@ -1368,6 +1387,15 @@ mod tests {
         verify_query(16).await
     }
 
+    // Python code to reproduce the "348406.05" result in DuckDB:
+    // ```python
+    // import duckdb
+    // lineitem = duckdb.read_csv("data/lineitem.tbl", 
columns={'l_orderkey':'int64', 'l_partkey':'int64', 'l_suppkey':'int64', 
'l_linenumber':'int64', 'l_quantity':'int64', 
'l_extendedprice':'decimal(15,2)', 'l_discount':'decimal(15,2)', 
'l_tax':'decimal(15,2)', 'l_returnflag':'varchar','l_linestatus':'varchar', 
'l_shipdate':'date', 'l_commitdate':'date', 'l_receiptdate':'date', 
'l_shipinstruct':'varchar', 'l_shipmode':'varchar', 'l_comment':'varchar'})
+    // part = duckdb.read_csv("data/part.tbl", columns={'p_partkey':'int64', 
'p_name':'varchar', 'p_mfgr':'varchar', 'p_brand':'varchar', 
'p_type':'varchar', 'p_size':'int64', 'p_container':'varchar', 
'p_retailprice':'double', 'p_comment':'varchar'})
+    // duckdb.sql("select sum(l_extendedprice) / 7.0 as avg_yearly from 
lineitem, part where p_partkey = l_partkey and p_brand = 'Brand#23' and 
p_container = 'MED BOX' and l_quantity < (select 0.2 * avg(l_quantity) from 
lineitem where l_partkey = p_partkey )")
+    // ```
+    // That is the same as DataFusion's output.
+    #[ignore = "the expected result is 348406.02 whereas both DataFusion and 
DuckDB return 348406.05"]
     #[tokio::test]
     async fn q17() -> Result<()> {
         verify_query(17).await
@@ -1534,6 +1562,72 @@ mod tests {
         Ok(())
     }
 
+    // We read the expected results from CSV files so we need to normalize the
+    // query results before we compare them with the expected results for the
+    // following reasons:
+    //
+    // 1. Float numbers have only two digits after the decimal point in CSV so
+    // we need to convert results to Decimal(15, 2) and then back to floats.
+    //
+    // 2. Decimal numbers are fixed as Decimal(15, 2) in CSV.
+    //
+    // 3. Strings may have trailing spaces and need to be trimmed.
+    //
+    // 4. Rename columns using the expected schema to make schema matching
+    // because, for q18, we have aggregate field `sum(l_quantity)` that is
+    // called `sum_l_quantity` in the expected results.
+    async fn normalize_for_verification(
+        batches: Vec<RecordBatch>,
+        expected_schema: Schema,
+    ) -> Result<Vec<RecordBatch>> {
+        if batches.is_empty() {
+            return Ok(vec![]);
+        }
+        let ctx = SessionContext::new();
+        let schema = batches[0].schema();
+        let df = ctx.read_batches(batches)?;
+        let df = df.select(
+            schema
+                .fields()
+                .iter()
+                .zip(expected_schema.fields())
+                .map(|(field, expected_field)| {
+                    match Field::data_type(field) {
+                        // Normalize decimals to Decimal(15, 2)
+                        DataType::Decimal128(_, _) => {
+                            // We convert to float64 and then to decimal(15, 
2).
+                            // Directly converting between Decimals caused test
+                            // failures.
+                            let inner_cast = Box::new(Expr::Cast(Cast::new(
+                                Box::new(col(Field::name(field))),
+                                DataType::Float64,
+                            )));
+                            Expr::Cast(Cast::new(inner_cast, 
DataType::Decimal128(15, 2)))
+                                .alias(Field::name(expected_field))
+                        }
+                        // Normalize floats to have 2 digits after the decimal 
point
+                        DataType::Float64 => {
+                            let inner_cast = Box::new(Expr::Cast(Cast::new(
+                                Box::new(col(Field::name(field))),
+                                DataType::Decimal128(15, 2),
+                            )));
+                            Expr::Cast(Cast::new(inner_cast, 
DataType::Float64))
+                                .alias(Field::name(expected_field))
+                        }
+                        // Normalize strings by trimming trailing spaces.
+                        DataType::Utf8 => Expr::Cast(Cast::new(
+                            Box::new(trim(vec![col(Field::name(field))])),
+                            Field::data_type(field).to_owned(),
+                        ))
+                        .alias(Field::name(field)),
+                        _ => col(Field::name(expected_field)),
+                    }
+                })
+                .collect::<Vec<Expr>>(),
+        )?;
+        df.collect().await
+    }
+
     async fn verify_query(n: usize) -> Result<()> {
         if let Ok(path) = env::var("TPCH_DATA") {
             // load expected answers from tpch-dbgen
@@ -1554,8 +1648,10 @@ mod tests {
                 output_path: None,
             };
             let actual = benchmark_datafusion(opt).await?;
+            let expected_schema = get_answer_schema(n);
+            let normalized = normalize_for_verification(actual, 
expected_schema).await?;
 
-            assert_expected_results(&expected, &actual)
+            assert_expected_results(&expected, &normalized)
         } else {
             println!("TPCH_DATA environment variable not set, skipping test");
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to