This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 94a6192f6b fix: allow placeholders to be substited when coercible 
(#8977)
94a6192f6b is described below

commit 94a6192f6be30b7f6d009bc936a866bf5dcb280c
Author: Adam Curtis <[email protected]>
AuthorDate: Wed Jan 24 14:19:32 2024 -0500

    fix: allow placeholders to be substited when coercible (#8977)
    
    * fix: allow placeholders to be substited when coercible
    
    * fix clippy
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/param_value.rs     |  24 +-----
 datafusion/core/tests/sql/select.rs      | 137 +++++++++++++++++++++++++++++++
 datafusion/expr/src/logical_plan/plan.rs |   5 +-
 3 files changed, 141 insertions(+), 25 deletions(-)

diff --git a/datafusion/common/src/param_value.rs 
b/datafusion/common/src/param_value.rs
index 3fe2ba99ab..c614098713 100644
--- a/datafusion/common/src/param_value.rs
+++ b/datafusion/common/src/param_value.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::error::{_internal_err, _plan_err};
+use crate::error::_plan_err;
 use crate::{DataFusionError, Result, ScalarValue};
 use arrow_schema::DataType;
 use std::collections::HashMap;
@@ -65,11 +65,7 @@ impl ParamValues {
         }
     }
 
-    pub fn get_placeholders_with_values(
-        &self,
-        id: &str,
-        data_type: Option<&DataType>,
-    ) -> Result<ScalarValue> {
+    pub fn get_placeholders_with_values(&self, id: &str) -> 
Result<ScalarValue> {
         match self {
             ParamValues::List(list) => {
                 if id.is_empty() {
@@ -90,14 +86,6 @@ impl ParamValues {
                         "No value found for placeholder with id {id}"
                     ))
                 })?;
-                // check if the data type of the value matches the data type 
of the placeholder
-                if Some(&value.data_type()) != data_type {
-                    return _internal_err!(
-                        "Placeholder value type mismatch: expected {:?}, got 
{:?}",
-                        data_type,
-                        value.data_type()
-                    );
-                }
                 Ok(value.clone())
             }
             ParamValues::Map(map) => {
@@ -109,14 +97,6 @@ impl ParamValues {
                         "No value found for placeholder with name {id}"
                     ))
                 })?;
-                // check if the data type of the value matches the data type 
of the placeholder
-                if Some(&value.data_type()) != data_type {
-                    return _internal_err!(
-                        "Placeholder value type mismatch: expected {:?}, got 
{:?}",
-                        data_type,
-                        value.data_type()
-                    );
-                }
                 Ok(value.clone())
             }
         }
