This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f300168791 fix: detect non-recursive CTEs in the recursive `WITH`
clause (#9836)
f300168791 is described below
commit f300168791b261e1162ac7fab47b329c9e5467f3
Author: Jonah Gao <[email protected]>
AuthorDate: Mon Apr 1 23:36:14 2024 +0800
fix: detect non-recursive CTEs in the recursive `WITH` clause (#9836)
* move cte related logic to its own mod
* fix check cte self reference
* add tests
* fix test
* move test to slt
---
datafusion/sql/src/cte.rs | 212 +++++++++++++++++++++++++++++
datafusion/sql/src/lib.rs | 1 +
datafusion/sql/src/planner.rs | 5 +
datafusion/sql/src/query.rs | 144 +-------------------
datafusion/sql/src/set_expr.rs | 81 ++++++-----
datafusion/sql/tests/sql_integration.rs | 10 --
datafusion/sqllogictest/test_files/cte.slt | 88 ++++++++++++
7 files changed, 356 insertions(+), 185 deletions(-)
diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs
new file mode 100644
index 0000000000..5b1f81e820
--- /dev/null
+++ b/datafusion/sql/src/cte.rs
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
+
+use arrow::datatypes::Schema;
+use datafusion_common::{
+ not_impl_err, plan_err,
+ tree_node::{TreeNode, TreeNodeRecursion},
+ Result,
+};
+use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
+use sqlparser::ast::{Query, SetExpr, SetOperator, With};
+
+impl<'a, S: ContextProvider> SqlToRel<'a, S> {
+ pub(super) fn plan_with_clause(
+ &self,
+ with: With,
+ planner_context: &mut PlannerContext,
+ ) -> Result<()> {
+ let is_recursive = with.recursive;
+ // Process CTEs from top to bottom
+ for cte in with.cte_tables {
+ // A `WITH` block can't use the same name more than once
+ let cte_name = self.normalizer.normalize(cte.alias.name.clone());
+ if planner_context.contains_cte(&cte_name) {
+ return plan_err!(
+ "WITH query name {cte_name:?} specified more than once"
+ );
+ }
+
+ // Create a logical plan for the CTE
+ let cte_plan = if is_recursive {
+ self.recursive_cte(cte_name.clone(), *cte.query,
planner_context)?
+ } else {
+ self.non_recursive_cte(*cte.query, planner_context)?
+ };
+
+ // Each `WITH` block can change the column names in the last
+ // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
+ let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
+ // Export the CTE to the outer query
+ planner_context.insert_cte(cte_name, final_plan);
+ }
+ Ok(())
+ }
+
+ fn non_recursive_cte(
+ &self,
+ cte_query: Query,
+ planner_context: &mut PlannerContext,
+ ) -> Result<LogicalPlan> {
+ // CTE expr don't need extend outer_query_schema,
+ // so we clone a new planner_context here.
+ let mut cte_planner_context = planner_context.clone();
+ self.query_to_plan(cte_query, &mut cte_planner_context)
+ }
+
+ fn recursive_cte(
+ &self,
+ cte_name: String,
+ mut cte_query: Query,
+ planner_context: &mut PlannerContext,
+ ) -> Result<LogicalPlan> {
+ if !self
+ .context_provider
+ .options()
+ .execution
+ .enable_recursive_ctes
+ {
+ return not_impl_err!("Recursive CTEs are not enabled");
+ }
+
+ let (left_expr, right_expr, set_quantifier) = match *cte_query.body {
+ SetExpr::SetOperation {
+ op: SetOperator::Union,
+ left,
+ right,
+ set_quantifier,
+ } => (left, right, set_quantifier),
+ other => {
+ // If the query is not a UNION, then it is not a recursive CTE
+ cte_query.body = Box::new(other);
+ return self.non_recursive_cte(cte_query, planner_context);
+ }
+ };
+
+ // Each recursive CTE consists from two parts in the logical plan:
+ // 1. A static term (the left hand side on the SQL, where the
+ // referencing to the same CTE is not allowed)
+ //
+ // 2. A recursive term (the right hand side, and the recursive
+ // part)
+
+ // Since static term does not have any specific properties, it can
+ // be compiled as if it was a regular expression. This will
+ // allow us to infer the schema to be used in the recursive term.
+
+ // ---------- Step 1: Compile the static term ------------------
+ let static_plan =
+ self.set_expr_to_plan(*left_expr, &mut planner_context.clone())?;
+
+ // Since the recursive CTEs include a component that references a
+ // table with its name, like the example below:
+ //
+ // WITH RECURSIVE values(n) AS (
+ // SELECT 1 as n -- static term
+ // UNION ALL
+ // SELECT n + 1
+ // FROM values -- self reference
+ // WHERE n < 100
+ // )
+ //
+ // We need a temporary 'relation' to be referenced and used. PostgreSQL
+ // calls this a 'working table', but it is entirely an implementation
+ // detail and a 'real' table with that name might not even exist (as
+ // in the case of DataFusion).
+ //
+ // Since we can't simply register a table during planning stage (it is
+ // an execution problem), we'll use a relation object that preserves
the
+ // schema of the input perfectly and also knows which recursive CTE it
is
+ // bound to.
+
+ // ---------- Step 2: Create a temporary relation ------------------
+ // Step 2.1: Create a table source for the temporary relation
+ let work_table_source = self.context_provider.create_cte_work_table(
+ &cte_name,
+ Arc::new(Schema::from(static_plan.schema().as_ref())),
+ )?;
+
+ // Step 2.2: Create a temporary relation logical plan that will be used
+ // as the input to the recursive term
+ let work_table_plan = LogicalPlanBuilder::scan(
+ cte_name.to_string(),
+ work_table_source.clone(),
+ None,
+ )?
+ .build()?;
+
+ let name = cte_name.clone();
+
+ // Step 2.3: Register the temporary relation in the planning context
+ // For all the self references in the variadic term, we'll replace it
+ // with the temporary relation we created above by temporarily
registering
+ // it as a CTE. This temporary relation in the planning context will be
+ // replaced by the actual CTE plan once we're done with the planning.
+ planner_context.insert_cte(cte_name.clone(), work_table_plan);
+
+ // ---------- Step 3: Compile the recursive term ------------------
+ // this uses the named_relation we inserted above to resolve the
+ // relation. This ensures that the recursive term uses the named
relation logical plan
+ // and thus the 'continuance' physical plan as its input and source
+ let recursive_plan =
+ self.set_expr_to_plan(*right_expr, &mut planner_context.clone())?;
+
+ // Check if the recursive term references the CTE itself,
+ // if not, it is a non-recursive CTE
+ if !has_work_table_reference(&recursive_plan, &work_table_source) {
+ // Remove the work table plan from the context
+ planner_context.remove_cte(&cte_name);
+ // Compile it as a non-recursive CTE
+ return self.set_operation_to_plan(
+ SetOperator::Union,
+ static_plan,
+ recursive_plan,
+ set_quantifier,
+ );
+ }
+
+ // ---------- Step 4: Create the final plan ------------------
+ // Step 4.1: Compile the final plan
+ let distinct = !Self::is_union_all(set_quantifier)?;
+ LogicalPlanBuilder::from(static_plan)
+ .to_recursive_query(name, recursive_plan, distinct)?
+ .build()
+ }
+}
+
+fn has_work_table_reference(
+ plan: &LogicalPlan,
+ work_table_source: &Arc<dyn TableSource>,
+) -> bool {
+ let mut has_reference = false;
+ plan.apply(&mut |node| {
+ if let LogicalPlan::TableScan(scan) = node {
+ if Arc::ptr_eq(&scan.source, work_table_source) {
+ has_reference = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })
+ // Closure always return Ok
+ .unwrap();
+ has_reference
+}
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index 12d6a46696..1040cc61c7 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -28,6 +28,7 @@
//! [`SqlToRel`]: planner::SqlToRel
//! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan
+mod cte;
mod expr;
pub mod parser;
pub mod planner;
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index f94c6ec4e8..d2182962b9 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -213,6 +213,11 @@ impl PlannerContext {
pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> {
self.ctes.get(cte_name).map(|cte| cte.as_ref())
}
+
+ /// Remove the plan of CTE / Subquery for the specified name
+ pub(super) fn remove_cte(&mut self, cte_name: &str) {
+ self.ctes.remove(cte_name);
+ }
}
/// SQL query planner
diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs
index eda8398c43..ba876d052f 100644
--- a/datafusion/sql/src/query.rs
+++ b/datafusion/sql/src/query.rs
@@ -19,21 +19,15 @@ use std::sync::Arc;
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
-use arrow::datatypes::Schema;
-use datafusion_common::{
- not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result,
ScalarValue,
-};
+use datafusion_common::{plan_err, Constraints, Result, ScalarValue};
use datafusion_expr::{
CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan,
LogicalPlanBuilder,
Operator,
};
use sqlparser::ast::{
- Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr,
SetOperator,
- SetQuantifier, Value,
+ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value,
};
-use sqlparser::parser::ParserError::ParserError;
-
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Generate a logical plan from an SQL query
pub(crate) fn query_to_plan(
@@ -54,139 +48,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<LogicalPlan> {
let set_expr = query.body;
if let Some(with) = query.with {
- // Process CTEs from top to bottom
- let is_recursive = with.recursive;
-
- for cte in with.cte_tables {
- // A `WITH` block can't use the same name more than once
- let cte_name =
self.normalizer.normalize(cte.alias.name.clone());
- if planner_context.contains_cte(&cte_name) {
- return sql_err!(ParserError(format!(
- "WITH query name {cte_name:?} specified more than once"
- )));
- }
-
- if is_recursive {
- if !self
- .context_provider
- .options()
- .execution
- .enable_recursive_ctes
- {
- return not_impl_err!("Recursive CTEs are not enabled");
- }
-
- match *cte.query.body {
- SetExpr::SetOperation {
- op: SetOperator::Union,
- left,
- right,
- set_quantifier,
- } => {
- let distinct = set_quantifier !=
SetQuantifier::All;
-
- // Each recursive CTE consists from two parts in
the logical plan:
- // 1. A static term (the left hand side on the
SQL, where the
- // referencing to the same
CTE is not allowed)
- //
- // 2. A recursive term (the right hand side, and
the recursive
- // part)
-
- // Since static term does not have any specific
properties, it can
- // be compiled as if it was a regular expression.
This will
- // allow us to infer the schema to be used in the
recursive term.
-
- // ---------- Step 1: Compile the static term
------------------
- let static_plan = self
- .set_expr_to_plan(*left, &mut
planner_context.clone())?;
-
- // Since the recursive CTEs include a component
that references a
- // table with its name, like the example below:
- //
- // WITH RECURSIVE values(n) AS (
- // SELECT 1 as n -- static term
- // UNION ALL
- // SELECT n + 1
- // FROM values -- self reference
- // WHERE n < 100
- // )
- //
- // We need a temporary 'relation' to be referenced
and used. PostgreSQL
- // calls this a 'working table', but it is
entirely an implementation
- // detail and a 'real' table with that name might
not even exist (as
- // in the case of DataFusion).
- //
- // Since we can't simply register a table during
planning stage (it is
- // an execution problem), we'll use a relation
object that preserves the
- // schema of the input perfectly and also knows
which recursive CTE it is
- // bound to.
-
- // ---------- Step 2: Create a temporary relation
------------------
- // Step 2.1: Create a table source for the
temporary relation
- let work_table_source =
- self.context_provider.create_cte_work_table(
- &cte_name,
-
Arc::new(Schema::from(static_plan.schema().as_ref())),
- )?;
-
- // Step 2.2: Create a temporary relation logical
plan that will be used
- // as the input to the recursive term
- let work_table_plan = LogicalPlanBuilder::scan(
- cte_name.to_string(),
- work_table_source,
- None,
- )?
- .build()?;
-
- let name = cte_name.clone();
-
- // Step 2.3: Register the temporary relation in
the planning context
- // For all the self references in the variadic
term, we'll replace it
- // with the temporary relation we created above by
temporarily registering
- // it as a CTE. This temporary relation in the
planning context will be
- // replaced by the actual CTE plan once we're done
with the planning.
- planner_context.insert_cte(cte_name.clone(),
work_table_plan);
-
- // ---------- Step 3: Compile the recursive term
------------------
- // this uses the named_relation we inserted above
to resolve the
- // relation. This ensures that the recursive term
uses the named relation logical plan
- // and thus the 'continuance' physical plan as its
input and source
- let recursive_plan = self
- .set_expr_to_plan(*right, &mut
planner_context.clone())?;
-
- // ---------- Step 4: Create the final plan
------------------
- // Step 4.1: Compile the final plan
- let logical_plan =
LogicalPlanBuilder::from(static_plan)
- .to_recursive_query(name, recursive_plan,
distinct)?
- .build()?;
-
- let final_plan =
- self.apply_table_alias(logical_plan,
cte.alias)?;
-
- // Step 4.2: Remove the temporary relation from
the planning context and replace it
- // with the final plan.
- planner_context.insert_cte(cte_name.clone(),
final_plan);
- }
- _ => {
- return Err(DataFusionError::SQL(
- ParserError(format!("Unsupported CTE: {cte}")),
- None,
- ));
- }
- };
- } else {
- // create logical plan & pass backreferencing CTEs
- // CTE expr don't need extend outer_query_schema
- let logical_plan =
- self.query_to_plan(*cte.query, &mut
planner_context.clone())?;
-
- // Each `WITH` block can change the column names in the
last
- // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
- let logical_plan = self.apply_table_alias(logical_plan,
cte.alias)?;
-
- planner_context.insert_cte(cte_name, logical_plan);
- }
- }
+ self.plan_with_clause(with, planner_context)?;
}
let plan = self.set_expr_to_plan(*(set_expr.clone()),
planner_context)?;
let plan = self.order_by(plan, query.order_by, planner_context)?;
diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs
index 2cbb68368f..cbe41c33c7 100644
--- a/datafusion/sql/src/set_expr.rs
+++ b/datafusion/sql/src/set_expr.rs
@@ -35,45 +35,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
right,
set_quantifier,
} => {
- let all = match set_quantifier {
- SetQuantifier::All => true,
- SetQuantifier::Distinct | SetQuantifier::None => false,
- SetQuantifier::ByName => {
- return not_impl_err!("UNION BY NAME not implemented");
- }
- SetQuantifier::AllByName => {
- return not_impl_err!("UNION ALL BY NAME not
implemented")
- }
- SetQuantifier::DistinctByName => {
- return not_impl_err!("UNION DISTINCT BY NAME not
implemented")
- }
- };
-
let left_plan = self.set_expr_to_plan(*left, planner_context)?;
let right_plan = self.set_expr_to_plan(*right,
planner_context)?;
- match (op, all) {
- (SetOperator::Union, true) =>
LogicalPlanBuilder::from(left_plan)
- .union(right_plan)?
- .build(),
- (SetOperator::Union, false) =>
LogicalPlanBuilder::from(left_plan)
- .union_distinct(right_plan)?
- .build(),
- (SetOperator::Intersect, true) => {
- LogicalPlanBuilder::intersect(left_plan, right_plan,
true)
- }
- (SetOperator::Intersect, false) => {
- LogicalPlanBuilder::intersect(left_plan, right_plan,
false)
- }
- (SetOperator::Except, true) => {
- LogicalPlanBuilder::except(left_plan, right_plan, true)
- }
- (SetOperator::Except, false) => {
- LogicalPlanBuilder::except(left_plan, right_plan,
false)
- }
- }
+ self.set_operation_to_plan(op, left_plan, right_plan,
set_quantifier)
}
SetExpr::Query(q) => self.query_to_plan(*q, planner_context),
_ => not_impl_err!("Query {set_expr} not implemented yet"),
}
}
+
+ pub(super) fn is_union_all(set_quantifier: SetQuantifier) -> Result<bool> {
+ match set_quantifier {
+ SetQuantifier::All => Ok(true),
+ SetQuantifier::Distinct | SetQuantifier::None => Ok(false),
+ SetQuantifier::ByName => {
+ not_impl_err!("UNION BY NAME not implemented")
+ }
+ SetQuantifier::AllByName => {
+ not_impl_err!("UNION ALL BY NAME not implemented")
+ }
+ SetQuantifier::DistinctByName => {
+ not_impl_err!("UNION DISTINCT BY NAME not implemented")
+ }
+ }
+ }
+
+ pub(super) fn set_operation_to_plan(
+ &self,
+ op: SetOperator,
+ left_plan: LogicalPlan,
+ right_plan: LogicalPlan,
+ set_quantifier: SetQuantifier,
+ ) -> Result<LogicalPlan> {
+ let all = Self::is_union_all(set_quantifier)?;
+ match (op, all) {
+ (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan)
+ .union(right_plan)?
+ .build(),
+ (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan)
+ .union_distinct(right_plan)?
+ .build(),
+ (SetOperator::Intersect, true) => {
+ LogicalPlanBuilder::intersect(left_plan, right_plan, true)
+ }
+ (SetOperator::Intersect, false) => {
+ LogicalPlanBuilder::intersect(left_plan, right_plan, false)
+ }
+ (SetOperator::Except, true) => {
+ LogicalPlanBuilder::except(left_plan, right_plan, true)
+ }
+ (SetOperator::Except, false) => {
+ LogicalPlanBuilder::except(left_plan, right_plan, false)
+ }
+ }
+ }
}
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index 101c31039c..a34f8f07fe 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -2994,16 +2994,6 @@ fn join_with_aliases() {
quick_test(sql, expected);
}
-#[test]
-fn cte_use_same_name_multiple_times() {
- let sql =
- "with a as (select * from person), a as (select * from orders) select
* from a;";
- let expected =
- "SQL error: ParserError(\"WITH query name \\\"a\\\" specified more
than once\")";
- let result = logical_plan(sql).err().unwrap();
- assert_eq!(result.strip_backtrace(), expected);
-}
-
#[test]
fn negative_interval_plus_interval_in_projection() {
let sql = "select -interval '2 days' + interval '5 days';";
diff --git a/datafusion/sqllogictest/test_files/cte.slt
b/datafusion/sqllogictest/test_files/cte.slt
index e33dfabaf2..eec7eb0e33 100644
--- a/datafusion/sqllogictest/test_files/cte.slt
+++ b/datafusion/sqllogictest/test_files/cte.slt
@@ -39,6 +39,37 @@ physical_plan
ProjectionExec: expr=[1 as a, 2 as b, 3 as c]
--PlaceholderRowExec
+# cte_use_same_name_multiple_times
+statement error DataFusion error: Error during planning: WITH query name "a"
specified more than once
+WITH a AS (SELECT 1), a AS (SELECT 2) SELECT * FROM a;
+
+# Test disabling recursive CTE
+statement ok
+set datafusion.execution.enable_recursive_ctes = false;
+
+query error DataFusion error: This feature is not implemented: Recursive CTEs
are not enabled
+WITH RECURSIVE nodes AS (
+ SELECT 1 as id
+ UNION ALL
+ SELECT id + 1 as id
+ FROM nodes
+ WHERE id < 3
+) SELECT * FROM nodes
+
+statement ok
+set datafusion.execution.enable_recursive_ctes = true;
+
+
+# DISTINCT UNION is not supported
+query error DataFusion error: This feature is not implemented: Recursive
queries with a distinct 'UNION' \(in which the previous iteration's results
will be de\-duplicated\) is not supported
+WITH RECURSIVE nodes AS (
+ SELECT 1 as id
+ UNION
+ SELECT id + 1 as id
+ FROM nodes
+ WHERE id < 3
+) SELECT * FROM nodes
+
# trivial recursive CTE works
query I rowsort
@@ -744,3 +775,60 @@ WITH RECURSIVE my_cte AS (
UNION ALL
SELECT 'abc' FROM my_cte WHERE CAST(a AS text) !='abc'
) SELECT * FROM my_cte;
+
+# Define a non-recursive CTE in the recursive WITH clause.
+# Test issue: https://github.com/apache/arrow-datafusion/issues/9804
+query I
+WITH RECURSIVE cte AS (
+ SELECT a FROM (VALUES(1)) AS t(a) WHERE a > 2
+ UNION ALL
+ SELECT 2
+) SELECT * FROM cte;
+----
+2
+
+# Define a non-recursive CTE in the recursive WITH clause.
+# UNION ALL
+query I rowsort
+WITH RECURSIVE cte AS (
+ SELECT 1
+ UNION ALL
+ SELECT 2
+) SELECT * FROM cte;
+----
+1
+2
+
+# Define a non-recursive CTE in the recursive WITH clause.
+# DISTINCT UNION
+query I
+WITH RECURSIVE cte AS (
+ SELECT 2
+ UNION
+ SELECT 2
+) SELECT * FROM cte;
+----
+2
+
+# Define a non-recursive CTE in the recursive WITH clause.
+# UNION is not present.
+query I
+WITH RECURSIVE cte AS (
+ SELECT 1
+) SELECT * FROM cte;
+----
+1
+
+# Define a recursive CTE and a non-recursive CTE at the same time.
+query II rowsort
+WITH RECURSIVE
+non_recursive_cte AS (
+ SELECT 1
+),
+recursive_cte AS (
+ SELECT 1 AS a UNION ALL SELECT a+2 FROM recursive_cte WHERE a < 3
+)
+SELECT * FROM non_recursive_cte, recursive_cte;
+----
+1 1
+1 3