This is an automated email from the ASF dual-hosted git repository.
houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 91a450f feat: support joins on Float32/Float64 columns (#1054)
91a450f is described below
commit 91a450fcc3a08a65a521b555707a3a2893a85697
Author: Francis Du <[email protected]>
AuthorDate: Mon Sep 27 01:12:09 2021 +0800
feat: support joins on Float32/Float64 columns (#1054)
---
datafusion/src/physical_plan/hash_join.rs | 6 ++-
datafusion/tests/sql.rs | 82 +++++++++++++++++++++++++++++++
python/Cargo.lock | 4 +-
3 files changed, 88 insertions(+), 4 deletions(-)
diff --git a/datafusion/src/physical_plan/hash_join.rs
b/datafusion/src/physical_plan/hash_join.rs
index d7aba9e..c51cf3b 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -44,8 +44,8 @@ use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::array::{
- Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array,
UInt32Array,
- UInt64Array, UInt8Array,
+ Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
+ StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use hashbrown::raw::RawTable;
@@ -778,6 +778,8 @@ fn equal_rows(
DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left,
right),
DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left,
right),
DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left,
right),
+ DataType::Float32 => equal_rows_elem!(Float32Array, l, r, left,
right),
+ DataType::Float64 => equal_rows_elem!(Float64Array, l, r, left,
right),
DataType::Timestamp(_, None) => {
equal_rows_elem!(Int64Array, l, r, left, right)
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 55f719b..ab4462a 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -2124,6 +2124,88 @@ async fn cross_join_unbalanced() {
);
}
+#[tokio::test]
+async fn test_join_float32() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+
+ // register population table
+ let population_schema = Arc::new(Schema::new(vec![
+ Field::new("city", DataType::Utf8, true),
+ Field::new("population", DataType::Float32, true),
+ ]));
+ let population_data = RecordBatch::try_new(
+ population_schema.clone(),
+ vec![
+ Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
+ Arc::new(Float32Array::from(vec![838.698, 1778.934, 626.443])),
+ ],
+ )?;
+ let population_table =
+ MemTable::try_new(population_schema, vec![vec![population_data]])?;
+ ctx.register_table("population", Arc::new(population_table))?;
+
+ let sql = "SELECT * \
+ FROM population as a \
+ JOIN (SELECT * FROM population as b) \
+ ON a.population = b.population \
+ ORDER BY a.population";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+
+ let expected = vec![
+ "+------+------------+------+------------+",
+ "| city | population | city | population |",
+ "+------+------------+------+------------+",
+ "| c | 626.443 | c | 626.443 |",
+ "| a | 838.698 | a | 838.698 |",
+ "| b | 1778.934 | b | 1778.934 |",
+ "+------+------------+------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_join_float64() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+
+ // register population table
+ let population_schema = Arc::new(Schema::new(vec![
+ Field::new("city", DataType::Utf8, true),
+ Field::new("population", DataType::Float64, true),
+ ]));
+ let population_data = RecordBatch::try_new(
+ population_schema.clone(),
+ vec![
+ Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
+ Arc::new(Float64Array::from(vec![838.698, 1778.934, 626.443])),
+ ],
+ )?;
+ let population_table =
+ MemTable::try_new(population_schema, vec![vec![population_data]])?;
+ ctx.register_table("population", Arc::new(population_table))?;
+
+ let sql = "SELECT * \
+ FROM population as a \
+ JOIN (SELECT * FROM population as b) \
+ ON a.population = b.population \
+ ORDER BY a.population";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+
+ let expected = vec![
+ "+------+------------+------+------------+",
+ "| city | population | city | population |",
+ "+------+------------+------+------------+",
+ "| c | 626.443 | c | 626.443 |",
+ "| a | 838.698 | a | 838.698 |",
+ "| b | 1778.934 | b | 1778.934 |",
+ "+------+------------+------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
fn create_join_context(
column_left: &str,
column_right: &str,
diff --git a/python/Cargo.lock b/python/Cargo.lock
index 9acde43..9b811ac 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -1195,9 +1195,9 @@ checksum =
"45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451"
[[package]]
name = "sqlparser"
-version = "0.10.0"
+version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7dbcb26aebf44a0993c3c95e41d13860131b6cbf52edb2c53230056baa4d733f"
+checksum = "10e1ce16b71375ad72d28d111131069ce0d5f8603f4f86d8acd3456b41b57a51"
dependencies = [
"log",
]