This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new bec385668 Enable more benchmark verification tests (#4044)
bec385668 is described below
commit bec3856688ac3dc3fdda927f26640dc028023436
Author: Andy Grove <[email protected]>
AuthorDate: Mon Oct 31 12:16:25 2022 -0600
Enable more benchmark verification tests (#4044)
* Fix Decimal and Floating type coerce rule
* Enable more queries in benchmark verification tests
* update comparison_binary_numeric_coercion
* revert type coercin change in comparison_binary_numeric_coercion
* smaller tolerance
Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
benchmarks/src/bin/tpch.rs | 35 ++++++++++++++++++++-------
benchmarks/src/tpch.rs | 59 ++++++++++++++++++++++++++++++----------------
2 files changed, 66 insertions(+), 28 deletions(-)
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index b9afe4d6a..df64537bd 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -668,7 +668,6 @@ mod tests {
}
#[cfg(feature = "ci")]
- #[ignore] // TODO produces correct result but has rounding error
#[tokio::test]
async fn verify_q9() -> Result<()> {
verify_query(9).await
@@ -681,7 +680,6 @@ mod tests {
}
#[cfg(feature = "ci")]
- #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023
#[tokio::test]
async fn verify_q11() -> Result<()> {
verify_query(11).await
@@ -700,7 +698,6 @@ mod tests {
}
#[cfg(feature = "ci")]
- #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025
#[tokio::test]
async fn verify_q14() -> Result<()> {
verify_query(14).await
@@ -719,7 +716,6 @@ mod tests {
}
#[cfg(feature = "ci")]
- #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026
#[tokio::test]
async fn verify_q17() -> Result<()> {
verify_query(17).await
@@ -896,8 +892,8 @@ mod tests {
#[cfg(feature = "ci")]
async fn verify_query(n: usize) -> Result<()> {
use datafusion::arrow::datatypes::{DataType, Field};
+ use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::Cast;
- use datafusion::logical_expr::Expr;
use std::env;
let path =
env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string());
@@ -990,7 +986,12 @@ mod tests {
}
data_type => data_type == e.data_type(),
});
- assert!(schema_matches);
+ if !schema_matches {
+ panic!(
+ "expected_fields: {:?}\ntransformed_fields: {:?}",
+ expected_fields, transformed_fields
+ )
+ }
// convert both datasets to Vec<Vec<String>> for simple comparison
let expected_vec = result_vec(&expected);
@@ -1000,8 +1001,26 @@ mod tests {
assert_eq!(expected_vec.len(), actual_vec.len());
// compare each row. this works as all TPC-H queries have
deterministically ordered results
- for i in 0..actual_vec.len() {
- assert_eq!(expected_vec[i], actual_vec[i]);
+ for i in 0..expected_vec.len() {
+ let expected_row = &expected_vec[i];
+ let actual_row = &actual_vec[i];
+ assert_eq!(expected_row.len(), actual_row.len());
+
+ for j in 0..expected.len() {
+ match (&expected_row[j], &actual_row[j]) {
+ (ScalarValue::Float64(Some(l)),
ScalarValue::Float64(Some(r))) => {
+ // allow for rounding errors until we move to decimal
types
+ let tolerance = 0.1;
+ if (l - r).abs() > tolerance {
+ panic!(
+ "Expected: {}; Actual: {}; Tolerance: {}",
+ l, r, tolerance
+ )
+ }
+ }
+ (l, r) => assert_eq!(format!("{:?}", l), format!("{:?}",
r)),
+ }
+ }
}
Ok(())
diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
index 46c53edf1..ad61de8a3 100644
--- a/benchmarks/src/tpch.rs
+++ b/benchmarks/src/tpch.rs
@@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::ArrayRef;
+use arrow::array::{
+ Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array,
Int64Array,
+ StringArray,
+};
use arrow::record_batch::RecordBatch;
use std::fs;
use std::ops::{Div, Mul};
@@ -23,7 +26,7 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
-use datafusion::arrow::util::display::array_value_to_string;
+use datafusion::common::ScalarValue;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
use datafusion::{
@@ -229,11 +232,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
Field::new("custdist", DataType::Int64, true),
]),
- 14 => Schema::new(vec![Field::new(
- "promo_revenue",
- DataType::Decimal128(38, 2),
- true,
- )]),
+ 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64,
true)]),
15 => Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, true),
@@ -250,11 +249,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
Field::new("supplier_cnt", DataType::Int64, true),
]),
- 17 => Schema::new(vec![Field::new(
- "avg_yearly",
- DataType::Decimal128(38, 2),
- true,
- )]),
+ 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64,
true)]),
18 => Schema::new(vec![
Field::new("c_name", DataType::Utf8, true),
@@ -389,14 +384,14 @@ pub async fn convert_tbl(
/// Converts the results into a 2d array of strings, `result[row][column]`
/// Special cases nulls to NULL for testing
-pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
+pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<ScalarValue>> {
let mut result = vec![];
for batch in results {
for row_index in 0..batch.num_rows() {
let row_vec = batch
.columns()
.iter()
- .map(|column| col_str(column, row_index))
+ .map(|column| col_to_scalar(column, row_index))
.collect();
result.push(row_vec);
}
@@ -422,13 +417,37 @@ pub fn string_schema(schema: Schema) -> Schema {
)
}
-/// Specialised String representation
-fn col_str(column: &ArrayRef, row_index: usize) -> String {
+fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
if column.is_null(row_index) {
- return "NULL".to_string();
+ return ScalarValue::Null;
+ }
+ match column.data_type() {
+ DataType::Int32 => {
+ let array = column.as_any().downcast_ref::<Int32Array>().unwrap();
+ ScalarValue::Int32(Some(array.value(row_index)))
+ }
+ DataType::Int64 => {
+ let array = column.as_any().downcast_ref::<Int64Array>().unwrap();
+ ScalarValue::Int64(Some(array.value(row_index)))
+ }
+ DataType::Float64 => {
+ let array =
column.as_any().downcast_ref::<Float64Array>().unwrap();
+ ScalarValue::Float64(Some(array.value(row_index)))
+ }
+ DataType::Decimal128(p, s) => {
+ let array =
column.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
+ }
+ DataType::Date32 => {
+ let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
+ ScalarValue::Date32(Some(array.value(row_index)))
+ }
+ DataType::Utf8 => {
+ let array = column.as_any().downcast_ref::<StringArray>().unwrap();
+ ScalarValue::Utf8(Some(array.value(row_index).to_string()))
+ }
+ other => panic!("unexpected data type in benchmark: {}", other),
}
-
- array_value_to_string(column, row_index).unwrap()
}
pub async fn transform_actual_result(
@@ -460,7 +479,7 @@ pub async fn transform_actual_result(
Expr::Alias(
Box::new(Expr::Cast(Cast::new(
round,
- DataType::Decimal128(38, 2),
+ DataType::Decimal128(15, 2),
))),
Field::name(field).to_string(),
)