This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 94a6192f6b fix: allow placeholders to be substited when coercible
(#8977)
94a6192f6b is described below
commit 94a6192f6be30b7f6d009bc936a866bf5dcb280c
Author: Adam Curtis <[email protected]>
AuthorDate: Wed Jan 24 14:19:32 2024 -0500
fix: allow placeholders to be substited when coercible (#8977)
* fix: allow placeholders to be substited when coercible
* fix clippy
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/param_value.rs | 24 +-----
datafusion/core/tests/sql/select.rs | 137 +++++++++++++++++++++++++++++++
datafusion/expr/src/logical_plan/plan.rs | 5 +-
3 files changed, 141 insertions(+), 25 deletions(-)
diff --git a/datafusion/common/src/param_value.rs
b/datafusion/common/src/param_value.rs
index 3fe2ba99ab..c614098713 100644
--- a/datafusion/common/src/param_value.rs
+++ b/datafusion/common/src/param_value.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::error::{_internal_err, _plan_err};
+use crate::error::_plan_err;
use crate::{DataFusionError, Result, ScalarValue};
use arrow_schema::DataType;
use std::collections::HashMap;
@@ -65,11 +65,7 @@ impl ParamValues {
}
}
- pub fn get_placeholders_with_values(
- &self,
- id: &str,
- data_type: Option<&DataType>,
- ) -> Result<ScalarValue> {
+ pub fn get_placeholders_with_values(&self, id: &str) ->
Result<ScalarValue> {
match self {
ParamValues::List(list) => {
if id.is_empty() {
@@ -90,14 +86,6 @@ impl ParamValues {
"No value found for placeholder with id {id}"
))
})?;
- // check if the data type of the value matches the data type
of the placeholder
- if Some(&value.data_type()) != data_type {
- return _internal_err!(
- "Placeholder value type mismatch: expected {:?}, got
{:?}",
- data_type,
- value.data_type()
- );
- }
Ok(value.clone())
}
ParamValues::Map(map) => {
@@ -109,14 +97,6 @@ impl ParamValues {
"No value found for placeholder with name {id}"
))
})?;
- // check if the data type of the value matches the data type
of the placeholder
- if Some(&value.data_type()) != data_type {
- return _internal_err!(
- "Placeholder value type mismatch: expected {:?}, got
{:?}",
- data_type,
- value.data_type()
- );
- }
Ok(value.clone())
}
}
diff --git a/datafusion/core/tests/sql/select.rs
b/datafusion/core/tests/sql/select.rs
index 4a782e54b0..71369c7300 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -525,6 +525,89 @@ async fn test_prepare_statement() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn prepared_statement_type_coercion() -> Result<()> {
+ let ctx = SessionContext::new();
+ let signed_ints: Int32Array = vec![-1, 0, 1].into();
+ let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+ let batch = RecordBatch::try_from_iter(vec![
+ ("signed", Arc::new(signed_ints) as ArrayRef),
+ ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+ ])?;
+ ctx.register_batch("test", batch)?;
+ let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT
signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned =
$3")
+ .await?
+ .with_param_values(vec![
+ ScalarValue::from(1_i64),
+ ScalarValue::from(-1_i32),
+ ScalarValue::from("1"),
+ ])?
+ .collect()
+ .await?;
+ let expected = [
+ "+--------+----------+",
+ "| signed | unsigned |",
+ "+--------+----------+",
+ "| -1 | 1 |",
+ "+--------+----------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ Ok(())
+}
+
+#[tokio::test]
+async fn prepared_statement_invalid_types() -> Result<()> {
+ let ctx = SessionContext::new();
+ let signed_ints: Int32Array = vec![-1, 0, 1].into();
+ let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+ let batch = RecordBatch::try_from_iter(vec![
+ ("signed", Arc::new(signed_ints) as ArrayRef),
+ ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+ ])?;
+ ctx.register_batch("test", batch)?;
+ let results = ctx
+ .sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed =
$1")
+ .await?
+ .with_param_values(vec![ScalarValue::from("1")]);
+ assert_eq!(
+ results.unwrap_err().strip_backtrace(),
+ "Error during planning: Expected parameter of type Int32, got Utf8 at
index 0"
+ );
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_list_query_parameters() -> Result<()> {
+ let tmp_dir = TempDir::new()?;
+ let partition_count = 4;
+ let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;
+
+ let results = ctx
+ .sql("SELECT * FROM test WHERE c1 = $1")
+ .await?
+ .with_param_values(vec![ScalarValue::from(3i32)])?
+ .collect()
+ .await?;
+ let expected = vec![
+ "+----+----+-------+",
+ "| c1 | c2 | c3 |",
+ "+----+----+-------+",
+ "| 3 | 1 | false |",
+ "| 3 | 10 | true |",
+ "| 3 | 2 | true |",
+ "| 3 | 3 | false |",
+ "| 3 | 4 | true |",
+ "| 3 | 5 | false |",
+ "| 3 | 6 | true |",
+ "| 3 | 7 | false |",
+ "| 3 | 8 | true |",
+ "| 3 | 9 | false |",
+ "+----+----+-------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ Ok(())
+}
+
#[tokio::test]
async fn test_named_query_parameters() -> Result<()> {
let tmp_dir = TempDir::new()?;
@@ -572,6 +655,60 @@ async fn test_named_query_parameters() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_parameter_type_coercion() -> Result<()> {
+ let ctx = SessionContext::new();
+ let signed_ints: Int32Array = vec![-1, 0, 1].into();
+ let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
+ let batch = RecordBatch::try_from_iter(vec![
+ ("signed", Arc::new(signed_ints) as ArrayRef),
+ ("unsigned", Arc::new(unsigned_ints) as ArrayRef),
+ ])?;
+ ctx.register_batch("test", batch)?;
+ let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $foo >=
signed AND signed <= $bar AND unsigned <= $baz AND unsigned = $str")
+ .await?
+ .with_param_values(vec![
+ ("foo", ScalarValue::from(1_u64)),
+ ("bar", ScalarValue::from(-1_i64)),
+ ("baz", ScalarValue::from(2_i32)),
+ ("str", ScalarValue::from("1")),
+ ])?
+ .collect().await?;
+ let expected = [
+ "+--------+----------+",
+ "| signed | unsigned |",
+ "+--------+----------+",
+ "| -1 | 1 |",
+ "+--------+----------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_parameter_invalid_types() -> Result<()> {
+ let ctx = SessionContext::new();
+ let list_array = ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![Some(vec![
+ Some(1),
+ Some(2),
+ Some(3),
+ ])]);
+ let batch =
+ RecordBatch::try_from_iter(vec![("list", Arc::new(list_array) as
ArrayRef)])?;
+ ctx.register_batch("test", batch)?;
+ let results = ctx
+ .sql("SELECT list FROM test WHERE list = $1")
+ .await?
+ .with_param_values(vec![ScalarValue::from(4_i32)])?
+ .collect()
+ .await;
+ assert_eq!(
+ results.unwrap_err().strip_backtrace(),
+ "Arrow error: Invalid argument error: Invalid comparison operation:
List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\",
data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
{} })"
+);
+ Ok(())
+}
+
#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 5ab8a9c99c..aee3a59dd2 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1243,9 +1243,8 @@ impl LogicalPlan {
) -> Result<Expr> {
expr.transform(&|expr| {
match &expr {
- Expr::Placeholder(Placeholder { id, data_type }) => {
- let value = param_values
- .get_placeholders_with_values(id, data_type.as_ref())?;
+ Expr::Placeholder(Placeholder { id, .. }) => {
+ let value = param_values.get_placeholders_with_values(id)?;
// Replace the placeholder with the value
Ok(Transformed::Yes(Expr::Literal(value)))
}