This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 405644fc4ace [SPARK-50130][SQL][FOLLOWUP] Simplify the implementation
of col.outer()
405644fc4ace is described below
commit 405644fc4ace98fdfeb8702ef230b42e042bd24d
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Nov 28 11:10:39 2024 +0800
[SPARK-50130][SQL][FOLLOWUP] Simplify the implementation of col.outer()
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/48664 to
simplify the code. The new workflow is:
1. The `col.outer()` simplify wraps the expression with `LazyExpression`
2. `QueryExecution` does lazy analysis if its main query contains
`LazyExpression`. Eager analysis is still performed if only subquery
expressions contain `LazyExpression`.
3. The analyzer simply strips away `LazyExpression` at the beginning.
After this simplification, we no longer need the special logic to strip
`LazyOuterReference` in the DataFrame side. We no longer need the extra flag in
the subquery expressions. It also makes the API easy to understand:
`col.outer()` is just used to trigger lazy analysis for Spark Classic
### Why are the changes needed?
cleanup
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48820 from cloud-fan/subquery.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Co-authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 -
python/pyspark/sql/column.py | 6 +-
python/pyspark/sql/dataframe.py | 32 ++--
python/pyspark/sql/tests/test_subquery.py | 186 ++++++++++-----------
.../main/scala/org/apache/spark/sql/Column.scala | 31 ++--
.../apache/spark/sql/internal/columnNodes.scala | 33 ++--
.../spark/sql/catalyst/analysis/Analyzer.scala | 27 +--
.../sql/catalyst/analysis/CheckAnalysis.scala | 21 +--
.../catalyst/analysis/ColumnResolutionHelper.scala | 44 ++---
.../analysis/EliminateLazyExpression.scala | 34 ++++
.../spark/sql/catalyst/analysis/unresolved.scala | 47 +-----
.../sql/catalyst/expressions/Expression.scala | 16 +-
.../spark/sql/catalyst/expressions/subquery.scala | 15 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +-
.../spark/sql/catalyst/optimizer/expressions.scala | 2 +-
.../spark/sql/catalyst/optimizer/subquery.scala | 15 +-
.../spark/sql/catalyst/trees/TreePatterns.scala | 6 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 14 +-
.../spark/sql/execution/QueryExecution.scala | 8 +-
.../adaptive/PlanAdaptiveSubqueries.scala | 2 +-
.../spark/sql/internal/columnNodeSupport.scala | 16 +-
.../apache/spark/sql/DataFrameSubquerySuite.scala | 173 +++++++++----------
22 files changed, 331 insertions(+), 405 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 3c494704fd71..77437f6c5617 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4748,12 +4748,6 @@
],
"sqlState" : "42KD9"
},
- "UNANALYZABLE_EXPRESSION" : {
- "message" : [
- "The plan contains an unanalyzable expression <expr> that holds the
analysis."
- ],
- "sqlState" : "03000"
- },
"UNBOUND_SQL_PARAMETER" : {
"message" : [
"Found the unbound parameter: <name>. Please, fix `args` and provide a
mapping of the parameter to either a SQL literal or collection constructor
functions such as `map()`, `array()`, `struct()`."
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 06dd2860fe40..285d30fad3bc 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -1524,7 +1524,11 @@ class Column:
@dispatch_col_method
def outer(self) -> "Column":
"""
- Mark this column reference as an outer reference for subqueries.
+ Mark this column as an outer column if its expression refers to
columns from an outer query.
+
+ This is used to trigger lazy analysis of Spark Classic DataFrame, so
that we can use it
+ to build subquery expressions. Spark Connect DataFrame is always
lazily analyzed and
+ does not need to use this function.
.. versionadded:: 4.0.0
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8a5b982bc7f2..085a1a629634 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6522,10 +6522,11 @@ class DataFrame:
in their department.
>>> from pyspark.sql import functions as sf
- >>> employees.where(
+ >>> employees.alias("e1").where(
... sf.col("salary")
- ... > employees.where(sf.col("department_id") ==
sf.col("department_id").outer())
- ... .select(sf.avg("salary")).scalar()
+ ... > employees.alias("e2").where(
+ ... sf.col("e2.department_id") ==
sf.col("e1.department_id").outer()
+ ... ).select(sf.avg("salary")).scalar()
... ).select("name", "salary", "department_id").show()
+-----+------+-------------+
| name|salary|department_id|
@@ -6538,12 +6539,13 @@ class DataFrame:
department.
>>> from pyspark.sql import functions as sf
- >>> employees.select(
+ >>> employees.alias("e1").select(
... "name", "salary", "department_id",
... sf.format_number(
... sf.lit(100) * sf.col("salary") /
- ... employees.where(sf.col("department_id") ==
sf.col("department_id").outer())
- ... .select(sf.sum("salary")).scalar().alias("avg_salary"),
+ ... employees.alias("e2").where(
+ ... sf.col("e2.department_id") ==
sf.col("e1.department_id").outer()
+ ...
).select(sf.sum("salary")).scalar().alias("avg_salary"),
... 1
... ).alias("salary_proportion_in_department")
... ).show()
@@ -6595,8 +6597,10 @@ class DataFrame:
Example 1: Filter for customers who have placed at least one order.
>>> from pyspark.sql import functions as sf
- >>> customers.where(
- ... orders.where(sf.col("customer_id") ==
sf.col("customer_id").outer()).exists()
+ >>> customers.alias("c").where(
+ ... orders.alias("o").where(
+ ... sf.col("o.customer_id") == sf.col("c.customer_id").outer()
+ ... ).exists()
... ).orderBy("customer_id").show()
+-----------+-------------+-------+
|customer_id|customer_name|country|
@@ -6609,8 +6613,10 @@ class DataFrame:
Example 2: Filter for customers who have never placed an order.
>>> from pyspark.sql import functions as sf
- >>> customers.where(
- ... ~orders.where(sf.col("customer_id") ==
sf.col("customer_id").outer()).exists()
+ >>> customers.alias("c").where(
+ ... ~orders.alias("o").where(
+ ... sf.col("o.customer_id") == sf.col("c.customer_id").outer()
+ ... ).exists()
... ).orderBy("customer_id").show()
+-----------+-------------+---------+
|customer_id|customer_name| country|
@@ -6621,9 +6627,9 @@ class DataFrame:
Example 3: Find Orders from Customers in the USA.
>>> from pyspark.sql import functions as sf
- >>> orders.where(
- ... customers.where(
- ... (sf.col("customer_id") == sf.col("customer_id").outer())
+ >>> orders.alias("o").where(
+ ... customers.alias("c").where(
+ ... (sf.col("c.customer_id") ==
sf.col("o.customer_id").outer())
... & (sf.col("country") == "USA")
... ).exists()
... ).orderBy("order_id").show()
diff --git a/python/pyspark/sql/tests/test_subquery.py
b/python/pyspark/sql/tests/test_subquery.py
index f58ff6364aed..7cc0360c3942 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -47,18 +47,21 @@ class SubqueryTestsMixin:
["c", "d"],
)
- def test_unanalyzable_expression(self):
- sub = self.spark.range(1).where(sf.col("id") == sf.col("id").outer())
+ def test_noop_outer(self):
+ assertDataFrameEqual(
+ self.spark.range(1).select(sf.col("id").outer()),
+ self.spark.range(1).select(sf.col("id")),
+ )
with self.assertRaises(AnalysisException) as pe:
- sub.schema
+ self.spark.range(1).select(sf.col("outer_col").outer()).collect()
self.check_error(
exception=pe.exception,
- errorClass="UNANALYZABLE_EXPRESSION",
- messageParameters={"expr": '"outer(id)"'},
+ errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ messageParameters={"objectName": "`outer_col`", "proposal":
"`id`"},
query_context_type=QueryContextType.DataFrame,
- fragment="outer",
+ fragment="col",
)
def test_simple_uncorrelated_scalar_subquery(self):
@@ -189,7 +192,7 @@ class SubqueryTestsMixin:
"c1",
(
self.spark.table("t2")
- .where(sf.col("c2").outer() == sf.col("c2"))
+ .where(sf.col("t1.c2").outer() == sf.col("t2.c2"))
.select(sf.max("c1"))
.scalar()
),
@@ -205,45 +208,72 @@ class SubqueryTestsMixin:
self.df2.createOrReplaceTempView("r")
with self.subTest("in where"):
- assertDataFrameEqual(
- self.spark.table("l").where(
- sf.col("b")
- < (
- self.spark.table("r")
- .where(sf.col("a").outer() == sf.col("c"))
- .select(sf.max("d"))
- .scalar()
+ for cond in [
+ sf.col("a").outer() == sf.col("c"),
+ (sf.col("a") == sf.col("c")).outer(),
+ sf.expr("a = c").outer(),
+ ]:
+ with self.subTest(cond=cond):
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ sf.col("b")
+ <
self.spark.table("r").where(cond).select(sf.max("d")).scalar()
+ ),
+ self.spark.sql(
+ """select * from l where b < (select max(d)
from r where a = c)"""
+ ),
)
- ),
- self.spark.sql(
- """select * from l where b < (select max(d) from r
where a = c)"""
- ),
- )
with self.subTest("in select"):
+ df1 = self.spark.table("l").alias("t1")
+ df2 = self.spark.table("l").alias("t2")
+
+ for cond in [
+ sf.col("t1.a") == sf.col("t2.a").outer(),
+ (sf.col("t1.a") == sf.col("t2.a")).outer(),
+ sf.expr("t1.a = t2.a").outer(),
+ ]:
+ with self.subTest(cond=cond):
+ assertDataFrameEqual(
+ df1.select(
+ "a",
+
df2.where(cond).select(sf.sum("b")).scalar().alias("sum_b"),
+ ),
+ self.spark.sql(
+ """
+ select
+ a, (select sum(b) from l t2 where t2.a =
t1.a) sum_b
+ from l t1
+ """
+ ),
+ )
+
+ with self.subTest("without .outer()"):
assertDataFrameEqual(
self.spark.table("l").select(
"a",
(
- self.spark.table("l")
- .where(sf.col("a") == sf.col("a").outer())
- .select(sf.sum("b"))
+ self.spark.table("r")
+ .where(sf.col("b") == sf.col("a").outer())
+ .select(sf.sum("d"))
.scalar()
- .alias("sum_b")
+ .alias("sum_d")
),
),
self.spark.sql(
- """select a, (select sum(b) from l l2 where l2.a =
l1.a) sum_b from l l1"""
+ """select a, (select sum(d) from r where b = l.a)
sum_d from l"""
),
)
with self.subTest("in select (null safe)"):
+ df1 = self.spark.table("l").alias("t1")
+ df2 = self.spark.table("l").alias("t2")
+
assertDataFrameEqual(
- self.spark.table("l").select(
+ df1.select(
"a",
(
- self.spark.table("l")
- .where(sf.col("a").eqNullSafe(sf.col("a").outer()))
+
df2.where(sf.col("t2.a").eqNullSafe(sf.col("t1.a").outer()))
.select(sf.sum("b"))
.scalar()
.alias("sum_b")
@@ -278,15 +308,13 @@ class SubqueryTestsMixin:
)
with self.subTest("non-aggregated"):
+ df1 = self.spark.table("l").alias("t1")
+ df2 = self.spark.table("l").alias("t2")
+
with self.assertRaises(SparkRuntimeException) as pe:
- self.spark.table("l").select(
+ df1.select(
"a",
- (
- self.spark.table("l")
- .where(sf.col("a") == sf.col("a").outer())
- .select("b")
- .scalar()
- ),
+ df2.where(sf.col("t1.a") ==
sf.col("t2.a").outer()).select("b").scalar(),
).collect()
self.check_error(
@@ -296,19 +324,21 @@ class SubqueryTestsMixin:
)
with self.subTest("non-equal"):
+ df1 = self.spark.table("l").alias("t1")
+ df2 = self.spark.table("l").alias("t2")
+
assertDataFrameEqual(
- self.spark.table("l").select(
+ df1.select(
"a",
(
- self.spark.table("l")
- .where(sf.col("a") < sf.col("a").outer())
+ df2.where(sf.col("t2.a") < sf.col("t1.a").outer())
.select(sf.sum("b"))
.scalar()
.alias("sum_b")
),
),
self.spark.sql(
- """select a, (select sum(b) from l l2 where l2.a <
l1.a) sum_b from l l1"""
+ """select a, (select sum(b) from l t2 where t2.a <
t1.a) sum_b from l t1"""
),
)
@@ -343,26 +373,30 @@ class SubqueryTestsMixin:
self.df2.createOrReplaceTempView("r")
with self.subTest("EXISTS"):
- assertDataFrameEqual(
- self.spark.table("l").where(
- self.spark.table("r").where(sf.col("a").outer() ==
sf.col("c")).exists()
- ),
- self.spark.sql(
- """select * from l where exists (select * from r where
l.a = r.c)"""
- ),
- )
+ for cond in [
+ sf.col("a").outer() == sf.col("c"),
+ (sf.col("a") == sf.col("c")).outer(),
+ sf.expr("a = c").outer(),
+ ]:
+ with self.subTest(cond=cond):
+ assertDataFrameEqual(
+
self.spark.table("l").where(self.spark.table("r").where(cond).exists()),
+ self.spark.sql(
+ """select * from l where exists (select * from
r where l.a = r.c)"""
+ ),
+ )
- assertDataFrameEqual(
- self.spark.table("l").where(
- self.spark.table("r").where(sf.col("a").outer() ==
sf.col("c")).exists()
- & (sf.col("a") <= sf.lit(2))
- ),
- self.spark.sql(
- """
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ self.spark.table("r").where(cond).exists()
+ & (sf.col("a") <= sf.lit(2))
+ ),
+ self.spark.sql(
+ """
select * from l where exists (select * from r where
l.a = r.c) and l.a <= 2
"""
- ),
- )
+ ),
+ )
with self.subTest("NOT EXISTS"):
assertDataFrameEqual(
@@ -450,46 +484,6 @@ class SubqueryTestsMixin:
fragment="col",
)
- with self.subTest("extra `outer()`"):
- with self.assertRaises(AnalysisException) as pe:
- self.spark.table("l").select(
- "a",
- (
- self.spark.table("r")
- .where(sf.col("c").outer() == sf.col("a").outer())
- .select(sf.sum("d"))
- .scalar()
- ),
- ).collect()
-
- self.check_error(
- exception=pe.exception,
- errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
- messageParameters={"objectName": "`c`", "proposal": "`a`,
`b`"},
- query_context_type=QueryContextType.DataFrame,
- fragment="outer",
- )
-
- with self.subTest("missing `outer()` for another outer"):
- with self.assertRaises(AnalysisException) as pe:
- self.spark.table("l").select(
- "a",
- (
- self.spark.table("r")
- .where(sf.col("b") == sf.col("a").outer())
- .select(sf.sum("d"))
- .scalar()
- ),
- ).collect()
-
- self.check_error(
- exception=pe.exception,
- errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
- messageParameters={"objectName": "`b`", "proposal": "`c`,
`d`"},
- query_context_type=QueryContextType.DataFrame,
- fragment="col",
- )
-
class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
pass
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
index 8498ae04d9a2..50ef61d4a7a1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.DataTypeParser
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{lit, map}
-import org.apache.spark.sql.internal.{ColumnNode, LazyOuterReference,
UnresolvedAttribute}
+import org.apache.spark.sql.internal.ColumnNode
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -1383,20 +1383,27 @@ class Column(val node: ColumnNode) extends Logging {
def over(): Column = over(Window.spec)
/**
- * Marks this column reference as an outer reference for subqueries.
+ * Mark this column as an outer column if its expression refers to columns
from an outer query.
+ * This is used to trigger lazy analysis of Spark Classic DataFrame, so that
we can use it to
+ * build subquery expressions. Spark Connect DataFrame is always lazily
analyzed and does not
+ * need to use this function.
*
- * @group subquery
+ * {{{
+ * // Spark can't analyze this `df` now as it doesn't know how to resolve
`t1.col`.
+ * val df = spark.table("t2").where($"t2.col" === $"t1.col".outer())
+ *
+ * // Since this `df` is lazily analyzed, you won't see any error until
you try to execute it.
+ * df.collect() // Fails with UNRESOLVED_COLUMN error.
+ *
+ * // Now Spark can resolve `t1.col` with the outer plan
`spark.table("t1")`.
+ * spark.table("t1").where(df.exists())
+ * }}}
+ *
+ * @group expr_ops
* @since 4.0.0
*/
- def outer(): Column = withOrigin {
- node match {
- case attr: UnresolvedAttribute if !attr.isMetadataColumn =>
- Column(LazyOuterReference(attr.nameParts, attr.planId))
- case _ =>
- throw new IllegalArgumentException(
- "Only unresolved attributes can be used as outer references")
- }
- }
+ def outer(): Column = Column(internal.LazyExpression(node))
+
}
/**
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index e3cc320a8b00..f745c152170e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -167,24 +167,6 @@ private[sql] object UnresolvedAttribute {
apply(unparsedIdentifier, None, false, CurrentOrigin.get)
}
-/**
- * Reference to an attribute in the outer context, used for Subqueries.
- *
- * @param nameParts
- * name of the attribute.
- * @param planId
- * id of the plan (Dataframe) that produces the attribute.
- */
-private[sql] case class LazyOuterReference(
- nameParts: Seq[String],
- planId: Option[Long] = None,
- override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
- override private[internal] def normalize(): LazyOuterReference =
- copy(planId = None, origin = NO_ORIGIN)
- override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`"
else n).mkString(".")
-}
-
/**
* Reference to all columns in a namespace (global, a Dataframe, or a nested
struct).
*
@@ -593,3 +575,18 @@ private[sql] case class InvokeInlineUserDefinedFunction(
private[sql] trait UserDefinedFunctionLike {
def name: String = SparkClassUtils.getFormattedClassName(this)
}
+
+/**
+ * A marker node to trigger Spark Classic DataFrame lazy analysis.
+ *
+ * @param child
+ * that needs to be lazily analyzed in Spark Classic DataFrame.
+ */
+private[sql] case class LazyExpression(
+ child: ColumnNode,
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
+ override private[internal] def normalize(): ColumnNode =
+ copy(child = child.normalize(), origin = NO_ORIGIN)
+ override def sql: String = "lazy" + argumentsToSql(Seq(child))
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e05f3533ae3c..84b3ca2289f4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -279,7 +279,8 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
- SubstituteUnresolvedOrdinals),
+ SubstituteUnresolvedOrdinals,
+ EliminateLazyExpression),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
@@ -2190,23 +2191,12 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
* can resolve outer references.
*
* Outer references of the subquery are updated as children of Subquery
expression.
- *
- * If hasExplicitOuterRefs is true, the subquery should have an explicit
outer reference,
- * instead of common `UnresolvedAttribute`s. In this case, tries to
resolve inner and outer
- * references separately.
*/
private def resolveSubQuery(
e: SubqueryExpression,
- outer: LogicalPlan,
- hasExplicitOuterRefs: Boolean = false)(
+ outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression):
SubqueryExpression = {
- val newSubqueryPlan = if (hasExplicitOuterRefs) {
- executeSameContext(e.plan).transformAllExpressionsWithPruning(
- _.containsPattern(UNRESOLVED_OUTER_REFERENCE)) {
- case u: UnresolvedOuterReference =>
- resolveOuterReference(u.nameParts, outer).getOrElse(u)
- }
- } else AnalysisContext.withOuterPlan(outer) {
+ val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
}
@@ -2231,11 +2221,10 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan):
LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION),
ruleId) {
- case s @ ScalarSubquery(sub, _, exprId, _, _, _, _,
hasExplicitOuterRefs)
- if !sub.resolved =>
- resolveSubQuery(s, outer, hasExplicitOuterRefs)(ScalarSubquery(_, _,
exprId))
- case e @ Exists(sub, _, exprId, _, _, hasExplicitOuterRefs) if
!sub.resolved =>
- resolveSubQuery(e, outer, hasExplicitOuterRefs)(Exists(_, _, exprId))
+ case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
+ resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
+ case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
+ resolveSubQuery(e, outer)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 586a0312e150..573619af1b5f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -457,11 +457,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
errorClass = "UNBOUND_SQL_PARAMETER",
messageParameters = Map("name" -> p.name))
- case l: LazyAnalysisExpression =>
- l.failAnalysis(
- errorClass = "UNANALYZABLE_EXPRESSION",
- messageParameters = Map("expr" -> toSQLExpr(l)))
-
case _ =>
})
@@ -1067,20 +1062,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
case _ =>
}
- def checkUnresolvedOuterReference(p: LogicalPlan, expr:
SubqueryExpression): Unit = {
- expr.plan.foreachUp(_.expressions.foreach(_.foreachUp {
- case o: UnresolvedOuterReference =>
- val cols = p.inputSet.toSeq.map(attr =>
toSQLId(attr.name)).mkString(", ")
- o.failAnalysis(
- errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
- messageParameters = Map("objectName" -> toSQLId(o.name),
"proposal" -> cols))
- case _ =>
- }))
- }
-
- // Check if there is unresolved outer attribute in the subquery plan.
- checkUnresolvedOuterReference(plan, expr)
-
// Validate the subquery plan.
checkAnalysis0(expr.plan)
@@ -1088,7 +1069,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
checkOuterReference(plan, expr)
expr match {
- case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) =>
+ case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw
QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index e869cb281ce0..36fd4d02f8da 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -221,35 +221,35 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
val outerPlan = AnalysisContext.get.outerPlan
if (outerPlan.isEmpty) return e
- e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE,
TEMP_RESOLVED_COLUMN)) {
+ def resolve(nameParts: Seq[String]): Option[Expression] = try {
+ outerPlan.get match {
+ // Subqueries in UnresolvedHaving can host grouping expressions and
aggregate functions.
+ // We should resolve columns with `agg.output` and the rule
`ResolveAggregateFunctions` will
+ // push them down to Aggregate later. This is similar to what we do in
`resolveColumns`.
+ case u @ UnresolvedHaving(_, agg: Aggregate) =>
+ agg.resolveChildren(nameParts, conf.resolver)
+ .orElse(u.resolveChildren(nameParts, conf.resolver))
+ .map(wrapOuterReference)
+ case other =>
+ other.resolveChildren(nameParts,
conf.resolver).map(wrapOuterReference)
+ }
+ } catch {
+ case ae: AnalysisException =>
+ logDebug(ae.getMessage)
+ None
+ }
+
+ e.transformWithPruning(
+ _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
case u: UnresolvedAttribute =>
- resolveOuterReference(u.nameParts, outerPlan.get).getOrElse(u)
+ resolve(u.nameParts).getOrElse(u)
// Re-resolves `TempResolvedColumn` as outer references if it has tried
to be resolved with
// Aggregate but failed.
case t: TempResolvedColumn if t.hasTried =>
- resolveOuterReference(t.nameParts, outerPlan.get).getOrElse(t)
+ resolve(t.nameParts).getOrElse(t)
}
}
- protected def resolveOuterReference(
- nameParts: Seq[String], outerPlan: LogicalPlan): Option[Expression] =
try {
- outerPlan match {
- // Subqueries in UnresolvedHaving can host grouping expressions and
aggregate functions.
- // We should resolve columns with `agg.output` and the rule
`ResolveAggregateFunctions` will
- // push them down to Aggregate later. This is similar to what we do in
`resolveColumns`.
- case u @ UnresolvedHaving(_, agg: Aggregate) =>
- agg.resolveChildren(nameParts, conf.resolver)
- .orElse(u.resolveChildren(nameParts, conf.resolver))
- .map(wrapOuterReference)
- case other =>
- other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference)
- }
- } catch {
- case ae: AnalysisException =>
- logDebug(ae.getMessage)
- None
- }
-
def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = {
// The temp variables live in `SYSTEM.SESSION`, and the name can be
qualified or not.
def maybeTempVariableName(nameParts: Seq[String]): Boolean = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/EliminateLazyExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/EliminateLazyExpression.scala
new file mode 100644
index 000000000000..68f3f90e193b
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/EliminateLazyExpression.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.LAZY_EXPRESSION
+
+/**
+ * `LazyExpression` is a marker node to trigger lazy analysis in DataFrames.
It's useless when
+ * entering the analyzer and this rule removes it.
+ */
+object EliminateLazyExpression extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.resolveExpressionsUpWithPruning(_.containsPattern(LAZY_EXPRESSION)) {
+ case l: LazyExpression => l.child
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 40994f42e71d..f366339d95c0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -1004,42 +1004,13 @@ case class UnresolvedTranspose(
copy(child = newChild)
}
-case class UnresolvedOuterReference(
- nameParts: Seq[String])
- extends LeafExpression with NamedExpression with Unevaluable {
-
- def name: String =
- nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
-
- override def exprId: ExprId = throw new UnresolvedException("exprId")
- override def dataType: DataType = throw new UnresolvedException("dataType")
- override def nullable: Boolean = throw new UnresolvedException("nullable")
- override def qualifier: Seq[String] = throw new
UnresolvedException("qualifier")
- override lazy val resolved = false
-
- override def toAttribute: Attribute = throw new
UnresolvedException("toAttribute")
- override def newInstance(): UnresolvedOuterReference = this
-
- final override val nodePatterns: Seq[TreePattern] =
Seq(UNRESOLVED_OUTER_REFERENCE)
-}
-
-case class LazyOuterReference(
- nameParts: Seq[String])
- extends LeafExpression with NamedExpression with Unevaluable with
LazyAnalysisExpression {
-
- def name: String =
- nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
-
- override def exprId: ExprId = throw new UnresolvedException("exprId")
- override def dataType: DataType = throw new UnresolvedException("dataType")
- override def nullable: Boolean = throw new UnresolvedException("nullable")
- override def qualifier: Seq[String] = throw new
UnresolvedException("qualifier")
-
- override def toAttribute: Attribute = throw new
UnresolvedException("toAttribute")
- override def newInstance(): NamedExpression = LazyOuterReference(nameParts)
-
- override def nodePatternsInternal(): Seq[TreePattern] =
Seq(LAZY_OUTER_REFERENCE)
-
- override def prettyName: String = "outer"
- override def sql: String = s"$prettyName($name)"
+// A marker node to indicate that the logical plan containing this expression
should be lazily
+// analyzed in the DataFrame. This node will be removed at the beginning of
analysis.
+case class LazyExpression(child: Expression) extends UnaryExpression with
Unevaluable {
+ override lazy val resolved: Boolean = false
+ override def dataType: DataType = child.dataType
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(child = newChild)
+ }
+ final override val nodePatterns: Seq[TreePattern] = Seq(LAZY_EXPRESSION)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c45479985282..2090aab3b1f5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -28,7 +28,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin,
LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
-import
org.apache.spark.sql.catalyst.trees.TreePattern.{LAZY_ANALYSIS_EXPRESSION,
RUNTIME_REPLACEABLE, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE,
TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
@@ -410,20 +410,6 @@ trait Unevaluable extends Expression with
FoldableUnevaluable {
final override def foldable: Boolean = false
}
-/**
- * An expression that cannot be analyzed. These expressions don't live
analysis time or after
- * and should not be evaluated during query planning and execution.
- */
-trait LazyAnalysisExpression extends Expression {
- final override lazy val resolved = false
-
- final override val nodePatterns: Seq[TreePattern] =
- Seq(LAZY_ANALYSIS_EXPRESSION) ++ nodePatternsInternal()
-
- // Subclasses can override this function to provide more TreePatterns.
- def nodePatternsInternal(): Seq[TreePattern] = Seq()
-}
-
/**
* An expression that gets replaced at runtime (currently by the optimizer)
into a different
* expression for evaluation. This is mainly used to provide compatibility
with other databases.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index bd6f65b61468..0c8253659dd5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -19,11 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference,
UnresolvedOuterReference}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -374,13 +372,6 @@ object SubExprUtils extends PredicateHelper {
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs
nonEquivalentGroupByExprs
}
-
- def removeLazyOuterReferences(logicalPlan: LogicalPlan): LogicalPlan = {
- logicalPlan.transformAllExpressionsWithPruning(
- _.containsPattern(TreePattern.LAZY_OUTER_REFERENCE)) {
- case or: LazyOuterReference => UnresolvedOuterReference(or.nameParts)
- }
- }
}
/**
@@ -407,8 +398,7 @@ case class ScalarSubquery(
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None,
mayHaveCountBug: Option[Boolean] = None,
- needSingleJoin: Option[Boolean] = None,
- hasExplicitOuterRefs: Boolean = false)
+ needSingleJoin: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with
Unevaluable {
override def dataType: DataType = {
if (!plan.schema.fields.nonEmpty) {
@@ -577,8 +567,7 @@ case class Exists(
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty,
- hint: Option[HintInfo] = None,
- hasExplicitOuterRefs: Boolean = false)
+ hint: Option[HintInfo] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint)
with Predicate
with Unevaluable {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0772c67ea27e..7ec467badce5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -346,7 +346,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(
PhysicalOperation(projections, predicates, a @ Aggregate(group, _,
child, _)),
- _, _, _, _, mayHaveCountBug, _, _)
+ _, _, _, _, mayHaveCountBug, _)
if
conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG)
&&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
// This is a subquery with an aggregate that may suffer from a COUNT
bug.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 754fea85ec6d..e867953bcf28 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
// Don't replace ScalarSubquery if its plan is an aggregate that may
suffer from a COUNT bug.
- case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _, _)
+ case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _)
if
conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG)
&&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
s
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 8c82769dbf4a..5a4e9f37c395 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -131,12 +131,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, Exists(sub, _, _, conditions, subHint, _)) =>
+ case (p, Exists(sub, _, _, conditions, subHint)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val join = buildJoin(outerPlan,
rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond),
LeftSemi, joinCond, subHint)
Project(p.output, join)
- case (p, Not(Exists(sub, _, _, conditions, subHint, _))) =>
+ case (p, Not(Exists(sub, _, _, conditions, subHint))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val join = buildJoin(outerPlan,
rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond),
LeftAnti, joinCond, subHint)
@@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
val introducedAttrs = ArrayBuffer.empty[Attribute]
val newExprs = exprs.map { e =>
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY,
IN_SUBQUERY)) {
- case Exists(sub, _, _, conditions, subHint, _) =>
+ case Exists(sub, _, _, conditions, subHint) =>
val exists = AttributeReference("exists", BooleanType, nullable =
false)()
val existenceJoin = ExistenceJoin(exists)
val newCondition = conditions.reduceLeftOption(And)
@@ -507,7 +507,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan]
with PredicateHelper
plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint,
- mayHaveCountBugOld, needSingleJoinOld, _)
+ mayHaveCountBugOld, needSingleJoinOld)
if children.nonEmpty =>
def mayHaveCountBugAgg(a: Aggregate): Boolean = {
@@ -560,7 +560,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan]
with PredicateHelper
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond,
conditions),
hint, Some(mayHaveCountBug), Some(needSingleJoin))
- case Exists(sub, children, exprId, conditions, hint, _) if
children.nonEmpty =>
+ case Exists(sub, children, exprId, conditions, hint) if
children.nonEmpty =>
val (newPlan, newCond) = if
(SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) {
decorrelate(sub, plan, handleCountBug = true)
} else {
@@ -818,7 +818,7 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] with AliasHelpe
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint,
mayHaveCountBug,
- needSingleJoin, _)) =>
+ needSingleJoin)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild,
sub, conditions)
val origOutput = query.output.head
// The subquery appears on the right side of the join, hence add its
hint to the right
@@ -1064,8 +1064,7 @@ object OptimizeOneRowRelationSubquery extends
Rule[LogicalPlan] {
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
- case s @ ScalarSubquery(
- OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _,
_)
+ case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _:
OneRowRelation)), _, _, _, _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(p.projectList.size == 1)
stripOuterReferences(p.projectList).head
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 24b787054fb1..7435f4c52703 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -63,6 +63,7 @@ object TreePattern extends Enumeration {
val LAMBDA_VARIABLE: Value = Value
val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value
val LATERAL_SUBQUERY: Value = Value
+ val LAZY_EXPRESSION: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
val LITERAL: Value = Value
@@ -154,7 +155,6 @@ object TreePattern extends Enumeration {
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value
- val UNRESOLVED_OUTER_REFERENCE: Value = Value
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
@@ -169,8 +169,4 @@ object TreePattern extends Enumeration {
// Execution Plan patterns (alphabetically ordered)
val EXCHANGE: Value = Value
-
- // Lazy analysis expression patterns (alphabetically ordered)
- val LAZY_ANALYSIS_EXPRESSION: Value = Value
- val LAZY_OUTER_REFERENCE: Value = Value
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4766a74308a1..a74d93b44db9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -280,9 +280,9 @@ class Dataset[T] private[sql](
// The resolved `ExpressionEncoder` which can be used to turn rows to
objects of type T, after
// collecting rows to the driver side.
- private lazy val resolvedEnc = {
- exprEnc.resolveAndBind(logicalPlan.output,
sparkSession.sessionState.analyzer)
- }
+ private lazy val resolvedEnc = exprEnc.resolveAndBind(
+ queryExecution.commandExecuted.output, sparkSession.sessionState.analyzer)
+
private implicit def classTag: ClassTag[T] = encoder.clsTag
@@ -996,16 +996,12 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def scalar(): Column = {
- Column(ExpressionColumnNode(
- ScalarSubqueryExpr(SubExprUtils.removeLazyOuterReferences(logicalPlan),
- hasExplicitOuterRefs = true)))
+ Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan)))
}
/** @inheritdoc */
def exists(): Column = {
- Column(ExpressionColumnNode(
- Exists(SubExprUtils.removeLazyOuterReferences(logicalPlan),
- hasExplicitOuterRefs = true)))
+ Column(ExpressionColumnNode(Exists(logicalPlan)))
}
/** @inheritdoc */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 490184c93620..5695ea57e7fb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -31,12 +31,11 @@ import
org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row,
SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
+import org.apache.spark.sql.catalyst.analysis.{LazyExpression,
UnsupportedOperationChecker}
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command,
CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression,
OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
-import org.apache.spark.sql.catalyst.trees.TreePattern.LAZY_ANALYSIS_EXPRESSION
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext,
InsertAdaptiveSparkPlan}
@@ -69,7 +68,10 @@ class QueryExecution(
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
- lazy val isLazyAnalysis: Boolean =
logical.containsAnyPattern(LAZY_ANALYSIS_EXPRESSION)
+ lazy val isLazyAnalysis: Boolean = {
+ // Only check the main query as subquery expression can be resolved now
with the main query.
+
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression])))
+ }
def assertAnalyzed(): Unit = {
try {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index 35a815d83922..5f2638655c37 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY,
DYNAMIC_PRUNING_SUBQUERY)) {
- case expressions.ScalarSubquery(_, _, exprId, _, _, _, _, _) =>
+ case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) =>
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", subqueryMap(exprId.id))
execution.ScalarSubquery(subquery, exprId)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 64eacba1c6bf..00e9a01f33c1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -88,9 +88,6 @@ private[sql] trait ColumnNodeToExpressionConverter extends
(ColumnNode => Expres
isDistinct = isDistinct,
isInternal = isInternal)
- case LazyOuterReference(nameParts, planId, _) =>
- convertLazyOuterReference(nameParts, planId)
-
case Alias(child, Seq(name), None, _) =>
expressions.Alias(apply(child), name)(
nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY,
Dataset.COL_POS_KEY))
@@ -193,6 +190,9 @@ private[sql] trait ColumnNodeToExpressionConverter extends
(ColumnNode => Expres
case _ => transformed
}
+ case l: LazyExpression =>
+ analysis.LazyExpression(apply(l.child))
+
case node =>
throw SparkException.internalError("Unsupported ColumnNode: " + node)
}
@@ -248,16 +248,6 @@ private[sql] trait ColumnNodeToExpressionConverter extends
(ColumnNode => Expres
}
attribute
}
-
- private def convertLazyOuterReference(
- nameParts: Seq[String],
- planId: Option[Long]): analysis.LazyOuterReference = {
- val lazyOuterReference = analysis.LazyOuterReference(nameParts)
- if (planId.isDefined) {
- lazyOuterReference.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get)
- }
- lazyOuterReference
- }
}
private[sql] object ColumnNodeToExpressionConverter extends
ColumnNodeToExpressionConverter {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index d656c36ce842..2420ad34d9ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -53,23 +53,15 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
r.createOrReplaceTempView("r")
}
- test("unanalyzable expression") {
- val sub = spark.range(1).select($"id" === $"id".outer())
-
- checkError(
- intercept[AnalysisException](sub.schema),
- condition = "UNANALYZABLE_EXPRESSION",
- parameters = Map("expr" -> "\"outer(id)\""),
- queryContext =
- Array(ExpectedContext(fragment = "outer", callSitePattern =
getCurrentClassCallSitePattern))
- )
-
+ test("noop outer()") {
+ checkAnswer(spark.range(1).select($"id".outer()), Row(0))
checkError(
- intercept[AnalysisException](sub.encoder),
- condition = "UNANALYZABLE_EXPRESSION",
- parameters = Map("expr" -> "\"outer(id)\""),
- queryContext =
- Array(ExpectedContext(fragment = "outer", callSitePattern =
getCurrentClassCallSitePattern))
+
intercept[AnalysisException](spark.range(1).select($"outer_col".outer()).collect()),
+ "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ parameters = Map("objectName" -> "`outer_col`", "proposal" -> "`id`"),
+ context = ExpectedContext(
+ fragment = "$",
+ callSitePattern = getCurrentClassCallSitePattern)
)
}
@@ -148,6 +140,64 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
}
+ test("correlated scalar subquery in SELECT with outer() function") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq(
+ $"t1.a" === $"t2.a".outer(),
+ ($"t1.a" === $"t2.a").outer(),
+ expr("t1.a = t2.a").outer()).foreach { cond =>
+ checkAnswer(
+ df1.select(
+ $"a",
+ df2.where(cond).select(sum($"b")).scalar().as("sum_b")
+ ),
+ sql("select a, (select sum(b) from l t1 where t1.a = t2.a) sum_b from
l t2")
+ )
+ }
+ }
+
+ test("correlated scalar subquery in WHERE with outer() function") {
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq(
+ $"a".outer() === $"c",
+ ($"a" === $"c").outer(),
+ expr("a = c").outer()).foreach { cond =>
+ checkAnswer(
+ spark.table("l").where(
+ $"b" < spark.table("r").where(cond).select(max($"d")).scalar()
+ ),
+ sql("select * from l where b < (select max(d) from r where a = c)")
+ )
+ }
+ }
+
+ test("EXISTS predicate subquery with outer() function") {
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq(
+ $"a".outer() === $"c",
+ ($"a" === $"c").outer(),
+ expr("a = c").outer()).foreach { cond =>
+ checkAnswer(
+ spark.table("l").where(
+ spark.table("r").where(cond).exists()
+ ),
+ sql("select * from l where exists (select * from r where l.a = r.c)")
+ )
+
+ checkAnswer(
+ spark.table("l").where(
+ spark.table("r").where(cond).exists() && $"a" <= lit(2)
+ ),
+ sql("select * from l where exists (select * from r where l.a = r.c)
and l.a <= 2")
+ )
+ }
+ }
+
test("SPARK-15677: Queries against local relations with scalar subquery in
Select list") {
withTempView("t1", "t2") {
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
@@ -192,22 +242,6 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
}
- test("EXISTS predicate subquery") {
- checkAnswer(
- spark.table("l").where(
- spark.table("r").where($"a".outer() === $"c").exists()
- ),
- sql("select * from l where exists (select * from r where l.a = r.c)")
- )
-
- checkAnswer(
- spark.table("l").where(
- spark.table("r").where($"a".outer() === $"c").exists() && $"a" <=
lit(2)
- ),
- sql("select * from l where exists (select * from r where l.a = r.c) and
l.a <= 2")
- )
- }
-
test("NOT EXISTS predicate subquery") {
checkAnswer(
spark.table("l").where(
@@ -244,32 +278,15 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
)
}
- test("correlated scalar subquery in where") {
- checkAnswer(
- spark.table("l").where(
- $"b" < spark.table("r").where($"a".outer() ===
$"c").select(max($"d")).scalar()
- ),
- sql("select * from l where b < (select max(d) from r where a = c)")
- )
- }
-
- test("correlated scalar subquery in select") {
+ test("correlated scalar subquery in select (null safe equal)") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
checkAnswer(
- spark.table("l").select(
+ df1.select(
$"a",
- spark.table("l").where($"a" ===
$"a".outer()).select(sum($"b")).scalar().as("sum_b")
+ df2.where($"t2.a" <=>
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")
),
- sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l
l1")
- )
- }
-
- test("correlated scalar subquery in select (null safe)") {
- checkAnswer(
- spark.table("l").select(
- $"a",
- spark.table("l").where($"a" <=>
$"a".outer()).select(sum($"b")).scalar().as("sum_b")
- ),
- sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from
l l1")
+ sql("select a, (select sum(b) from l t2 where t2.a <=> t1.a) sum_b from
l t1")
)
}
@@ -300,10 +317,12 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
test("non-aggregated correlated scalar subquery") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
val exception1 = intercept[SparkRuntimeException] {
- spark.table("l").select(
+ df1.select(
$"a",
- spark.table("l").where($"a" ===
$"a".outer()).select($"b").scalar().as("sum_b")
+ df2.where($"t1.a" ===
$"t2.a".outer()).select($"b").scalar().as("sum_b")
).collect()
}
checkError(
@@ -313,12 +332,14 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
test("non-equal correlated scalar subquery") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
checkAnswer(
- spark.table("l").select(
+ df1.select(
$"a",
- spark.table("l").where($"a" <
$"a".outer()).select(sum($"b")).scalar().as("sum_b")
+ df2.where($"t2.a" <
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")
),
- sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l
l1")
+ sql("select a, (select sum(b) from l t2 where t2.a < t1.a) sum_b from l
t1")
)
}
@@ -346,7 +367,7 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
spark.table("l").select(
$"a",
spark.table("r").where($"c" === $"a").select(sum($"d")).scalar()
- ).collect()
+ )
}
checkError(
exception1,
@@ -355,35 +376,5 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
queryContext =
Array(ExpectedContext(fragment = "$", callSitePattern =
getCurrentClassCallSitePattern))
)
-
- // Extra `outer()`
- val exception2 = intercept[AnalysisException] {
- spark.table("l").select(
- $"a",
- spark.table("r").where($"c".outer() ===
$"a".outer()).select(sum($"d")).scalar()
- ).collect()
- }
- checkError(
- exception2,
- condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
- parameters = Map("objectName" -> "`c`", "proposal" -> "`a`, `b`"),
- queryContext =
- Array(ExpectedContext(fragment = "outer", callSitePattern =
getCurrentClassCallSitePattern))
- )
-
- // Missing `outer()` for another outer
- val exception3 = intercept[AnalysisException] {
- spark.table("l").select(
- $"a",
- spark.table("r").where($"b" ===
$"a".outer()).select(sum($"d")).scalar()
- ).collect()
- }
- checkError(
- exception3,
- condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
- parameters = Map("objectName" -> "`b`", "proposal" -> "`c`, `d`"),
- queryContext =
- Array(ExpectedContext(fragment = "$", callSitePattern =
getCurrentClassCallSitePattern))
- )
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]