This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a8c1579ba update TPCH-mimicking tests to Decimal data type from Float,
matching the benchmark (#3438)
a8c1579ba is described below
commit a8c1579baf403c2c556e632e4deacc4ac98f38df
Author: Kirk Mitchener <[email protected]>
AuthorDate: Mon Sep 12 09:13:01 2022 -0400
update TPCH-mimicking tests to Decimal data type from Float, matching the
benchmark (#3438)
* update tests to Decimal data type from Float, matching the benchmark
itself
* add trim, remove sloppiness
---
datafusion/core/tests/sql/explain_analyze.rs | 2 +-
datafusion/core/tests/sql/mod.rs | 35 ++++++++++++++++------------
datafusion/core/tests/sql/predicates.rs | 8 +++----
datafusion/core/tests/sql/subqueries.rs | 27 ++++++++++-----------
4 files changed, 39 insertions(+), 33 deletions(-)
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index d5509cf65..91dd9401e 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name,
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue,
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone,
#customer.c_comment\
- \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS
Float64) - #lineitem.l_discount)]]\
+ \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS
Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) -
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index ad4800b70..c16386c5d 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -453,7 +453,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("c_address", DataType::Utf8, false),
Field::new("c_nationkey", DataType::Int64, false),
Field::new("c_phone", DataType::Utf8, false),
- Field::new("c_acctbal", DataType::Float64, false),
+ Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
Field::new("c_mktsegment", DataType::Utf8, false),
Field::new("c_comment", DataType::Utf8, false),
]),
@@ -462,7 +462,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("o_orderkey", DataType::Int64, false),
Field::new("o_custkey", DataType::Int64, false),
Field::new("o_orderstatus", DataType::Utf8, false),
- Field::new("o_totalprice", DataType::Float64, false),
+ Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
Field::new("o_orderdate", DataType::Date32, false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
@@ -475,10 +475,10 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("l_partkey", DataType::Int64, false),
Field::new("l_suppkey", DataType::Int64, false),
Field::new("l_linenumber", DataType::Int32, false),
- Field::new("l_quantity", DataType::Float64, false),
- Field::new("l_extendedprice", DataType::Float64, false),
- Field::new("l_discount", DataType::Float64, false),
- Field::new("l_tax", DataType::Float64, false),
+ Field::new("l_quantity", DataType::Decimal128(15, 2), false),
+ Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
+ Field::new("l_discount", DataType::Decimal128(15, 2), false),
+ Field::new("l_tax", DataType::Decimal128(15, 2), false),
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32, false),
@@ -502,7 +502,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("s_address", DataType::Utf8, false),
Field::new("s_nationkey", DataType::Int64, false),
Field::new("s_phone", DataType::Utf8, false),
- Field::new("s_acctbal", DataType::Float64, false),
+ Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
Field::new("s_comment", DataType::Utf8, false),
]),
@@ -510,7 +510,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("ps_partkey", DataType::Int64, false),
Field::new("ps_suppkey", DataType::Int64, false),
Field::new("ps_availqty", DataType::Int32, false),
- Field::new("ps_supplycost", DataType::Float64, false),
+ Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
Field::new("ps_comment", DataType::Utf8, false),
]),
@@ -522,7 +522,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("p_type", DataType::Utf8, false),
Field::new("p_size", DataType::Int32, false),
Field::new("p_container", DataType::Utf8, false),
- Field::new("p_retailprice", DataType::Float64, false),
+ Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
Field::new("p_comment", DataType::Utf8, false),
]),
@@ -573,9 +573,9 @@ async fn register_tpch_csv_data(
DataType::Int64 => {
cols.push(Box::new(Int64Builder::with_capacity(records.len())))
}
- DataType::Float64 => {
-
cols.push(Box::new(Float64Builder::with_capacity(records.len())))
- }
+ DataType::Decimal128(p, s) => cols.push(Box::new(
+ Decimal128Builder::with_capacity(records.len(), *p, *s),
+ )),
_ => {
let msg = format!("Not implemented: {}", field.data_type());
Err(DataFusionError::Plan(msg))?
@@ -606,9 +606,14 @@ async fn register_tpch_csv_data(
let sb =
col.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
sb.append_value(val.trim().parse().unwrap());
}
- DataType::Float64 => {
- let sb =
col.as_any_mut().downcast_mut::<Float64Builder>().unwrap();
- sb.append_value(val.trim().parse().unwrap());
+ DataType::Decimal128(_, _) => {
+ let sb = col
+ .as_any_mut()
+ .downcast_mut::<Decimal128Builder>()
+ .unwrap();
+ let val = val.trim().replace('.', "");
+ let value_i128 = val.parse::<i128>().unwrap();
+ sb.append_value(value_i128)?;
}
_ => Err(DataFusionError::Plan(format!(
"Not implemented: {}",
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index 32365090a..5b57bc971 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -427,10 +427,10 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
- " Projection: #part.p_partkey = #lineitem.l_partkey AS
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey,
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[#part.p_partkey =
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64,
p_brand:Utf8, p_size:Int32]",
- " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5)
OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >=
CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS
Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\")
AND #lineitem.l_quantity >= CAST(Int64 [...]
- " CrossJoin: [l_partkey:Int64, l_quantity:Float64,
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
- " TableScan: lineitem projection=[l_partkey, l_quantity]
[l_partkey:Int64, l_quantity:Float64]",
+ " Projection: #part.p_partkey = #lineitem.l_partkey AS
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey,
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[#part.p_partkey =
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64,
l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size: [...]
+ " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND
#lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <=
Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >=
Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <=
Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand =
Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decima [...]
+ " CrossJoin: [l_partkey:Int64, l_quantity:Decimal128(15, 2),
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+ " TableScan: lineitem projection=[l_partkey, l_quantity]
[l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64,
p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size],
partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8,
p_size:Int32]",
];
diff --git a/datafusion/core/tests/sql/subqueries.rs
b/datafusion/core/tests/sql/subqueries.rs
index 58561de12..0d9fe37f9 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -52,12 +52,12 @@ where c_acctbal < (
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
Projection: #customer.c_custkey
- Filter: #customer.c_acctbal < #__sq_2.__value
+ Filter: CAST(#customer.c_acctbal AS Decimal128(25, 2)) < #__sq_2.__value
Inner Join: #customer.c_custkey = #__sq_2.o_custkey
TableScan: customer projection=[c_custkey, c_acctbal]
Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value,
alias=__sq_2
Aggregate: groupBy=[[#orders.o_custkey]],
aggr=[[SUM(#orders.o_totalprice)]]
- Filter: #orders.o_totalprice < #__sq_1.__value
+ Filter: CAST(#orders.o_totalprice AS Decimal128(25, 2)) <
#__sq_1.__value
Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
TableScan: orders projection=[o_orderkey, o_custkey,
o_totalprice]
Projection: #lineitem.l_orderkey,
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
@@ -229,6 +229,7 @@ async fn tpch_q4_correlated() -> Result<()> {
Ok(())
}
+#[ignore] // https://github.com/apache/arrow-datafusion/issues/3437
#[tokio::test]
async fn tpch_q17_correlated() -> Result<()> {
let parts = r#"63700,goldenrod lavender spring chocolate
lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly
ironi
@@ -260,15 +261,15 @@ async fn tpch_q17_correlated() -> Result<()> {
.map_err(|e| format!("{:?} at {}", e, "error"))
.unwrap();
let actual = format!("{}", plan.display_indent());
- let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7)
AS avg_yearly
+ let expected = r#"Projection: CAST(#SUM(lineitem.l_extendedprice) AS
Decimal128(38, 33)) / CAST(Float64(7) AS Decimal128(38, 33)) AS avg_yearly
Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
- Filter: #lineitem.l_quantity < #__sq_1.__value
+ Filter: CAST(#lineitem.l_quantity AS Decimal128(38, 21)) < #__sq_1.__value
Inner Join: #part.p_partkey = #__sq_1.l_partkey
Inner Join: #lineitem.l_partkey = #part.p_partkey
TableScan: lineitem projection=[l_partkey, l_quantity,
l_extendedprice]
Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container =
Utf8("MED BOX")
TableScan: part projection=[p_partkey, p_brand, p_container]
- Projection: #lineitem.l_partkey, Float64(0.2) *
#AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+ Projection: #lineitem.l_partkey, CAST(Float64(0.2) AS Decimal128(38,
21)) * CAST(#AVG(lineitem.l_quantity) AS Decimal128(38, 21)) AS __value,
alias=__sq_1
Aggregate: groupBy=[[#lineitem.l_partkey]],
aggr=[[AVG(#lineitem.l_quantity)]]
TableScan: lineitem projection=[l_partkey, l_quantity,
l_extendedprice]"#
.to_string();
@@ -328,14 +329,14 @@ order by s_name;
Filter: #nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[#nation.n_name = Utf8("CANADA")]
Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
- Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value
+ Filter: CAST(#partsupp.ps_availqty AS Decimal128(38, 17)) >
#__sq_3.__value
Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey,
#partsupp.ps_suppkey = #__sq_3.l_suppkey
Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty]
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
Filter: #part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name],
partial_filters=[#part.p_name LIKE Utf8("forest%")]
- Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5)
* #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
+ Projection: #lineitem.l_partkey, #lineitem.l_suppkey,
CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS
Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]],
aggr=[[SUM(#lineitem.l_quantity)]]
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS
Date32)
TableScan: lineitem projection=[l_partkey, l_suppkey,
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >=
CAST(Utf8("1994-01-01") AS Date32)]"#
@@ -384,7 +385,7 @@ order by cntrycode;"#;
Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)),
SUM(#custsale.c_acctbal)]]
Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS
cntrycode, #customer.c_acctbal, alias=custsale
- Filter: #customer.c_acctbal > #__sq_1.__value
+ Filter: CAST(#customer.c_acctbal AS Decimal128(19, 6)) >
#__sq_1.__value
CrossJoin:
Anti Join: #customer.c_custkey = #orders.o_custkey
Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"),
Utf8("17")])
@@ -392,7 +393,7 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
- Filter: #customer.c_acctbal > Float64(0) AND
substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"),
Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+ Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) >
CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1),
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"),
Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal],
partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone,
Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"),
Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);
@@ -443,17 +444,17 @@ order by value desc;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: #value DESC NULLS FIRST
Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS value
- Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) >
#__sq_1.__value
+ Filter: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 17)) > #__sq_1.__value
CrossJoin:
- Aggregate: groupBy=[[#partsupp.ps_partkey]],
aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
+ Aggregate: groupBy=[[#partsupp.ps_partkey]],
aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) *
CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[#nation.n_name = Utf8("GERMANY")]
- Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) *
Float64(0.0001) AS __value, alias=__sq_1
- Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost *
CAST(#partsupp.ps_availqty AS Float64))]]
+ Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty)
AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS
__value, alias=__sq_1
+ Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS
Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty, ps_supplycost]