This is an automated email from the ASF dual-hosted git repository. avantgardner pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push: new 5d4038a84 Infer values for inserts (#4977) 5d4038a84 is described below commit 5d4038a8463a575328bedbc22b32456f5dcd562c Author: Brent Gardner <bgard...@squarelabs.net> AuthorDate: Mon Jan 23 10:20:21 2023 -0700 Infer values for inserts (#4977) * Infer values for updates Co-authored-by: Andrew Lamb <and...@nerdnetworks.org> --- datafusion/sql/src/statement.rs | 33 ++++++++++++++- datafusion/sql/tests/integration_test.rs | 69 ++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 616619a1b..5b7949d2d 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -40,7 +40,7 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, + Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableFactor, TableWithJoins, UnaryOperator, Value, }; @@ -762,8 +762,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let arrow_schema = (*provider.schema()).clone(); let table_schema = Arc::new(DFSchema::try_from(arrow_schema)?); + // infer types for Values clause... other types should be resolvable the regular way + let mut prepare_param_data_types = BTreeMap::new(); + if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() { + for row in rows.iter() { + for (idx, val) in row.iter().enumerate() { + if let ast::Expr::Value(Value::Placeholder(name)) = val { + let name = + name.replace('$', "").parse::<usize>().map_err(|_| { + DataFusionError::Plan(format!( + "Can't parse placeholder: {name}" + )) + })? - 1; + let col = columns.get(idx).ok_or_else(|| { + DataFusionError::Plan(format!( + "Placeholder ${} refers to a non existent column", + idx + 1 + )) + })?; + let field = + table_schema.field_with_name(None, col.value.as_str())?; + let dt = field.field().data_type().clone(); + let _ = prepare_param_data_types.insert(name, dt); + } + } + } + } + let prepare_param_data_types = prepare_param_data_types.into_values().collect(); + // Projection - let mut planner_context = PlannerContext::new(); + let mut planner_context = + PlannerContext::new_with_prepare_param_data_types(prepare_param_data_types); let source = self.query_to_plan(*source, &mut planner_context)?; if columns.len() != source.schema().fields().len() { Err(DataFusionError::Plan( diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index e93ec8712..c771a3ec5 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -3390,6 +3390,75 @@ Dml: op=[Update] table=[person] prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_insert_infer() { + let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; + + let expected_plan = r#" +Dml: op=[Insert] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: ($1, $2, $3) + "# + .trim(); + + let expected_dt = "[Int32]"; + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::UInt32)), + ("$2".to_string(), Some(DataType::Utf8)), + ("$3".to_string(), Some(DataType::Utf8)), + ]); + assert_eq!(actual_types, expected_types); + + // replace params with values + let param_values = vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::Utf8(Some("Alan".to_string())), + ScalarValue::Utf8(Some("Turing".to_string())), + ]; + let expected_plan = r#" +Dml: op=[Insert] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) + "# + .trim(); + let plan = plan.replace_params_with_values(¶m_values).unwrap(); + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + +#[test] +#[should_panic(expected = "Placeholder $4 refers to a non existent column")] +fn test_prepare_statement_insert_infer_gt() { + let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3, $4)"; + + let expected_plan = r#""#.trim(); + let expected_dt = "[Int32]"; + let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt); +} + +#[test] +#[should_panic(expected = "value: Plan(\"Column count doesn't match insert query!\")")] +fn test_prepare_statement_insert_infer_lt() { + let sql = "insert into person (id, first_name, last_name) values ($1, $2)"; + + let expected_plan = r#""#.trim(); + let expected_dt = "[Int32]"; + let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt); +} + +#[test] +#[should_panic(expected = "value: Plan(\"Placeholder type could not be resolved\")")] +fn test_prepare_statement_insert_infer_gap() { + let sql = "insert into person (id, first_name, last_name) values ($2, $4, $6)"; + + let expected_plan = r#""#.trim(); + let expected_dt = "[Int32]"; + let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt); +} + #[test] fn test_prepare_statement_to_plan_one_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1";