This is an automated email from the ASF dual-hosted git repository.

houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 91a450f  feat: support joins on Float32/Float64 columns (#1054)
91a450f is described below

commit 91a450fcc3a08a65a521b555707a3a2893a85697
Author: Francis Du <[email protected]>
AuthorDate: Mon Sep 27 01:12:09 2021 +0800

    feat: support joins on Float32/Float64 columns (#1054)
---
 datafusion/src/physical_plan/hash_join.rs |  6 ++-
 datafusion/tests/sql.rs                   | 82 +++++++++++++++++++++++++++++++
 python/Cargo.lock                         |  4 +-
 3 files changed, 88 insertions(+), 4 deletions(-)

diff --git a/datafusion/src/physical_plan/hash_join.rs 
b/datafusion/src/physical_plan/hash_join.rs
index d7aba9e..c51cf3b 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -44,8 +44,8 @@ use arrow::error::Result as ArrowResult;
 use arrow::record_batch::RecordBatch;
 
 use arrow::array::{
-    Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, 
UInt32Array,
-    UInt64Array, UInt8Array,
+    Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
+    StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
 };
 
 use hashbrown::raw::RawTable;
@@ -778,6 +778,8 @@ fn equal_rows(
             DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, 
right),
             DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, 
right),
             DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, 
right),
+            DataType::Float32 => equal_rows_elem!(Float32Array, l, r, left, 
right),
+            DataType::Float64 => equal_rows_elem!(Float64Array, l, r, left, 
right),
             DataType::Timestamp(_, None) => {
                 equal_rows_elem!(Int64Array, l, r, left, right)
             }
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 55f719b..ab4462a 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -2124,6 +2124,88 @@ async fn cross_join_unbalanced() {
     );
 }
 
+#[tokio::test]
+async fn test_join_float32() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+
+    // register population table
+    let population_schema = Arc::new(Schema::new(vec![
+        Field::new("city", DataType::Utf8, true),
+        Field::new("population", DataType::Float32, true),
+    ]));
+    let population_data = RecordBatch::try_new(
+        population_schema.clone(),
+        vec![
+            Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
+            Arc::new(Float32Array::from(vec![838.698, 1778.934, 626.443])),
+        ],
+    )?;
+    let population_table =
+        MemTable::try_new(population_schema, vec![vec![population_data]])?;
+    ctx.register_table("population", Arc::new(population_table))?;
+
+    let sql = "SELECT * \
+                     FROM population as a \
+                     JOIN (SELECT * FROM population as b) \
+                     ON a.population = b.population \
+                     ORDER BY a.population";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+------+------------+------+------------+",
+        "| city | population | city | population |",
+        "+------+------------+------+------------+",
+        "| c    | 626.443    | c    | 626.443    |",
+        "| a    | 838.698    | a    | 838.698    |",
+        "| b    | 1778.934   | b    | 1778.934   |",
+        "+------+------------+------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_join_float64() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+
+    // register population table
+    let population_schema = Arc::new(Schema::new(vec![
+        Field::new("city", DataType::Utf8, true),
+        Field::new("population", DataType::Float64, true),
+    ]));
+    let population_data = RecordBatch::try_new(
+        population_schema.clone(),
+        vec![
+            Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
+            Arc::new(Float64Array::from(vec![838.698, 1778.934, 626.443])),
+        ],
+    )?;
+    let population_table =
+        MemTable::try_new(population_schema, vec![vec![population_data]])?;
+    ctx.register_table("population", Arc::new(population_table))?;
+
+    let sql = "SELECT * \
+                     FROM population as a \
+                     JOIN (SELECT * FROM population as b) \
+                     ON a.population = b.population \
+                     ORDER BY a.population";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+------+------------+------+------------+",
+        "| city | population | city | population |",
+        "+------+------------+------+------------+",
+        "| c    | 626.443    | c    | 626.443    |",
+        "| a    | 838.698    | a    | 838.698    |",
+        "| b    | 1778.934   | b    | 1778.934   |",
+        "+------+------------+------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
+
 fn create_join_context(
     column_left: &str,
     column_right: &str,
diff --git a/python/Cargo.lock b/python/Cargo.lock
index 9acde43..9b811ac 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -1195,9 +1195,9 @@ checksum = 
"45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451"
 
 [[package]]
 name = "sqlparser"
-version = "0.10.0"
+version = "0.11.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "7dbcb26aebf44a0993c3c95e41d13860131b6cbf52edb2c53230056baa4d733f"
+checksum = "10e1ce16b71375ad72d28d111131069ce0d5f8603f4f86d8acd3456b41b57a51"
 dependencies = [
  "log",
 ]

Reply via email to