diff --git a/datafusion/core/tests/sql/select.rs 
b/datafusion/core/tests/sql/select.rs
index 4a782e54b0..71369c7300 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -525,6 +525,89 @@ async fn test_prepare_statement() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn prepared_statement_type_coercion() -> Result<()> {
+    let ctx = SessionContext::new();
+    let signed_ints: Int32Array = vec![-1, 0, 1].into();
+    let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+    let batch = RecordBatch::try_from_iter(vec![
+        ("signed", Arc::new(signed_ints) as ArrayRef),
+        ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+    ])?;
+    ctx.register_batch("test", batch)?;
+    let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT 
signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = 
$3")
+        .await?
+        .with_param_values(vec![
+            ScalarValue::from(1_i64),
+            ScalarValue::from(-1_i32),
+            ScalarValue::from("1"),
+        ])?
+        .collect()
+        .await?;
+    let expected = [
+        "+--------+----------+",
+        "| signed | unsigned |",
+        "+--------+----------+",
+        "| -1     | 1        |",
+        "+--------+----------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
+#[tokio::test]
+async fn prepared_statement_invalid_types() -> Result<()> {
+    let ctx = SessionContext::new();
+    let signed_ints: Int32Array = vec![-1, 0, 1].into();
+    let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+    let batch = RecordBatch::try_from_iter(vec![
+        ("signed", Arc::new(signed_ints) as ArrayRef),
+        ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+    ])?;
+    ctx.register_batch("test", batch)?;
+    let results = ctx
+        .sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = 
$1")
+        .await?
+        .with_param_values(vec![ScalarValue::from("1")]);
+    assert_eq!(
+        results.unwrap_err().strip_backtrace(),
+        "Error during planning: Expected parameter of type Int32, got Utf8 at 
index 0"
+    );
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_list_query_parameters() -> Result<()> {
+    let tmp_dir = TempDir::new()?;
+    let partition_count = 4;
+    let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;
+
+    let results = ctx
+        .sql("SELECT * FROM test WHERE c1 = $1")
+        .await?
+        .with_param_values(vec![ScalarValue::from(3i32)])?
+        .collect()
+        .await?;
+    let expected = vec![
+        "+----+----+-------+",
+        "| c1 | c2 | c3    |",
+        "+----+----+-------+",
+        "| 3  | 1  | false |",
+        "| 3  | 10 | true  |",
+        "| 3  | 2  | true  |",
+        "| 3  | 3  | false |",
+        "| 3  | 4  | true  |",
+        "| 3  | 5  | false |",
+        "| 3  | 6  | true  |",
+        "| 3  | 7  | false |",
+        "| 3  | 8  | true  |",
+        "| 3  | 9  | false |",
+        "+----+----+-------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
 #[tokio::test]
 async fn test_named_query_parameters() -> Result<()> {
     let tmp_dir = TempDir::new()?;
@@ -572,6 +655,60 @@ async fn test_named_query_parameters() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_parameter_type_coercion() -> Result<()> {
+    let ctx = SessionContext::new();
+    let signed_ints: Int32Array = vec![-1, 0, 1].into();
+    let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+    let batch = RecordBatch::try_from_iter(vec![
+        ("signed", Arc::new(signed_ints) as ArrayRef),
+        ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+    ])?;
+    ctx.register_batch("test", batch)?;
+    let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $foo >= 
signed AND signed <= $bar AND unsigned <= $baz AND unsigned = $str")
+        .await?
+        .with_param_values(vec![
+            ("foo", ScalarValue::from(1_u64)),
+            ("bar", ScalarValue::from(-1_i64)),
+            ("baz", ScalarValue::from(2_i32)),
+            ("str", ScalarValue::from("1")),
+        ])?
+        .collect().await?;
+    let expected = [
+        "+--------+----------+",
+        "| signed | unsigned |",
+        "+--------+----------+",
+        "| -1     | 1        |",
+        "+--------+----------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_parameter_invalid_types() -> Result<()> {
+    let ctx = SessionContext::new();
+    let list_array = ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![Some(vec![
+        Some(1),
+        Some(2),
+        Some(3),
+    ])]);
+    let batch =
+        RecordBatch::try_from_iter(vec![("list", Arc::new(list_array) as 
ArrayRef)])?;
+    ctx.register_batch("test", batch)?;
+    let results = ctx
+        .sql("SELECT list FROM test WHERE list = $1")
+        .await?
+        .with_param_values(vec![ScalarValue::from(4_i32)])?
+        .collect()
+        .await;
+    assert_eq!(
+        results.unwrap_err().strip_backtrace(),
+        "Arrow error: Invalid argument error: Invalid comparison operation: 
List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", 
data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} })"
+);
+    Ok(())
+}
+
 #[tokio::test]
 async fn parallel_query_with_filter() -> Result<()> {
     let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 5ab8a9c99c..aee3a59dd2 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1243,9 +1243,8 @@ impl LogicalPlan {
     ) -> Result<Expr> {
         expr.transform(&|expr| {
             match &expr {
-                Expr::Placeholder(Placeholder { id, data_type }) => {
-                    let value = param_values
-                        .get_placeholders_with_values(id, data_type.as_ref())?;
+                Expr::Placeholder(Placeholder { id, .. }) => {
+                    let value = param_values.get_placeholders_with_values(id)?;
                     // Replace the placeholder with the value
                     Ok(Transformed::Yes(Expr::Literal(value)))
                 }

Reply via email to