This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a8c1579ba update TPCH-mimicking tests to Decimal data type from Float, 
matching the benchmark (#3438)
a8c1579ba is described below

commit a8c1579baf403c2c556e632e4deacc4ac98f38df
Author: Kirk Mitchener <[email protected]>
AuthorDate: Mon Sep 12 09:13:01 2022 -0400

    update TPCH-mimicking tests to Decimal data type from Float, matching the 
benchmark (#3438)
    
    * update tests to Decimal data type from Float, matching the benchmark 
itself
    
    * add trim, remove sloppiness
---
 datafusion/core/tests/sql/explain_analyze.rs |  2 +-
 datafusion/core/tests/sql/mod.rs             | 35 ++++++++++++++++------------
 datafusion/core/tests/sql/predicates.rs      |  8 +++----
 datafusion/core/tests/sql/subqueries.rs      | 27 ++++++++++-----------
 4 files changed, 39 insertions(+), 33 deletions(-)

diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index d5509cf65..91dd9401e 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
     let expected = "\
     Sort: #revenue DESC NULLS FIRST\
     \n  Projection: #customer.c_custkey, #customer.c_name, 
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, 
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, 
#customer.c_comment\
-    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS 
Float64) - #lineitem.l_discount)]]\
+    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS 
Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - 
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
     \n      Inner Join: #customer.c_nationkey = #nation.n_nationkey\
     \n        Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
     \n          Inner Join: #customer.c_custkey = #orders.o_custkey\
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index ad4800b70..c16386c5d 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -453,7 +453,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("c_address", DataType::Utf8, false),
             Field::new("c_nationkey", DataType::Int64, false),
             Field::new("c_phone", DataType::Utf8, false),
-            Field::new("c_acctbal", DataType::Float64, false),
+            Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
             Field::new("c_mktsegment", DataType::Utf8, false),
             Field::new("c_comment", DataType::Utf8, false),
         ]),
@@ -462,7 +462,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("o_orderkey", DataType::Int64, false),
             Field::new("o_custkey", DataType::Int64, false),
             Field::new("o_orderstatus", DataType::Utf8, false),
-            Field::new("o_totalprice", DataType::Float64, false),
+            Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
             Field::new("o_orderdate", DataType::Date32, false),
             Field::new("o_orderpriority", DataType::Utf8, false),
             Field::new("o_clerk", DataType::Utf8, false),
@@ -475,10 +475,10 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("l_partkey", DataType::Int64, false),
             Field::new("l_suppkey", DataType::Int64, false),
             Field::new("l_linenumber", DataType::Int32, false),
-            Field::new("l_quantity", DataType::Float64, false),
-            Field::new("l_extendedprice", DataType::Float64, false),
-            Field::new("l_discount", DataType::Float64, false),
-            Field::new("l_tax", DataType::Float64, false),
+            Field::new("l_quantity", DataType::Decimal128(15, 2), false),
+            Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
+            Field::new("l_discount", DataType::Decimal128(15, 2), false),
+            Field::new("l_tax", DataType::Decimal128(15, 2), false),
             Field::new("l_returnflag", DataType::Utf8, false),
             Field::new("l_linestatus", DataType::Utf8, false),
             Field::new("l_shipdate", DataType::Date32, false),
@@ -502,7 +502,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("s_address", DataType::Utf8, false),
             Field::new("s_nationkey", DataType::Int64, false),
             Field::new("s_phone", DataType::Utf8, false),
-            Field::new("s_acctbal", DataType::Float64, false),
+            Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
             Field::new("s_comment", DataType::Utf8, false),
         ]),
 
@@ -510,7 +510,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("ps_partkey", DataType::Int64, false),
             Field::new("ps_suppkey", DataType::Int64, false),
             Field::new("ps_availqty", DataType::Int32, false),
-            Field::new("ps_supplycost", DataType::Float64, false),
+            Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
             Field::new("ps_comment", DataType::Utf8, false),
         ]),
 
@@ -522,7 +522,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
             Field::new("p_type", DataType::Utf8, false),
             Field::new("p_size", DataType::Int32, false),
             Field::new("p_container", DataType::Utf8, false),
