This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 42587518f changed datatypes to match TPC-H definition -- where Float64
was used, using Decimal128 now (#3393)
42587518f is described below
commit 42587518f2899dbca102c97ad93873d25d906aad
Author: Kirk Mitchener <[email protected]>
AuthorDate: Thu Sep 8 09:47:44 2022 -0400
changed datatypes to match TPC-H definition -- where Float64 was used,
using Decimal128 now (#3393)
added special handling of q15 results, where we want to capture the results
of the second of 3 statements
fixed up the comparison of query results against known-good answers
stop ignoring q15 and q21
---
benchmarks/README.md | 3 +-
benchmarks/src/bin/tpch.rs | 279 ++++++++++++++++++++++++++++++++-------------
2 files changed, 201 insertions(+), 81 deletions(-)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 7b4dd3001..505469fc5 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -25,7 +25,8 @@ implementations as well as other query engines.
## Benchmark derived from TPC-H
-These benchmarks are derived from the [TPC-H][1] benchmark.
+These benchmarks are derived from the [TPC-H][1] benchmark. And we use this
repo as the source of tpch-gen and answers:
+https://github.com/databricks/tpch-dbgen.git, based on
[2.17.1](https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf)
version of TPC-H.
## Generating Test Data
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index 43db654e8..963833ee9 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -197,8 +197,21 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt)
-> Result<Vec<RecordB
let start = Instant::now();
let sql = &get_query_sql(opt.query)?;
- for query in sql {
- result = execute_query(&ctx, query, opt.debug).await?;
+
+ // query 15 is special, with 3 statements. the second statement is the
one from which we
+ // want to capture the results
+ if opt.query == 15 {
+ for (n, query) in sql.iter().enumerate() {
+ if n == 1 {
+ result = execute_query(&ctx, query, opt.debug).await?;
+ } else {
+ execute_query(&ctx, query, opt.debug).await?;
+ }
+ }
+ } else {
+ for query in sql {
+ result = execute_query(&ctx, query, opt.debug).await?;
+ }
}
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
@@ -281,8 +294,9 @@ async fn execute_query(
if debug {
println!("=== Logical plan ===\n{:?}\n", plan);
}
- let plan = ctx.optimize(&plan)?;
+
if debug {
+ let plan = ctx.optimize(&plan)?;
println!("=== Optimized logical plan ===\n{:?}\n", plan);
}
let physical_plan = ctx.create_physical_plan(&plan).await?;
@@ -442,7 +456,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("p_type", DataType::Utf8, false),
Field::new("p_size", DataType::Int32, false),
Field::new("p_container", DataType::Utf8, false),
- Field::new("p_retailprice", DataType::Float64, false),
+ Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
Field::new("p_comment", DataType::Utf8, false),
]),
@@ -452,7 +466,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("s_address", DataType::Utf8, false),
Field::new("s_nationkey", DataType::Int64, false),
Field::new("s_phone", DataType::Utf8, false),
- Field::new("s_acctbal", DataType::Float64, false),
+ Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
Field::new("s_comment", DataType::Utf8, false),
]),
@@ -460,7 +474,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("ps_partkey", DataType::Int64, false),
Field::new("ps_suppkey", DataType::Int64, false),
Field::new("ps_availqty", DataType::Int32, false),
- Field::new("ps_supplycost", DataType::Float64, false),
+ Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
Field::new("ps_comment", DataType::Utf8, false),
]),
@@ -470,7 +484,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("c_address", DataType::Utf8, false),
Field::new("c_nationkey", DataType::Int64, false),
Field::new("c_phone", DataType::Utf8, false),
- Field::new("c_acctbal", DataType::Float64, false),
+ Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
Field::new("c_mktsegment", DataType::Utf8, false),
Field::new("c_comment", DataType::Utf8, false),
]),
@@ -479,7 +493,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("o_orderkey", DataType::Int64, false),
Field::new("o_custkey", DataType::Int64, false),
Field::new("o_orderstatus", DataType::Utf8, false),
- Field::new("o_totalprice", DataType::Float64, false),
+ Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
Field::new("o_orderdate", DataType::Date32, false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
@@ -492,10 +506,10 @@ fn get_schema(table: &str) -> Schema {
Field::new("l_partkey", DataType::Int64, false),
Field::new("l_suppkey", DataType::Int64, false),
Field::new("l_linenumber", DataType::Int32, false),
- Field::new("l_quantity", DataType::Float64, false),
- Field::new("l_extendedprice", DataType::Float64, false),
- Field::new("l_discount", DataType::Float64, false),
- Field::new("l_tax", DataType::Float64, false),
+ Field::new("l_quantity", DataType::Decimal128(15, 2), false),
+ Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
+ Field::new("l_discount", DataType::Decimal128(15, 2), false),
+ Field::new("l_tax", DataType::Decimal128(15, 2), false),
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32, false),
@@ -575,12 +589,39 @@ struct QueryResult {
mod tests {
use super::*;
use std::env;
+ use std::ops::{Div, Mul};
use std::sync::Arc;
use datafusion::arrow::array::*;
use datafusion::arrow::util::display::array_value_to_string;
- use datafusion::logical_plan::Expr;
- use datafusion::logical_plan::Expr::Cast;
+ use datafusion::logical_expr::Expr;
+ use datafusion::logical_expr::Expr::Cast;
+ use datafusion::logical_expr::Expr::ScalarFunction;
+
+ const QUERY_LIMIT: [Option<usize>; 22] = [
+ None,
+ Some(100),
+ Some(10),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(20),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(100),
+ None,
+ None,
+ Some(100),
+ None,
+ ];
#[tokio::test]
async fn q1() -> Result<()> {
@@ -672,6 +713,7 @@ mod tests {
verify_query(18).await
}
+ #[ignore]
#[tokio::test]
async fn q19() -> Result<()> {
verify_query(19).await
@@ -762,7 +804,6 @@ mod tests {
run_query(14).await
}
- #[ignore] // https://github.com/apache/arrow-datafusion/issues/166
#[tokio::test]
async fn run_q15() -> Result<()> {
run_query(15).await
@@ -794,7 +835,6 @@ mod tests {
run_query(20).await
}
- #[ignore] // https://github.com/apache/arrow-datafusion/issues/172
#[tokio::test]
async fn run_q21() -> Result<()> {
run_query(21).await
@@ -836,21 +876,21 @@ mod tests {
1 => Schema::new(vec![
Field::new("l_returnflag", DataType::Utf8, true),
Field::new("l_linestatus", DataType::Utf8, true),
- Field::new("sum_qty", DataType::Float64, true),
- Field::new("sum_base_price", DataType::Float64, true),
- Field::new("sum_disc_price", DataType::Float64, true),
- Field::new("sum_charge", DataType::Float64, true),
- Field::new("avg_qty", DataType::Float64, true),
- Field::new("avg_price", DataType::Float64, true),
- Field::new("avg_disc", DataType::Float64, true),
- Field::new("count_order", DataType::UInt64, true),
+ Field::new("sum_qty", DataType::Decimal128(15, 2), true),
+ Field::new("sum_base_price", DataType::Decimal128(15, 2),
true),
+ Field::new("sum_disc_price", DataType::Decimal128(15, 2),
true),
+ Field::new("sum_charge", DataType::Decimal128(15, 2), true),
+ Field::new("avg_qty", DataType::Decimal128(15, 2), true),
+ Field::new("avg_price", DataType::Decimal128(15, 2), true),
+ Field::new("avg_disc", DataType::Decimal128(15, 2), true),
+ Field::new("count_order", DataType::Int64, true),
]),
2 => Schema::new(vec![
- Field::new("s_acctbal", DataType::Float64, true),
+ Field::new("s_acctbal", DataType::Decimal128(15, 2), true),
Field::new("s_name", DataType::Utf8, true),
Field::new("n_name", DataType::Utf8, true),
- Field::new("p_partkey", DataType::Int32, true),
+ Field::new("p_partkey", DataType::Int64, true),
Field::new("p_mfgr", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
Field::new("s_phone", DataType::Utf8, true),
@@ -858,47 +898,51 @@ mod tests {
]),
3 => Schema::new(vec![
- Field::new("l_orderkey", DataType::Int32, true),
- Field::new("revenue", DataType::Float64, true),
+ Field::new("l_orderkey", DataType::Int64, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
Field::new("o_orderdate", DataType::Date32, true),
Field::new("o_shippriority", DataType::Int32, true),
]),
4 => Schema::new(vec![
Field::new("o_orderpriority", DataType::Utf8, true),
- Field::new("order_count", DataType::Int32, true),
+ Field::new("order_count", DataType::Int64, true),
]),
5 => Schema::new(vec![
Field::new("n_name", DataType::Utf8, true),
- Field::new("revenue", DataType::Float64, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
]),
- 6 => Schema::new(vec![Field::new("revenue", DataType::Float64,
true)]),
+ 6 => Schema::new(vec![Field::new(
+ "revenue",
+ DataType::Decimal128(15, 2),
+ true,
+ )]),
7 => Schema::new(vec![
Field::new("supp_nation", DataType::Utf8, true),
Field::new("cust_nation", DataType::Utf8, true),
Field::new("l_year", DataType::Int32, true),
- Field::new("revenue", DataType::Float64, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
]),
8 => Schema::new(vec![
Field::new("o_year", DataType::Int32, true),
- Field::new("mkt_share", DataType::Float64, true),
+ Field::new("mkt_share", DataType::Decimal128(15, 2), true),
]),
9 => Schema::new(vec![
Field::new("nation", DataType::Utf8, true),
Field::new("o_year", DataType::Int32, true),
- Field::new("sum_profit", DataType::Float64, true),
+ Field::new("sum_profit", DataType::Decimal128(15, 2), true),
]),
10 => Schema::new(vec![
- Field::new("c_custkey", DataType::Int32, true),
+ Field::new("c_custkey", DataType::Int64, true),
Field::new("c_name", DataType::Utf8, true),
- Field::new("revenue", DataType::Float64, true),
- Field::new("c_acctbal", DataType::Float64, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
+ Field::new("c_acctbal", DataType::Decimal128(15, 2), true),
Field::new("n_name", DataType::Utf8, true),
Field::new("c_address", DataType::Utf8, true),
Field::new("c_phone", DataType::Utf8, true),
@@ -906,8 +950,8 @@ mod tests {
]),
11 => Schema::new(vec![
- Field::new("ps_partkey", DataType::Int32, true),
- Field::new("value", DataType::Float64, true),
+ Field::new("ps_partkey", DataType::Int64, true),
+ Field::new("value", DataType::Decimal128(15, 2), true),
]),
12 => Schema::new(vec![
@@ -923,24 +967,30 @@ mod tests {
14 => Schema::new(vec![Field::new("promo_revenue",
DataType::Float64, true)]),
- 15 => Schema::new(vec![Field::new("promo_revenue",
DataType::Float64, true)]),
+ 15 => Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, true),
+ Field::new("s_name", DataType::Utf8, true),
+ Field::new("s_address", DataType::Utf8, true),
+ Field::new("s_phone", DataType::Utf8, true),
+ Field::new("total_revenue", DataType::Decimal128(15, 2), true),
+ ]),
16 => Schema::new(vec![
Field::new("p_brand", DataType::Utf8, true),
Field::new("p_type", DataType::Utf8, true),
- Field::new("c_phone", DataType::Int32, true),
- Field::new("c_comment", DataType::Int32, true),
+ Field::new("p_size", DataType::Int32, true),
+ Field::new("supplier_cnt", DataType::Int64, true),
]),
17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64,
true)]),
18 => Schema::new(vec![
Field::new("c_name", DataType::Utf8, true),
- Field::new("c_custkey", DataType::Int32, true),
- Field::new("o_orderkey", DataType::Int32, true),
+ Field::new("c_custkey", DataType::Int64, true),
+ Field::new("o_orderkey", DataType::Int64, true),
Field::new("o_orderdate", DataType::Date32, true),
- Field::new("o_totalprice", DataType::Float64, true),
- Field::new("sum_l_quantity", DataType::Float64, true),
+ Field::new("o_totalprice", DataType::Decimal128(15, 2), true),
+ Field::new("sum_l_quantity", DataType::Decimal128(15, 2),
true),
]),
19 => Schema::new(vec![Field::new("revenue", DataType::Float64,
true)]),
@@ -952,13 +1002,13 @@ mod tests {
21 => Schema::new(vec![
Field::new("s_name", DataType::Utf8, true),
- Field::new("numwait", DataType::Int32, true),
+ Field::new("numwait", DataType::Int64, true),
]),
22 => Schema::new(vec![
- Field::new("cntrycode", DataType::Int32, true),
- Field::new("numcust", DataType::Int32, true),
- Field::new("totacctbal", DataType::Float64, true),
+ Field::new("cntrycode", DataType::Utf8, true),
+ Field::new("numcust", DataType::Int64, true),
+ Field::new("totacctbal", DataType::Decimal128(15, 2), true),
]),
_ => unimplemented!(),
@@ -983,22 +1033,59 @@ mod tests {
)
}
- // convert the schema to the same but with all columns set to
nullable=true.
- // this allows direct schema comparison ignoring nullable.
- fn nullable_schema(schema: Arc<Schema>) -> Schema {
- Schema::new(
- schema
- .fields()
- .iter()
- .map(|field| {
- Field::new(
- Field::name(field),
- Field::data_type(field).to_owned(),
- true,
- )
- })
- .collect::<Vec<Field>>(),
- )
+ async fn transform_actual_result(
+ result: Vec<RecordBatch>,
+ n: usize,
+ ) -> Result<Vec<RecordBatch>> {
+ // to compare the recorded answers to the answers we got back from
running the query,
+ // we need to round the decimal columns and trim the Utf8 columns
+ let ctx = SessionContext::new();
+ let result_schema = result[0].schema();
+ let table = Arc::new(MemTable::try_new(result_schema.clone(),
vec![result])?);
+ let mut df = ctx.read_table(table)?
+ .select(
+ result_schema
+ .fields
+ .iter()
+ .map(|field| {
+ match Field::data_type(field) {
+ DataType::Decimal128(_,_) => {
+ // if decimal, then round it to 2 decimal
places like the answers
+ // round() doesn't support the second argument
for decimal places to round to
+ // this can be simplified to remove the mul
and div when
+ //
https://github.com/apache/arrow-datafusion/issues/2420 is completed
+ // cast it back to an over-sized Decimal with
2 precision when done rounding
+ let round = Box::new(ScalarFunction {
+ fun:
datafusion::logical_expr::BuiltinScalarFunction::Round,
+ args:
vec![col(Field::name(field)).mul(lit(100))]
+ }.div(lit(100)));
+ Expr::Alias(
+ Box::new(Cast {
+ expr: round,
+ data_type: DataType::Decimal128(38,2),
+ }),
+ Field::name(field).to_string(),
+ )
+ }
+ DataType::Utf8 => {
+ // if string, then trim it like the answers
got trimmed
+ Expr::Alias(
+ Box::new(trim(col(Field::name(field)))),
+ Field::name(field).to_string()
+ )
+ }
+ _ => {
+ col(Field::name(field))
+ }
+ }
+ }).collect()
+ )?;
+ if let Some(x) = QUERY_LIMIT[n - 1] {
+ df = df.limit(0, Some(x))?;
+ }
+
+ let df = df.collect().await?;
+ Ok(df)
}
async fn run_query(n: usize) -> Result<()> {
@@ -1026,6 +1113,11 @@ mod tests {
Ok(())
}
+ /// compares query results against stored answers from the git repo
+ /// verifies that:
+ /// * datatypes returned in columns is correct
+ /// * the correct number of rows are returned
+ /// * the content of the rows is correct
async fn verify_query(n: usize) -> Result<()> {
if let Ok(path) = env::var("TPCH_DATA") {
// load expected answers from tpch-dbgen
@@ -1045,13 +1137,30 @@ mod tests {
.fields()
.iter()
.map(|field| {
- Expr::Alias(
- Box::new(Cast {
- expr: Box::new(trim(col(Field::name(field)))),
- data_type: Field::data_type(field).to_owned(),
- }),
- Field::name(field).to_string(),
- )
+ match Field::data_type(field) {
+ DataType::Decimal128(_, _) => {
+ // there's no support for casting from Utf8 to
Decimal, so
+ // we'll cast from Utf8 to Float64 to Decimal
for Decimal types
+ let inner_cast = Box::new(Cast {
+ expr:
Box::new(trim(col(Field::name(field)))),
+ data_type: DataType::Float64,
+ });
+ Expr::Alias(
+ Box::new(Cast {
+ expr: inner_cast,
+ data_type:
Field::data_type(field).to_owned(),
+ }),
+ Field::name(field).to_string(),
+ )
+ }
+ _ => Expr::Alias(
+ Box::new(Cast {
+ expr:
Box::new(trim(col(Field::name(field)))),
+ data_type:
Field::data_type(field).to_owned(),
+ }),
+ Field::name(field).to_string(),
+ ),
+ }
})
.collect::<Vec<Expr>>(),
)?;
@@ -1071,20 +1180,30 @@ mod tests {
};
let actual = benchmark_datafusion(opt).await?;
- // assert schema equality without comparing nullable values
- assert_eq!(
- nullable_schema(expected[0].schema()),
- nullable_schema(actual[0].schema())
- );
+ let transformed = transform_actual_result(actual, n).await?;
+
+ // assert schema data types match
+ let transformed_fields = &transformed[0].schema().fields;
+ let expected_fields = &expected[0].schema().fields;
+ let schema_matches = transformed_fields
+ .iter()
+ .zip(expected_fields.iter())
+ .all(|(t, e)| match t.data_type() {
+ DataType::Decimal128(_, _) => {
+ matches!(e.data_type(), DataType::Decimal128(_, _))
+ }
+ data_type => data_type == e.data_type(),
+ });
+ assert!(schema_matches);
// convert both datasets to Vec<Vec<String>> for simple comparison
let expected_vec = result_vec(&expected);
- let actual_vec = result_vec(&actual);
+ let actual_vec = result_vec(&transformed);
// basic result comparison
assert_eq!(expected_vec.len(), actual_vec.len());
- // compare each row. this works as all TPC-H queries have
determinisically ordered results
+ // compare each row. this works as all TPC-H queries have
deterministically ordered results
for i in 0..actual_vec.len() {
assert_eq!(expected_vec[i], actual_vec[i]);
}