-            Field::new("p_retailprice", DataType::Float64, false),
+            Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
             Field::new("p_comment", DataType::Utf8, false),
         ]),
 
@@ -573,9 +573,9 @@ async fn register_tpch_csv_data(
             DataType::Int64 => {
                 cols.push(Box::new(Int64Builder::with_capacity(records.len())))
             }
-            DataType::Float64 => {
-                
cols.push(Box::new(Float64Builder::with_capacity(records.len())))
-            }
+            DataType::Decimal128(p, s) => cols.push(Box::new(
+                Decimal128Builder::with_capacity(records.len(), *p, *s),
+            )),
             _ => {
                 let msg = format!("Not implemented: {}", field.data_type());
                 Err(DataFusionError::Plan(msg))?
@@ -606,9 +606,14 @@ async fn register_tpch_csv_data(
                     let sb = 
col.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
                     sb.append_value(val.trim().parse().unwrap());
                 }
-                DataType::Float64 => {
-                    let sb = 
col.as_any_mut().downcast_mut::<Float64Builder>().unwrap();
-                    sb.append_value(val.trim().parse().unwrap());
+                DataType::Decimal128(_, _) => {
+                    let sb = col
+                        .as_any_mut()
+                        .downcast_mut::<Decimal128Builder>()
+                        .unwrap();
+                    let val = val.trim().replace('.', "");
+                    let value_i128 = val.parse::<i128>().unwrap();
+                    sb.append_value(value_i128)?;
                 }
                 _ => Err(DataFusionError::Plan(format!(
                     "Not implemented: {}",
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index 32365090a..5b57bc971 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -427,10 +427,10 @@ async fn multiple_or_predicates() -> Result<()> {
     let expected =vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #lineitem.l_partkey [l_partkey:Int64]",
-        "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, 
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[#part.p_partkey = 
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size 
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, 
p_brand:Utf8, p_size:Int32]",
-        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND 
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) 
OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= 
CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS 
Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") 
AND #lineitem.l_quantity >= CAST(Int64 [...]
-        "        CrossJoin: [l_partkey:Int64, l_quantity:Float64, 
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
-        "          TableScan: lineitem projection=[l_partkey, l_quantity] 
[l_partkey:Int64, l_quantity:Float64]",
+        "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, 
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[#part.p_partkey = 
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size 
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, 
l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size: [...]
+        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND 
#lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= 
Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = 
Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decima [...]
+        "        CrossJoin: [l_partkey:Int64, l_quantity:Decimal128(15, 2), 
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+        "          TableScan: lineitem projection=[l_partkey, l_quantity] 
[l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
         "          Filter: #part.p_size >= Int32(1) [p_partkey:Int64, 
p_brand:Utf8, p_size:Int32]",
         "            TableScan: part projection=[p_partkey, p_brand, p_size], 
partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, 
p_size:Int32]",
     ];
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index 58561de12..0d9fe37f9 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -52,12 +52,12 @@ where c_acctbal < (
     let actual = format!("{}", plan.display_indent());
     let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
   Projection: #customer.c_custkey
-    Filter: #customer.c_acctbal < #__sq_2.__value
+    Filter: CAST(#customer.c_acctbal AS Decimal128(25, 2)) < #__sq_2.__value
       Inner Join: #customer.c_custkey = #__sq_2.o_custkey
         TableScan: customer projection=[c_custkey, c_acctbal]
         Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
           Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
-            Filter: #orders.o_totalprice < #__sq_1.__value
+            Filter: CAST(#orders.o_totalprice AS Decimal128(25, 2)) < 
#__sq_1.__value
               Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
                 TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
                 Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
@@ -229,6 +229,7 @@ async fn tpch_q4_correlated() -> Result<()> {
     Ok(())
 }
 
+#[ignore] // https://github.com/apache/arrow-datafusion/issues/3437
 #[tokio::test]
 async fn tpch_q17_correlated() -> Result<()> {
     let parts = r#"63700,goldenrod lavender spring chocolate 
lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly 
ironi
@@ -260,15 +261,15 @@ async fn tpch_q17_correlated() -> Result<()> {
         .map_err(|e| format!("{:?} at {}", e, "error"))
         .unwrap();
     let actual = format!("{}", plan.display_indent());
-    let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) 
AS avg_yearly
+    let expected = r#"Projection: CAST(#SUM(lineitem.l_extendedprice) AS 
Decimal128(38, 33)) / CAST(Float64(7) AS Decimal128(38, 33)) AS avg_yearly
   Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
-    Filter: #lineitem.l_quantity < #__sq_1.__value
+    Filter: CAST(#lineitem.l_quantity AS Decimal128(38, 21)) < #__sq_1.__value
       Inner Join: #part.p_partkey = #__sq_1.l_partkey
         Inner Join: #lineitem.l_partkey = #part.p_partkey
           TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]
           Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = 
Utf8("MED BOX")
             TableScan: part projection=[p_partkey, p_brand, p_container]
-        Projection: #lineitem.l_partkey, Float64(0.2) * 
#AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+        Projection: #lineitem.l_partkey, CAST(Float64(0.2) AS Decimal128(38, 
21)) * CAST(#AVG(lineitem.l_quantity) AS Decimal128(38, 21)) AS __value, 
alias=__sq_1
           Aggregate: groupBy=[[#lineitem.l_partkey]], 
aggr=[[AVG(#lineitem.l_quantity)]]
             TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]"#
         .to_string();
@@ -328,14 +329,14 @@ order by s_name;
         Filter: #nation.n_name = Utf8("CANADA")
           TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("CANADA")]
       Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
-        Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value
+        Filter: CAST(#partsupp.ps_availqty AS Decimal128(38, 17)) > 
#__sq_3.__value
           Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, 
#partsupp.ps_suppkey = #__sq_3.l_suppkey
             Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
               Projection: #part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: #part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name], 
partial_filters=[#part.p_name LIKE Utf8("forest%")]
-            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) 
* #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
+            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, 
CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
               Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], 
aggr=[[SUM(#lineitem.l_quantity)]]
                 Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS 
Date32)
                   TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
CAST(Utf8("1994-01-01") AS Date32)]"#
@@ -384,7 +385,7 @@ order by cntrycode;"#;
     Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), 
SUM(#custsale.c_acctbal)]]
       Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
         Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS 
cntrycode, #customer.c_acctbal, alias=custsale
-          Filter: #customer.c_acctbal > #__sq_1.__value
+          Filter: CAST(#customer.c_acctbal AS Decimal128(19, 6)) > 
#__sq_1.__value
             CrossJoin:
               Anti Join: #customer.c_custkey = #orders.o_custkey
                 Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])
@@ -392,7 +393,7 @@ order by cntrycode;"#;
                 TableScan: orders projection=[o_custkey]
               Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
                 Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
-                  Filter: #customer.c_acctbal > Float64(0) AND 
substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), 
Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+                  Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), 
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), 
Utf8("18"), Utf8("17")])
                     TableScan: customer projection=[c_phone, c_acctbal], 
partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, 
Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), 
Utf8("30"), Utf8("18"), Utf8("17")])]"#
         .to_string();
     assert_eq!(actual, expected);
@@ -443,17 +444,17 @@ order by value desc;
     let actual = format!("{}", plan.display_indent());
     let expected = r#"Sort: #value DESC NULLS FIRST
   Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * 
partsupp.ps_availqty) AS value
-    Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > 
#__sq_1.__value
+    Filter: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 17)) > #__sq_1.__value
       CrossJoin:
-        Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
+        Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * 
CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
           Inner Join: #supplier.s_nationkey = #nation.n_nationkey
             Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty, ps_supplycost]
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: #nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("GERMANY")]
-        Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * 
Float64(0.0001) AS __value, alias=__sq_1
-          Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * 
CAST(#partsupp.ps_availqty AS Float64))]]
+        Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) 
AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS 
__value, alias=__sq_1
+          Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS 
Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: #supplier.s_nationkey = #nation.n_nationkey
               Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
                 TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty, ps_supplycost]

Reply via email to