This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new aac494e74c60 [SPARK-50134][SPARK-50130][SQL][CONNECT] Support
DataFrame API for SCALAR and EXISTS subqueries in Spark Connect
aac494e74c60 is described below
commit aac494e74c60acfd9abd65386a931fdbbf75c433
Author: Takuya Ueshin <[email protected]>
AuthorDate: Thu Dec 26 13:31:02 2024 -0800
[SPARK-50134][SPARK-50130][SQL][CONNECT] Support DataFrame API for SCALAR
and EXISTS subqueries in Spark Connect
### What changes were proposed in this pull request?
Supports DataFrame API for SCALAR and EXISTS subqueries in Spark Connect.
The proto plan will be, using `with_relations`:
```
with_relations [id 10]
root: plan [id 9] using subquery expressions holding plan ids
reference:
refs#1: [id 8] plan for the subquery 1
refs#2: [id 5] plan for the subquery 2
```
### Why are the changes needed?
DataFrame APIs for SCALAR and EXISTS subqueries are missing in Spark
Connect.
### Does this PR introduce _any_ user-facing change?
Yes, SCALAR and EXISTS subqueries will be available in Spark Connect.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49182 from ueshin/issues/SPARK-50130/scalar_exists_connect.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 68 ++--
.../spark/sql/RelationalGroupedDataset.scala | 2 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 95 ++++-
.../org/apache/spark/sql/TableValuedFunction.scala | 2 +-
.../spark/sql/internal/columnNodeSupport.scala | 29 ++
.../apache/spark/sql/DataFrameSubquerySuite.scala | 440 +++++++++++++++++++++
.../internal/ColumnNodeToProtoConverterSuite.scala | 1 +
.../org/apache/spark/sql/test/QueryTest.scala | 155 ++++++++
.../org/apache/spark/sql/test/SQLHelper.scala | 17 +
python/pyspark/sql/connect/dataframe.py | 21 +-
python/pyspark/sql/connect/expressions.py | 100 ++++-
python/pyspark/sql/connect/plan.py | 234 +++++++----
.../pyspark/sql/connect/proto/expressions_pb2.py | 156 ++++----
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 53 +++
python/pyspark/sql/dataframe.py | 8 +-
.../sql/tests/connect/test_parity_subquery.py | 57 ++-
python/pyspark/sql/tests/test_subquery.py | 121 +++++-
python/pyspark/testing/utils.py | 28 +-
.../apache/spark/sql/internal/columnNodes.scala | 69 +++-
.../spark/sql/catalyst/analysis/unresolved.scala | 15 +
.../spark/sql/catalyst/expressions/subquery.scala | 17 +
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../main/protobuf/spark/connect/expressions.proto | 15 +
.../sql/connect/planner/SparkConnectPlanner.scala | 62 ++-
.../spark/sql/internal/columnNodeSupport.scala | 2 +
.../apache/spark/sql/DataFrameSubquerySuite.scala | 98 +++++
.../ColumnNodeToExpressionConverterSuite.scala | 1 +
27 files changed, 1593 insertions(+), 274 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index ffaa8a70cc7c..75df538678a3 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
-import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter,
DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF,
UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
+import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter,
DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl,
SubqueryExpressionNode, SubqueryType, ToScalaUDF, UDFAdaptors,
UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
@@ -288,9 +288,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
- private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit):
DataFrame = {
+ private def buildJoin(right: Dataset[_], cols: Seq[Column] = Seq.empty)(
+ f: proto.Join.Builder => Unit): DataFrame = {
checkSameSparkSession(right)
- sparkSession.newDataFrame { builder =>
+ sparkSession.newDataFrame(cols) { builder =>
val joinBuilder = builder.getJoinBuilder
joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
f(joinBuilder)
@@ -334,7 +335,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame
= {
- buildJoin(right) { builder =>
+ buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
.setJoinCondition(joinExprs.expr)
@@ -394,7 +395,7 @@ class Dataset[T] private[sql] (
case _ =>
throw new IllegalArgumentException(s"Unsupported lateral join type
$joinType")
}
- sparkSession.newDataFrame { builder =>
+ sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
@@ -426,7 +427,7 @@ class Dataset[T] private[sql] (
val sortExprs = sortCols.map { c =>
ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
}
- sparkSession.newDataset(agnosticEncoder) { builder =>
+ sparkSession.newDataset(agnosticEncoder, sortCols) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
@@ -502,7 +503,7 @@ class Dataset[T] private[sql] (
* methods and typed select methods is the encoder used to build the return
dataset.
*/
private def selectUntyped(encoder: AgnosticEncoder[_], cols: Seq[Column]):
Dataset[_] = {
- sparkSession.newDataset(encoder) { builder =>
+ sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
@@ -510,29 +511,32 @@ class Dataset[T] private[sql] (
}
/** @inheritdoc */
- def filter(condition: Column): Dataset[T] =
sparkSession.newDataset(agnosticEncoder) {
- builder =>
+ def filter(condition: Column): Dataset[T] = {
+ sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+ }
}
private def buildUnpivot(
ids: Array[Column],
valuesOption: Option[Array[Column]],
variableColumnName: String,
- valueColumnName: String): DataFrame = sparkSession.newDataFrame {
builder =>
- val unpivot = builder.getUnpivotBuilder
- .setInput(plan.getRoot)
- .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
- .setVariableColumnName(variableColumnName)
- .setValueColumnName(valueColumnName)
- valuesOption.foreach { values =>
- unpivot.getValuesBuilder
- .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+ valueColumnName: String): DataFrame = {
+ sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) {
builder =>
+ val unpivot = builder.getUnpivotBuilder
+ .setInput(plan.getRoot)
+ .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
+ .setVariableColumnName(variableColumnName)
+ .setValueColumnName(valueColumnName)
+ valuesOption.foreach { values =>
+ unpivot.getValuesBuilder
+ .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+ }
}
}
private def buildTranspose(indices: Seq[Column]): DataFrame =
- sparkSession.newDataFrame { builder =>
+ sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
transpose.addIndexColumns(indexColumn.expr)
@@ -624,18 +628,15 @@ class Dataset[T] private[sql] (
def transpose(): DataFrame =
buildTranspose(Seq.empty)
- // TODO(SPARK-50134): Support scalar Subquery API in Spark Connect
- // scalastyle:off not.implemented.error.usage
/** @inheritdoc */
def scalar(): Column = {
- ???
+ Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR))
}
/** @inheritdoc */
def exists(): Column = {
- ???
+ Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS))
}
- // scalastyle:on not.implemented.error.usage
/** @inheritdoc */
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
@@ -782,7 +783,7 @@ class Dataset[T] private[sql] (
val aliases = values.zip(names).map { case (value, name) =>
value.name(name).expr.getAlias
}
- sparkSession.newDataFrame { builder =>
+ sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
.setInput(plan.getRoot)
.addAllAliases(aliases.asJava)
@@ -842,10 +843,12 @@ class Dataset[T] private[sql] (
@scala.annotation.varargs
def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)
- private def buildDrop(cols: Seq[Column]): DataFrame =
sparkSession.newDataFrame { builder =>
- builder.getDropBuilder
- .setInput(plan.getRoot)
- .addAllColumns(cols.map(_.expr).asJava)
+ private def buildDrop(cols: Seq[Column]): DataFrame = {
+ sparkSession.newDataFrame(cols) { builder =>
+ builder.getDropBuilder
+ .setInput(plan.getRoot)
+ .addAllColumns(cols.map(_.expr).asJava)
+ }
}
private def buildDropByNames(cols: Seq[String]): DataFrame =
sparkSession.newDataFrame {
@@ -1015,12 +1018,13 @@ class Dataset[T] private[sql] (
private def buildRepartitionByExpression(
numPartitions: Option[Int],
- partitionExprs: Seq[Column]): Dataset[T] =
sparkSession.newDataset(agnosticEncoder) {
- builder =>
+ partitionExprs: Seq[Column]): Dataset[T] = {
+ sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
+ }
}
/** @inheritdoc */
@@ -1152,7 +1156,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
- sparkSession.newDataset(agnosticEncoder) { builder =>
+ sparkSession.newDataset(agnosticEncoder, expr +: exprs) { builder =>
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 5bded40b0d13..0944c88a6790 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -45,7 +45,7 @@ class RelationalGroupedDataset private[sql] (
import df.sparkSession.RichColumn
protected def toDF(aggExprs: Seq[Column]): DataFrame = {
- df.sparkSession.newDataFrame { builder =>
+ df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
val aggBuilder = builder.getAggregateBuilder
.setInput(df.plan.getRoot)
groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b6bba8251913..89519034d07c 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder,
CloseableIterator, Spar
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig,
SessionCleaner, SessionState, SharedState, SqlApiConf}
+import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig,
SessionCleaner, SessionState, SharedState, SqlApiConf, SubqueryExpressionNode}
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr,
toTypedExpr}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming.DataStreamReader
@@ -324,20 +324,111 @@ class SparkSession private[sql] (
}
}
+ /**
+ * Create a DataFrame including the proto plan built by the given function.
+ *
+ * @param f
+ * The function to build the proto plan.
+ * @return
+ * The DataFrame created from the proto plan.
+ */
@Since("4.0.0")
@DeveloperApi
def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder)(f)
}
+ /**
+ * Create a DataFrame including the proto plan built by the given function.
+ *
+ * Use this method when columns are used to create a new DataFrame. When
there are columns
+ * referring to other Dataset or DataFrame, the plan will be wrapped with a
`WithRelation`.
+ *
+ * {{{
+ * with_relations [id 10]
+ * root: plan [id 9] using columns referring to other Dataset or
DataFrame, holding plan ids
+ * reference:
+ * refs#1: [id 8] plan for the reference 1
+ * refs#2: [id 5] plan for the reference 2
+ * }}}
+ *
+ * @param cols
+ * The columns to be used in the DataFrame.
+ * @param f
+ * The function to build the proto plan.
+ * @return
+ * The DataFrame created from the proto plan.
+ */
+ @Since("4.0.0")
+ @DeveloperApi
+ def newDataFrame(cols: Seq[Column])(f: proto.Relation.Builder => Unit):
DataFrame = {
+ newDataset(UnboundRowEncoder, cols)(f)
+ }
+
+ /**
+ * Create a Dataset including the proto plan built by the given function.
+ *
+ * @param encoder
+ * The encoder for the Dataset.
+ * @param f
+ * The function to build the proto plan.
+ * @return
+ * The Dataset created from the proto plan.
+ */
@Since("4.0.0")
@DeveloperApi
def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
+ newDataset[T](encoder, Seq.empty)(f)
+ }
+
+ /**
+ * Create a Dataset including the proto plan built by the given function.
+ *
+ * Use this method when columns are used to create a new Dataset. When there
are columns
+ * referring to other Dataset or DataFrame, the plan will be wrapped with a
`WithRelation`.
+ *
+ * {{{
+ * with_relations [id 10]
+ * root: plan [id 9] using columns referring to other Dataset or
DataFrame, holding plan ids
+ * reference:
+ * refs#1: [id 8] plan for the reference 1
+ * refs#2: [id 5] plan for the reference 2
+ * }}}
+ *
+ * @param encoder
+ * The encoder for the Dataset.
+ * @param cols
+ * The columns to be used in the DataFrame.
+ * @param f
+ * The function to build the proto plan.
+ * @return
+ * The Dataset created from the proto plan.
+ */
+ @Since("4.0.0")
+ @DeveloperApi
+ def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])(
+ f: proto.Relation.Builder => Unit): Dataset[T] = {
+ val references = cols.flatMap(_.node.collect { case n:
SubqueryExpressionNode =>
+ n.relation
+ })
+
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
- val plan = proto.Plan.newBuilder().setRoot(builder).build()
+
+ val rootBuilder = if (references.length == 0) {
+ builder
+ } else {
+ val rootBuilder = proto.Relation.newBuilder()
+ rootBuilder.getWithRelationsBuilder
+ .setRoot(builder)
+ .addAllReferences(references.asJava)
+ rootBuilder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
+ rootBuilder
+ }
+
+ val plan = proto.Plan.newBuilder().setRoot(rootBuilder).build()
new Dataset[T](this, plan, encoder)
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
index 4f2687b53786..2a5afd1d5871 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
@@ -47,7 +47,7 @@ class TableValuedFunction(sparkSession: SparkSession) extends
api.TableValuedFun
}
private def fn(name: String, args: Seq[Column]): Dataset[Row] = {
- sparkSession.newDataFrame { builder =>
+ sparkSession.newDataFrame(args) { builder =>
builder.getUnresolvedTableValuedFunctionBuilder
.setFunctionName(name)
.addAllArguments(args.map(toExpr).asJava)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 0e8889e19de2..8d57a8d3efd4 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -167,6 +167,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode =>
proto.Expression) {
case LazyExpression(child, _) =>
builder.getLazyExpressionBuilder.setChild(apply(child, e))
+ case SubqueryExpressionNode(relation, subqueryType, _) =>
+ val b = builder.getSubqueryExpressionBuilder
+ b.setSubqueryType(subqueryType match {
+ case SubqueryType.SCALAR =>
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR
+ case SubqueryType.EXISTS =>
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS
+ })
+ assert(relation.hasCommon && relation.getCommon.hasPlanId)
+ b.setPlanId(relation.getCommon.getPlanId)
+
case ProtoColumnNode(e, _) =>
return e
@@ -217,4 +226,24 @@ case class ProtoColumnNode(
override val origin: Origin = CurrentOrigin.get)
extends ColumnNode {
override def sql: String = expr.toString
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+}
+
+sealed trait SubqueryType
+
+object SubqueryType {
+ case object SCALAR extends SubqueryType
+ case object EXISTS extends SubqueryType
+}
+
+case class SubqueryExpressionNode(
+ relation: proto.Relation,
+ subqueryType: SubqueryType,
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
+ override def sql: String = subqueryType match {
+ case SubqueryType.SCALAR => s"($relation)"
+ case _ => s"$subqueryType ($relation)"
+ }
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 91f60b1fefb9..fc37444f7719 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -17,12 +17,306 @@
package org.apache.spark.sql
+import org.apache.spark.{SparkException, SparkRuntimeException}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession}
class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession {
import testImplicits._
+ val row = identity[(java.lang.Integer, java.lang.Double)](_)
+
+ lazy val l = Seq(
+ row((1, 2.0)),
+ row((1, 2.0)),
+ row((2, 1.0)),
+ row((2, 1.0)),
+ row((3, 3.0)),
+ row((null, null)),
+ row((null, 5.0)),
+ row((6, null))).toDF("a", "b")
+
+ lazy val r = Seq(
+ row((2, 3.0)),
+ row((2, 3.0)),
+ row((3, 2.0)),
+ row((4, 1.0)),
+ row((null, null)),
+ row((null, 5.0)),
+ row((6, null))).toDF("c", "d")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ l.createOrReplaceTempView("l")
+ r.createOrReplaceTempView("r")
+ }
+
+ test("noop outer()") {
+ checkAnswer(spark.range(1).select($"id".outer()), Row(0))
+ checkError(
+
intercept[AnalysisException](spark.range(1).select($"outer_col".outer()).collect()),
+ "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ parameters = Map("objectName" -> "`outer_col`", "proposal" -> "`id`"))
+ }
+
+ test("simple uncorrelated scalar subquery") {
+ checkAnswer(
+ spark.range(1).select(spark.range(1).select(lit(1)).scalar().as("b")),
+ sql("select (select 1 as b) as b"))
+
+ checkAnswer(
+ spark
+ .range(1)
+ .select(
+ spark.range(1).select(spark.range(1).select(lit(1)).scalar() +
1).scalar() + lit(1)),
+ sql("select (select (select 1) + 1) + 1"))
+
+ // string type
+ checkAnswer(
+ spark.range(1).select(spark.range(1).select(lit("s")).scalar().as("b")),
+ sql("select (select 's' as s) as b"))
+ }
+
+ test("uncorrelated scalar subquery should return null if there is 0 rows") {
+ checkAnswer(
+
spark.range(1).select(spark.range(1).select(lit("s")).limit(0).scalar().as("b")),
+ sql("select (select 's' as s limit 0) as b"))
+ }
+
+ test("uncorrelated scalar subquery on a DataFrame generated query") {
+ withTempView("subqueryData") {
+ val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
+ df.createOrReplaceTempView("subqueryData")
+
+ checkAnswer(
+ spark
+ .range(1)
+ .select(
+ spark
+ .table("subqueryData")
+ .select($"key")
+ .where($"key" > 2)
+ .orderBy($"key")
+ .limit(1)
+ .scalar() + lit(1)),
+ sql("select (select key from subqueryData where key > 2 order by key
limit 1) + 1"))
+
+ checkAnswer(
+
spark.range(1).select(-spark.table("subqueryData").select(max($"key")).scalar()),
+ sql("select -(select max(key) from subqueryData)"))
+
+ checkAnswer(
+
spark.range(1).select(spark.table("subqueryData").select($"value").limit(0).scalar()),
+ sql("select (select value from subqueryData limit 0)"))
+
+ checkAnswer(
+ spark
+ .range(1)
+ .select(
+ spark
+ .table("subqueryData")
+ .where($"key" ===
spark.table("subqueryData").select(max($"key")).scalar() - lit(1))
+ .select(min($"value"))
+ .scalar()),
+ sql(
+ "select (select min(value) from subqueryData" +
+ " where key = (select max(key) from subqueryData) - 1)"))
+ }
+ }
+
+ test("correlated scalar subquery in SELECT with outer() function") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq($"t1.a" === $"t2.a".outer(), ($"t1.a" === $"t2.a").outer(), expr("t1.a
= t2.a").outer())
+ .foreach { cond =>
+ checkAnswer(
+ df1.select($"a",
df2.where(cond).select(sum($"b")).scalar().as("sum_b")),
+ sql("select a, (select sum(b) from l t1 where t1.a = t2.a) sum_b
from l t2"))
+ }
+ }
+
+ test("correlated scalar subquery in WHERE with outer() function") {
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq($"a".outer() === $"c", ($"a" === $"c").outer(), expr("a =
c").outer()).foreach { cond =>
+ checkAnswer(
+ spark.table("l").where($"b" <
spark.table("r").where(cond).select(max($"d")).scalar()),
+ sql("select * from l where b < (select max(d) from r where a = c)"))
+ }
+ }
+
+ test("EXISTS predicate subquery with outer() function") {
+ // We can use the `.outer()` function to wrap either the outer column, or
the entire condition,
+ // or the SQL string of the condition.
+ Seq($"a".outer() === $"c", ($"a" === $"c").outer(), expr("a =
c").outer()).foreach { cond =>
+ checkAnswer(
+ spark.table("l").where(spark.table("r").where(cond).exists()),
+ sql("select * from l where exists (select * from r where l.a = r.c)"))
+
+ checkAnswer(
+ spark.table("l").where(spark.table("r").where(cond).exists() && $"a"
<= lit(2)),
+ sql("select * from l where exists (select * from r where l.a = r.c)
and l.a <= 2"))
+ }
+ }
+
+ test("SPARK-15677: Queries against local relations with scalar subquery in
Select list") {
+ withTempView("t1", "t2") {
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ checkAnswer(
+
spark.table("t1").select(spark.range(1).select(lit(1).as("col")).scalar()),
+ sql("SELECT (select 1 as col) from t1"))
+
+ checkAnswer(
+
spark.table("t1").select(spark.table("t2").select(max($"c1")).scalar()),
+ sql("SELECT (select max(c1) from t2) from t1"))
+
+ checkAnswer(
+ spark.table("t1").select(lit(1) +
spark.range(1).select(lit(1).as("col")).scalar()),
+ sql("SELECT 1 + (select 1 as col) from t1"))
+
+ checkAnswer(
+ spark.table("t1").select($"c1",
spark.table("t2").select(max($"c1")).scalar() + $"c2"),
+ sql("SELECT c1, (select max(c1) from t2) + c2 from t1"))
+
+ checkAnswer(
+ spark
+ .table("t1")
+ .select(
+ $"c1",
+ spark.table("t2").where($"t1.c2".outer() ===
$"t2.c2").select(max($"c1")).scalar()),
+ sql("SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1"))
+ }
+ }
+
+ test("NOT EXISTS predicate subquery") {
+ checkAnswer(
+ spark.table("l").where(!spark.table("r").where($"a".outer() ===
$"c").exists()),
+ sql("select * from l where not exists (select * from r where l.a =
r.c)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where(!spark.table("r").where($"a".outer() === $"c" && $"b".outer() <
$"d").exists()),
+ sql("select * from l where not exists (select * from r where l.a = r.c
and l.b < r.d)"))
+ }
+
+ test("EXISTS predicate subquery within OR") {
+ checkAnswer(
+ spark
+ .table("l")
+ .where(spark.table("r").where($"a".outer() === $"c").exists() ||
+ spark.table("r").where($"a".outer() === $"c").exists()),
+ sql(
+ "select * from l where exists (select * from r where l.a = r.c)" +
+ " or exists (select * from r where l.a = r.c)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where(!spark.table("r").where($"a".outer() === $"c" && $"b".outer() <
$"d").exists() ||
+ !spark.table("r").where($"a".outer() === $"c").exists()),
+ sql(
+ "select * from l where not exists (select * from r where l.a = r.c and
l.b < r.d)" +
+ " or not exists (select * from r where l.a = r.c)"))
+ }
+
+ test("correlated scalar subquery in select (null safe equal)") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
+ checkAnswer(
+ df1.select(
+ $"a",
+ df2.where($"t2.a" <=>
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")),
+ sql("select a, (select sum(b) from l t2 where t2.a <=> t1.a) sum_b from
l t1"))
+ }
+
+ test("correlated scalar subquery in aggregate") {
+ checkAnswer(
+ spark
+ .table("l")
+ .groupBy(
+ $"a",
+ spark.table("r").where($"a".outer() ===
$"c").select(sum($"d")).scalar().as("sum_d"))
+ .agg(Map.empty[String, String]),
+ sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group
by 1, 2"))
+ }
+
+ test("SPARK-34269: correlated subquery with view in aggregate's grouping
expression") {
+ withTable("tr") {
+ withView("vr") {
+ r.write.saveAsTable("tr")
+ sql("create view vr as select * from tr")
+ checkAnswer(
+ spark
+ .table("l")
+ .groupBy(
+ $"a",
+ spark
+ .table("vr")
+ .where($"a".outer() === $"c")
+ .select(sum($"d"))
+ .scalar()
+ .as("sum_d"))
+ .agg(Map.empty[String, String]),
+ sql("select a, (select sum(d) from vr where a = c) sum_d from l l1
group by 1, 2"))
+ }
+ }
+ }
+
+ test("non-aggregated correlated scalar subquery") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
+ val exception1 = intercept[SparkRuntimeException] {
+ df1
+ .select($"a", df2.where($"t1.a" ===
$"t2.a".outer()).select($"b").scalar().as("sum_b"))
+ .collect()
+ }
+ checkError(exception1, condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS")
+ }
+
+ test("non-equal correlated scalar subquery") {
+ val df1 = spark.table("l").as("t1")
+ val df2 = spark.table("l").as("t2")
+ checkAnswer(
+ df1.select(
+ $"a",
+ df2.where($"t2.a" <
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")),
+ sql("select a, (select sum(b) from l t2 where t2.a < t1.a) sum_b from l
t1"))
+ }
+
+ test("disjunctive correlated scalar subquery") {
+ checkAnswer(
+ spark
+ .table("l")
+ .where(
+ spark
+ .table("r")
+ .where(($"a".outer() === $"c" && $"d" === 2.0) ||
+ ($"a".outer() === $"c" && $"d" === 1.0))
+ .select(count(lit(1)))
+ .scalar() > 0)
+ .select($"a"),
+ sql("""
+ |select a
+ |from l
+ |where (select count(*)
+ | from r
+ | where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0
+ """.stripMargin))
+ }
+
+ test("correlated scalar subquery with missing outer reference") {
+ checkAnswer(
+ spark
+ .table("l")
+ .select($"a", spark.table("r").where($"c" ===
$"a").select(sum($"d")).scalar()),
+ sql("select a, (select sum(d) from r where c = a) from l"))
+ }
+
private def table1() = {
sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
spark.table("t1")
@@ -182,6 +476,60 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
}
}
+ test("scalar subquery inside lateral join") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ // uncorrelated
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select($"c2".outer(),
t2.select(min($"c2")).scalar()))
+ .toDF("c1", "c2", "c3", "c4"),
+ sql("SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2) FROM t2))")
+ .toDF("c1", "c2", "c3", "c4"))
+
+ // correlated
+ checkAnswer(
+ t1.lateralJoin(
+ spark
+ .range(1)
+ .select($"c1".outer().as("a"))
+ .select(t2.where($"c1" ===
$"a".outer()).select(sum($"c2")).scalar())),
+ sql("""
+ |SELECT * FROM t1, LATERAL (
+ | SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM (SELECT
c1 AS a)
+ |)
+ |""".stripMargin))
+ }
+ }
+
+ test("lateral join inside subquery") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ // uncorrelated
+ checkAnswer(
+ t1.where(
+ $"c1" === t2
+ .lateralJoin(spark.range(1).select($"c1".outer().as("a")))
+ .select(min($"a"))
+ .scalar()),
+ sql("SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL
(SELECT c1 AS a))"))
+ // correlated
+ checkAnswer(
+ t1.where(
+ $"c1" === t2
+ .lateralJoin(spark.range(1).select($"c1".outer().as("a")))
+ .where($"c1" === $"t1.c1".outer())
+ .select(min($"a"))
+ .scalar()),
+ sql(
+ "SELECT * FROM t1 " +
+ "WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE
c1 = t1.c1)"))
+ }
+ }
+
test("lateral join with table-valued functions") {
withView("t1", "t3") {
val t1 = table1()
@@ -219,4 +567,96 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
sql("SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 =
c3"))
}
}
+
+ test("subquery with generator / table-valued functions") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ spark.range(1).select(explode(t1.select(collect_list("c2")).scalar())),
+ sql("SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))"))
+ checkAnswer(
+ spark.tvf.explode(t1.select(collect_list("c2")).scalar()),
+ sql("SELECT * FROM EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))"))
+ }
+ }
+
+ test("subquery in join condition") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.join(t2, $"t1.c1" === t1.select(max("c1")).scalar()).toDF("c1",
"c2", "c3", "c4"),
+ sql("SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT MAX(c1) FROM t1)")
+ .toDF("c1", "c2", "c3", "c4"))
+ }
+ }
+
+ test("subquery in unpivot") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkError(
+ intercept[AnalysisException] {
+ t1.unpivot(Array(t2.exists()), "c1", "c2").collect()
+ },
+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+ parameters = Map("treeNode" -> "(?s)'Unpivot.*"),
+ matchPVals = true)
+ checkError(
+ intercept[AnalysisException] {
+ t1.unpivot(Array($"c1"), Array(t2.exists()), "c1", "c2").collect()
+ },
+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+ parameters = Map("treeNode" -> "(?s)Expand.*"),
+ matchPVals = true)
+ }
+ }
+
+ test("subquery in transpose") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkError(
+ intercept[AnalysisException] {
+ t1.transpose(t1.select(max("c1")).scalar()).collect()
+ },
+ "TRANSPOSE_INVALID_INDEX_COLUMN",
+ parameters = Map("reason" -> "Index column must be an atomic
attribute"))
+ }
+ }
+
+ test("subquery in withColumns") {
+ withView("t1") {
+ val t1 = table1()
+
+ // TODO(SPARK-50601): Fix the SparkConnectPlanner to support this case
+ checkError(
+ intercept[SparkException] {
+ t1.withColumn("scalar", spark.range(1).select($"c1".outer() +
$"c2".outer()).scalar())
+ .collect()
+ },
+ "INTERNAL_ERROR",
+ parameters = Map("message" -> "Found the unresolved operator: .*"),
+ matchPVals = true)
+ }
+ }
+
+ test("subquery in drop") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(t1.drop(spark.range(1).select(lit("c1")).scalar()), t1)
+ }
+ }
+
+ test("subquery in repartition") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(t1.repartition(spark.range(1).select(lit(1)).scalar()), t1)
+ }
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
index 2efd39673519..4cb03420c4d0 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
@@ -431,4 +431,5 @@ class ColumnNodeToProtoConverterSuite extends
ConnectFunSuite {
private[internal] case class Nope(override val origin: Origin =
CurrentOrigin.get)
extends ColumnNode {
override def sql: String = "nope"
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
index 8837c76b76ae..f22644074324 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
@@ -19,8 +19,11 @@ package org.apache.spark.sql.test
import java.util.TimeZone
+import scala.jdk.CollectionConverters._
+
import org.scalatest.Assertions
+import org.apache.spark.{QueryContextType, SparkThrowable}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide
import org.apache.spark.util.ArrayImplicits._
@@ -53,6 +56,158 @@ abstract class QueryTest extends ConnectFunSuite with
SQLHelper {
checkAnswer(df, expectedAnswer.toImmutableArraySeq)
}
+ case class ExpectedContext(
+ contextType: QueryContextType,
+ objectType: String,
+ objectName: String,
+ startIndex: Int,
+ stopIndex: Int,
+ fragment: String,
+ callSitePattern: String)
+
+ object ExpectedContext {
+ def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
+ ExpectedContext("", "", start, stop, fragment)
+ }
+
+ def apply(
+ objectType: String,
+ objectName: String,
+ startIndex: Int,
+ stopIndex: Int,
+ fragment: String): ExpectedContext = {
+ new ExpectedContext(
+ QueryContextType.SQL,
+ objectType,
+ objectName,
+ startIndex,
+ stopIndex,
+ fragment,
+ "")
+ }
+
+ def apply(fragment: String, callSitePattern: String): ExpectedContext = {
+ new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1,
fragment, callSitePattern)
+ }
+ }
+
+ /**
+ * Checks an exception with an error condition against expected results.
+ * @param exception
+ * The exception to check
+ * @param condition
+ * The expected error condition identifying the error
+ * @param sqlState
+ * Optional the expected SQLSTATE, not verified if not supplied
+ * @param parameters
+ * A map of parameter names and values. The names are as defined in the
error-classes file.
+ * @param matchPVals
+ * Optionally treat the parameters value as regular expression pattern.
false if not supplied.
+ */
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ sqlState: Option[String] = None,
+ parameters: Map[String, String] = Map.empty,
+ matchPVals: Boolean = false,
+ queryContext: Array[ExpectedContext] = Array.empty): Unit = {
+ assert(exception.getCondition === condition)
+ sqlState.foreach(state => assert(exception.getSqlState === state))
+ val expectedParameters = exception.getMessageParameters.asScala
+ if (matchPVals) {
+ assert(expectedParameters.size === parameters.size)
+ expectedParameters.foreach(exp => {
+ val parm = parameters.getOrElse(
+ exp._1,
+ throw new IllegalArgumentException("Missing parameter" + exp._1))
+ if (!exp._2.matches(parm)) {
+ throw new IllegalArgumentException(
+ "For parameter '" + exp._1 + "' value '" + exp._2 +
+ "' does not match: " + parm)
+ }
+ })
+ } else {
+ assert(expectedParameters === parameters)
+ }
+ val actualQueryContext = exception.getQueryContext()
+ assert(
+ actualQueryContext.length === queryContext.length,
+ "Invalid length of the query context")
+ actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
+ assert(
+ actual.contextType() === expected.contextType,
+ "Invalid contextType of a query context Actual:" + actual.toString)
+ if (actual.contextType() == QueryContextType.SQL) {
+ assert(
+ actual.objectType() === expected.objectType,
+ "Invalid objectType of a query context Actual:" + actual.toString)
+ assert(
+ actual.objectName() === expected.objectName,
+ "Invalid objectName of a query context. Actual:" + actual.toString)
+ assert(
+ actual.startIndex() === expected.startIndex,
+ "Invalid startIndex of a query context. Actual:" + actual.toString)
+ assert(
+ actual.stopIndex() === expected.stopIndex,
+ "Invalid stopIndex of a query context. Actual:" + actual.toString)
+ assert(
+ actual.fragment() === expected.fragment,
+ "Invalid fragment of a query context. Actual:" + actual.toString)
+ } else if (actual.contextType() == QueryContextType.DataFrame) {
+ assert(
+ actual.fragment() === expected.fragment,
+ "Invalid code fragment of a query context. Actual:" +
actual.toString)
+ if (expected.callSitePattern.nonEmpty) {
+ assert(
+ actual.callSite().matches(expected.callSitePattern),
+ "Invalid callSite of a query context. Actual:" + actual.toString)
+ }
+ }
+ }
+ }
+
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ sqlState: String,
+ parameters: Map[String, String]): Unit =
+ checkError(exception, condition, Some(sqlState), parameters)
+
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ sqlState: String,
+ parameters: Map[String, String],
+ context: ExpectedContext): Unit =
+ checkError(exception, condition, Some(sqlState), parameters, false,
Array(context))
+
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ parameters: Map[String, String],
+ context: ExpectedContext): Unit =
+ checkError(exception, condition, None, parameters, false, Array(context))
+
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ sqlState: String,
+ context: ExpectedContext): Unit =
+ checkError(exception, condition, Some(sqlState), Map.empty, false,
Array(context))
+
+ protected def checkError(
+ exception: SparkThrowable,
+ condition: String,
+ sqlState: Option[String],
+ parameters: Map[String, String],
+ context: ExpectedContext): Unit =
+ checkError(exception, condition, sqlState, parameters, false,
Array(context))
+
+ protected def getCurrentClassCallSitePattern: String = {
+ val cs = Thread.currentThread().getStackTrace()(2)
+ s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)"
+ }
+
/**
* Evaluates a dataset to make sure that the result of calling collect
matches the given
* expected answer.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
index 007d4f0648e4..d9828ae92267 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import org.scalatest.Assertions.fail
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession,
SQLImplicits}
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils}
trait SQLHelper {
@@ -110,6 +111,22 @@ trait SQLHelper {
finally SparkFileUtils.deleteRecursively(path)
}
+ /**
+ * Drops temporary view `viewNames` after calling `f`.
+ */
+ protected def withTempView(viewNames: String*)(f: => Unit): Unit = {
+ SparkErrorUtils.tryWithSafeFinally(f) {
+ viewNames.foreach { viewName =>
+ try spark.catalog.dropTempView(viewName)
+ catch {
+ // If the test failed part way, we don't want to mask the failure by
failing to remove
+ // temp views that never got created.
+ case _: NoSuchTableException =>
+ }
+ }
+ }
+ }
+
/**
* Drops table `tableName` after calling `f`.
*/
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index ee1886b8ef29..185ddc88cd08 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -79,6 +79,7 @@ from pyspark.sql.connect.streaming.readwriter import
DataStreamWriter
from pyspark.sql.column import Column
from pyspark.sql.connect.expressions import (
ColumnReference,
+ SubqueryExpression,
UnresolvedRegex,
UnresolvedStar,
)
@@ -1801,18 +1802,14 @@ class DataFrame(ParentDataFrame):
)
def scalar(self) -> Column:
- # TODO(SPARK-50134): Implement this method
- raise PySparkNotImplementedError(
- errorClass="NOT_IMPLEMENTED",
- messageParameters={"feature": "scalar()"},
- )
+ from pyspark.sql.connect.column import Column as ConnectColumn
+
+ return ConnectColumn(SubqueryExpression(self._plan,
subquery_type="scalar"))
def exists(self) -> Column:
- # TODO(SPARK-50134): Implement this method
- raise PySparkNotImplementedError(
- errorClass="NOT_IMPLEMENTED",
- messageParameters={"feature": "exists()"},
- )
+ from pyspark.sql.connect.column import Column as ConnectColumn
+
+ return ConnectColumn(SubqueryExpression(self._plan,
subquery_type="exists"))
@property
def schema(self) -> StructType:
@@ -2278,10 +2275,6 @@ def _test() -> None:
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
- # TODO(SPARK-50134): Support subquery in connect
- del pyspark.sql.dataframe.DataFrame.scalar.__doc__
- del pyspark.sql.dataframe.DataFrame.exists.__doc__
-
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 5d7b348f6d38..413a69181683 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -82,6 +82,7 @@ from pyspark.sql.utils import is_timestamp_ntz_preferred,
enum_to_value
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.window import WindowSpec
+ from pyspark.sql.connect.plan import LogicalPlan
class Expression:
@@ -128,6 +129,15 @@ class Expression:
plan.common.origin.CopyFrom(self.origin)
return plan
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return []
+
+ def foreach(self, f: Callable[["Expression"], None]) -> None:
+ f(self)
+ for c in self.children:
+ c.foreach(f)
+
class CaseWhen(Expression):
def __init__(
@@ -162,6 +172,16 @@ class CaseWhen(Expression):
return unresolved_function.to_plan(session)
+ @property
+ def children(self) -> Sequence["Expression"]:
+ children = []
+ for branch in self._branches:
+ children.append(branch[0])
+ children.append(branch[1])
+ if self._else_value is not None:
+ children.append(self._else_value)
+ return children
+
def __repr__(self) -> str:
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
_else = f" ELSE {self._else_value}" if self._else_value is not None
else ""
@@ -196,6 +216,10 @@ class ColumnAlias(Expression):
exp.alias.expr.CopyFrom(self._child.to_plan(session))
return exp
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._child]
+
def __repr__(self) -> str:
return f"{self._child} AS {','.join(self._alias)}"
@@ -622,6 +646,10 @@ class SortOrder(Expression):
return sort
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._child]
+
class UnresolvedFunction(Expression):
def __init__(
@@ -649,6 +677,10 @@ class UnresolvedFunction(Expression):
fun.unresolved_function.is_distinct = self._is_distinct
return fun
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return self._args
+
def __repr__(self) -> str:
# Default print handling:
if self._is_distinct:
@@ -730,12 +762,12 @@ class CommonInlineUserDefinedFunction(Expression):
function_name: str,
function: Union[PythonUDF, JavaUDF],
deterministic: bool = False,
- arguments: Sequence[Expression] = [],
+ arguments: Optional[Sequence[Expression]] = None,
):
super().__init__()
self._function_name = function_name
self._deterministic = deterministic
- self._arguments = arguments
+ self._arguments: Sequence[Expression] = arguments or []
self._function = function
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
@@ -770,6 +802,10 @@ class CommonInlineUserDefinedFunction(Expression):
expr.java_udf.CopyFrom(cast(proto.JavaUDF,
self._function.to_plan(session)))
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return self._arguments
+
def __repr__(self) -> str:
return f"{self._function_name}({', '.join([str(arg) for arg in
self._arguments])})"
@@ -799,6 +835,10 @@ class WithField(Expression):
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._structExpr, self._valueExpr]
+
def __repr__(self) -> str:
return f"update_field({self._structExpr}, {self._fieldName},
{self._valueExpr})"
@@ -823,6 +863,10 @@ class DropField(Expression):
expr.update_fields.field_name = self._fieldName
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._structExpr]
+
def __repr__(self) -> str:
return f"drop_field({self._structExpr}, {self._fieldName})"
@@ -847,6 +891,10 @@ class UnresolvedExtractValue(Expression):
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._child, self._extraction]
+
def __repr__(self) -> str:
return f"{self._child}['{self._extraction}']"
@@ -906,6 +954,10 @@ class CastExpression(Expression):
return fun
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._expr]
+
def __repr__(self) -> str:
# We cannot guarantee the string representations be exactly the same,
e.g.
# str(sf.col("a").cast("long")):
@@ -989,6 +1041,10 @@ class LambdaFunction(Expression):
)
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._function] + self._arguments
+
def __repr__(self) -> str:
return (
f"LambdaFunction({str(self._function)}, "
@@ -1098,6 +1154,12 @@ class WindowExpression(Expression):
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return (
+ [self._windowFunction] + self._windowSpec._partitionSpec +
self._windowSpec._orderSpec
+ )
+
def __repr__(self) -> str:
return f"WindowExpression({str(self._windowFunction)},
({str(self._windowSpec)}))"
@@ -1128,6 +1190,10 @@ class CallFunction(Expression):
expr.call_function.arguments.extend([arg.to_plan(session) for arg
in self._args])
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return self._args
+
def __repr__(self) -> str:
if len(self._args) > 0:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg
in self._args])})"
@@ -1151,6 +1217,10 @@ class NamedArgumentExpression(Expression):
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._value]
+
def __repr__(self) -> str:
return f"{self._key} => {self._value}"
@@ -1166,5 +1236,31 @@ class LazyExpression(Expression):
expr.lazy_expression.child.CopyFrom(self._expr.to_plan(session))
return expr
+ @property
+ def children(self) -> Sequence["Expression"]:
+ return [self._expr]
+
def __repr__(self) -> str:
return f"lazy({self._expr})"
+
+
+class SubqueryExpression(Expression):
+ def __init__(self, plan: "LogicalPlan", subquery_type: str) -> None:
+ assert isinstance(subquery_type, str)
+ assert subquery_type in ("scalar", "exists")
+
+ super().__init__()
+ self._plan = plan
+ self._subquery_type = subquery_type
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ expr = self._create_proto_expression()
+ expr.subquery_expression.plan_id = self._plan._plan_id
+ if self._subquery_type == "scalar":
+ expr.subquery_expression.subquery_type =
proto.SubqueryExpression.SUBQUERY_TYPE_SCALAR
+ elif self._subquery_type == "exists":
+ expr.subquery_expression.subquery_type =
proto.SubqueryExpression.SUBQUERY_TYPE_EXISTS
+ return expr
+
+ def __repr__(self) -> str:
+ return f"SubqueryExpression({self._plan}, {self._subquery_type})"
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index c411baf17ce9..02b60381ab93 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -52,7 +52,7 @@ from pyspark.sql.column import Column
from pyspark.sql.connect.logging import logger
from pyspark.sql.connect.proto import base_pb2 as
spark_dot_connect_dot_base__pb2
from pyspark.sql.connect.conversion import storage_level_to_proto
-from pyspark.sql.connect.expressions import Expression
+from pyspark.sql.connect.expressions import Expression, SubqueryExpression
from pyspark.sql.connect.types import pyspark_types_to_proto_types,
UnparsedDataType
from pyspark.errors import (
AnalysisException,
@@ -73,9 +73,30 @@ class LogicalPlan:
INDENT = 2
- def __init__(self, child: Optional["LogicalPlan"]) -> None:
+ def __init__(
+ self, child: Optional["LogicalPlan"], references:
Optional[Sequence["LogicalPlan"]] = None
+ ) -> None:
+ """
+
+ Parameters
+ ----------
+ child : :class:`LogicalPlan`, optional.
+ The child logical plan.
+ references : list of :class:`LogicalPlan`, optional.
+ The list of logical plans that are referenced as subqueries in
this logical plan.
+ """
self._child = child
- self._plan_id = LogicalPlan._fresh_plan_id()
+ self._root_plan_id = LogicalPlan._fresh_plan_id()
+
+ self._references: Sequence["LogicalPlan"] = references or []
+ self._plan_id_with_rel: Optional[int] = None
+ if len(self._references) > 0:
+ assert all(isinstance(r, LogicalPlan) for r in self._references)
+ self._plan_id_with_rel = LogicalPlan._fresh_plan_id()
+
+ @property
+ def _plan_id(self) -> int:
+ return self._plan_id_with_rel or self._root_plan_id
@staticmethod
def _fresh_plan_id() -> int:
@@ -89,7 +110,7 @@ class LogicalPlan:
def _create_proto_relation(self) -> proto.Relation:
plan = proto.Relation()
- plan.common.plan_id = self._plan_id
+ plan.common.plan_id = self._root_plan_id
return plan
def plan(self, session: "SparkConnectClient") -> proto.Relation: # type:
ignore[empty-body]
@@ -136,6 +157,42 @@ class LogicalPlan:
else:
return self._child.observations
+ @staticmethod
+ def _collect_references(
+ cols_or_exprs: Sequence[Union[Column, Expression]]
+ ) -> Sequence["LogicalPlan"]:
+ references: List[LogicalPlan] = []
+
+ def append_reference(e: Expression) -> None:
+ if isinstance(e, SubqueryExpression):
+ references.append(e._plan)
+
+ for col_or_expr in cols_or_exprs:
+ if isinstance(col_or_expr, Column):
+ col_or_expr._expr.foreach(append_reference)
+ else:
+ col_or_expr.foreach(append_reference)
+ return references
+
+ def _with_relations(
+ self, root: proto.Relation, session: "SparkConnectClient"
+ ) -> proto.Relation:
+ if len(self._references) == 0:
+ return root
+ else:
+ # When there are references to other DataFrame, e.g., subqueries,
build new plan like:
+ # with_relations [id 10]
+ # root: plan [id 9]
+ # reference:
+ # refs#1: [id 8]
+ # refs#2: [id 5]
+ plan = proto.Relation()
+ assert isinstance(self._plan_id_with_rel, int)
+ plan.common.plan_id = self._plan_id_with_rel
+ plan.with_relations.root.CopyFrom(root)
+ plan.with_relations.references.extend([ref.plan(session) for ref
in self._references])
+ return plan
+
def _parameters_to_print(self, parameters: Mapping[str, Any]) ->
Mapping[str, Any]:
"""
Extracts the parameters that are able to be printed. It looks up the
signature
@@ -192,6 +249,7 @@ class LogicalPlan:
getattr(a, "__forward_arg__", "").endswith("LogicalPlan")
for a in getattr(tpe.annotation, "__args__", ())
)
+
if (
not is_logical_plan
and not is_forwardref_logical_plan
@@ -473,8 +531,8 @@ class Project(LogicalPlan):
child: Optional["LogicalPlan"],
columns: List[Column],
) -> None:
- super().__init__(child)
assert all(isinstance(c, Column) for c in columns)
+ super().__init__(child, self._collect_references(columns))
self._columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -482,7 +540,8 @@ class Project(LogicalPlan):
plan = self._create_proto_relation()
plan.project.input.CopyFrom(self._child.plan(session))
plan.project.expressions.extend([c.to_plan(session) for c in
self._columns])
- return plan
+
+ return self._with_relations(plan, session)
class WithColumns(LogicalPlan):
@@ -495,8 +554,6 @@ class WithColumns(LogicalPlan):
columns: Sequence[Column],
metadata: Optional[Sequence[str]] = None,
) -> None:
- super().__init__(child)
-
assert isinstance(columnNames, list)
assert len(columnNames) > 0
assert all(isinstance(c, str) for c in columnNames)
@@ -513,6 +570,8 @@ class WithColumns(LogicalPlan):
# validate json string
assert m == "" or json.loads(m) is not None
+ super().__init__(child, self._collect_references(columns))
+
self._columnNames = columnNames
self._columns = columns
self._metadata = metadata
@@ -530,7 +589,7 @@ class WithColumns(LogicalPlan):
alias.metadata = self._metadata[i]
plan.with_columns.aliases.append(alias)
- return plan
+ return self._with_relations(plan, session)
class WithWatermark(LogicalPlan):
@@ -608,16 +667,14 @@ class Hint(LogicalPlan):
name: str,
parameters: Sequence[Column],
) -> None:
- super().__init__(child)
-
assert isinstance(name, str)
- self._name = name
-
assert parameters is not None and isinstance(parameters, List)
for param in parameters:
assert isinstance(param, Column)
+ super().__init__(child, self._collect_references(parameters))
+ self._name = name
self._parameters = parameters
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -626,12 +683,12 @@ class Hint(LogicalPlan):
plan.hint.input.CopyFrom(self._child.plan(session))
plan.hint.name = self._name
plan.hint.parameters.extend([param.to_plan(session) for param in
self._parameters])
- return plan
+ return self._with_relations(plan, session)
class Filter(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], filter: Column) -> None:
- super().__init__(child)
+ super().__init__(child, self._collect_references([filter]))
self.filter = filter
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -639,7 +696,7 @@ class Filter(LogicalPlan):
plan = self._create_proto_relation()
plan.filter.input.CopyFrom(self._child.plan(session))
plan.filter.condition.CopyFrom(self.filter.to_plan(session))
- return plan
+ return self._with_relations(plan, session)
class Limit(LogicalPlan):
@@ -712,11 +769,10 @@ class Sort(LogicalPlan):
columns: List[Column],
is_global: bool,
) -> None:
- super().__init__(child)
-
assert all(isinstance(c, Column) for c in columns)
assert isinstance(is_global, bool)
+ super().__init__(child, self._collect_references(columns))
self.columns = columns
self.is_global = is_global
@@ -726,7 +782,7 @@ class Sort(LogicalPlan):
plan.sort.input.CopyFrom(self._child.plan(session))
plan.sort.order.extend([c.to_plan(session).sort_order for c in
self.columns])
plan.sort.is_global = self.is_global
- return plan
+ return self._with_relations(plan, session)
class Drop(LogicalPlan):
@@ -735,9 +791,12 @@ class Drop(LogicalPlan):
child: Optional["LogicalPlan"],
columns: List[Union[Column, str]],
) -> None:
- super().__init__(child)
if len(columns) > 0:
assert all(isinstance(c, (Column, str)) for c in columns)
+
+ super().__init__(
+ child, self._collect_references([c for c in columns if
isinstance(c, Column)])
+ )
self._columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -749,7 +808,7 @@ class Drop(LogicalPlan):
plan.drop.columns.append(c.to_plan(session))
else:
plan.drop.column_names.append(c)
- return plan
+ return self._with_relations(plan, session)
class Sample(LogicalPlan):
@@ -792,8 +851,6 @@ class Aggregate(LogicalPlan):
pivot_values: Optional[Sequence[Column]],
grouping_sets: Optional[Sequence[Sequence[Column]]],
) -> None:
- super().__init__(child)
-
assert isinstance(group_type, str) and group_type in [
"groupby",
"rollup",
@@ -801,15 +858,12 @@ class Aggregate(LogicalPlan):
"pivot",
"grouping_sets",
]
- self._group_type = group_type
assert isinstance(grouping_cols, list) and all(isinstance(c, Column)
for c in grouping_cols)
- self._grouping_cols = grouping_cols
assert isinstance(aggregate_cols, list) and all(
isinstance(c, Column) for c in aggregate_cols
)
- self._aggregate_cols = aggregate_cols
if group_type == "pivot":
assert pivot_col is not None and isinstance(pivot_col, Column)
@@ -821,6 +875,19 @@ class Aggregate(LogicalPlan):
assert pivot_values is None
assert grouping_sets is None
+ super().__init__(
+ child,
+ self._collect_references(
+ grouping_cols
+ + aggregate_cols
+ + ([pivot_col] if pivot_col is not None else [])
+ + (pivot_values if pivot_values is not None else [])
+ + ([g for gs in grouping_sets for g in gs] if grouping_sets is
not None else [])
+ ),
+ )
+ self._group_type = group_type
+ self._grouping_cols = grouping_cols
+ self._aggregate_cols = aggregate_cols
self._pivot_col = pivot_col
self._pivot_values = pivot_values
self._grouping_sets = grouping_sets
@@ -859,7 +926,7 @@ class Aggregate(LogicalPlan):
grouping_set=[c.to_plan(session) for c in grouping_set]
)
)
- return plan
+ return self._with_relations(plan, session)
class Join(LogicalPlan):
@@ -870,7 +937,16 @@ class Join(LogicalPlan):
on: Optional[Union[str, List[str], Column, List[Column]]],
how: Optional[str],
) -> None:
- super().__init__(left)
+ super().__init__(
+ left,
+ self._collect_references(
+ []
+ if on is None or isinstance(on, str)
+ else [on]
+ if isinstance(on, Column)
+ else [c for c in on if isinstance(c, Column)]
+ ),
+ )
self.left = cast(LogicalPlan, left)
self.right = right
self.on = on
@@ -942,7 +1018,7 @@ class Join(LogicalPlan):
merge_column = functools.reduce(lambda c1, c2: c1 & c2,
self.on)
plan.join.join_condition.CopyFrom(cast(Column,
merge_column).to_plan(session))
plan.join.join_type = self.how
- return plan
+ return self._with_relations(plan, session)
@property
def observations(self) -> Dict[str, "Observation"]:
@@ -982,7 +1058,20 @@ class AsOfJoin(LogicalPlan):
allow_exact_matches: bool,
direction: str,
) -> None:
- super().__init__(left)
+ super().__init__(
+ left,
+ self._collect_references(
+ [left_as_of, right_as_of]
+ + (
+ []
+ if on is None or isinstance(on, str)
+ else [on]
+ if isinstance(on, Column)
+ else [c for c in on if isinstance(c, Column)]
+ )
+ + ([tolerance] if tolerance is not None else [])
+ ),
+ )
self.left = left
self.right = right
self.left_as_of = left_as_of
@@ -1022,7 +1111,7 @@ class AsOfJoin(LogicalPlan):
plan.as_of_join.allow_exact_matches = self.allow_exact_matches
plan.as_of_join.direction = self.direction
- return plan
+ return self._with_relations(plan, session)
@property
def observations(self) -> Dict[str, "Observation"]:
@@ -1064,7 +1153,7 @@ class LateralJoin(LogicalPlan):
on: Optional[Column],
how: Optional[str],
) -> None:
- super().__init__(left)
+ super().__init__(left, self._collect_references([on] if on is not None
else []))
self.left = cast(LogicalPlan, left)
self.right = right
self.on = on
@@ -1097,7 +1186,7 @@ class LateralJoin(LogicalPlan):
if self.on is not None:
plan.lateral_join.join_condition.CopyFrom(self.on.to_plan(session))
plan.lateral_join.join_type = self.how
- return plan
+ return self._with_relations(plan, session)
@property
def observations(self) -> Dict[str, "Observation"]:
@@ -1225,9 +1314,9 @@ class RepartitionByExpression(LogicalPlan):
num_partitions: Optional[int],
columns: List[Column],
) -> None:
- super().__init__(child)
- self.num_partitions = num_partitions
assert all(isinstance(c, Column) for c in columns)
+ super().__init__(child, self._collect_references(columns))
+ self.num_partitions = num_partitions
self.columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1240,7 +1329,7 @@ class RepartitionByExpression(LogicalPlan):
plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
if self.num_partitions is not None:
plan.repartition_by_expression.num_partitions = self.num_partitions
- return plan
+ return self._with_relations(plan, session)
class SubqueryAlias(LogicalPlan):
@@ -1286,8 +1375,6 @@ class SQL(LogicalPlan):
named_args: Optional[Dict[str, Column]] = None,
views: Optional[Sequence[SubqueryAlias]] = None,
) -> None:
- super().__init__(None)
-
if args is not None:
assert isinstance(args, List)
assert all(isinstance(arg, Column) for arg in args)
@@ -1301,10 +1388,8 @@ class SQL(LogicalPlan):
if views is not None:
assert isinstance(views, List)
assert all(isinstance(v, SubqueryAlias) for v in views)
- if len(views) > 0:
- # reserved plan id for WithRelations
- self._plan_id_with_rel = LogicalPlan._fresh_plan_id()
+ super().__init__(None, views)
self._query = query
self._args = args
self._named_args = named_args
@@ -1320,20 +1405,7 @@ class SQL(LogicalPlan):
for k, arg in self._named_args.items():
plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session))
- if self._views is not None and len(self._views) > 0:
- # build new plan like
- # with_relations [id 10]
- # root: sql [id 9]
- # reference:
- # view#1: [id 8]
- # view#2: [id 5]
- sql_plan = plan
- plan = proto.Relation()
- plan.common.plan_id = self._plan_id_with_rel
- plan.with_relations.root.CopyFrom(sql_plan)
- plan.with_relations.references.extend([v.plan(session) for v in
self._views])
-
- return plan
+ return self._with_relations(plan, session)
def command(self, session: "SparkConnectClient") -> proto.Command:
cmd = proto.Command()
@@ -1407,7 +1479,7 @@ class Unpivot(LogicalPlan):
variable_column_name: str,
value_column_name: str,
) -> None:
- super().__init__(child)
+ super().__init__(child, self._collect_references(ids + (values or [])))
self.ids = ids
self.values = values
self.variable_column_name = variable_column_name
@@ -1422,7 +1494,7 @@ class Unpivot(LogicalPlan):
plan.unpivot.values.values.extend([v.to_plan(session) for v in
self.values])
plan.unpivot.variable_column_name = self.variable_column_name
plan.unpivot.value_column_name = self.value_column_name
- return plan
+ return self._with_relations(plan, session)
class Transpose(LogicalPlan):
@@ -1433,7 +1505,7 @@ class Transpose(LogicalPlan):
child: Optional["LogicalPlan"],
index_columns: Sequence[Column],
) -> None:
- super().__init__(child)
+ super().__init__(child, self._collect_references(index_columns))
self.index_columns = index_columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1443,12 +1515,12 @@ class Transpose(LogicalPlan):
if self.index_columns is not None and len(self.index_columns) > 0:
for index_column in self.index_columns:
plan.transpose.index_columns.append(index_column.to_plan(session))
- return plan
+ return self._with_relations(plan, session)
class UnresolvedTableValuedFunction(LogicalPlan):
def __init__(self, name: str, args: Sequence[Column]):
- super().__init__(None)
+ super().__init__(None, self._collect_references(args))
self._name = name
self._args = args
@@ -1457,7 +1529,7 @@ class UnresolvedTableValuedFunction(LogicalPlan):
plan.unresolved_table_valued_function.function_name = self._name
for arg in self._args:
plan.unresolved_table_valued_function.arguments.append(arg.to_plan(session))
- return plan
+ return self._with_relations(plan, session)
class CollectMetrics(LogicalPlan):
@@ -1469,9 +1541,9 @@ class CollectMetrics(LogicalPlan):
observation: Union[str, "Observation"],
exprs: List[Column],
) -> None:
- super().__init__(child)
- self._observation = observation
assert all(isinstance(e, Column) for e in exprs)
+ super().__init__(child, self._collect_references(exprs))
+ self._observation = observation
self._exprs = exprs
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1484,7 +1556,7 @@ class CollectMetrics(LogicalPlan):
else str(self._observation._name)
)
plan.collect_metrics.metrics.extend([e.to_plan(session) for e in
self._exprs])
- return plan
+ return self._with_relations(plan, session)
@property
def observations(self) -> Dict[str, "Observation"]:
@@ -1569,13 +1641,13 @@ class NAReplace(LogicalPlan):
cols: Optional[List[str]],
replacements: Sequence[Tuple[Column, Column]],
) -> None:
- super().__init__(child)
- self.cols = cols
-
assert replacements is not None and isinstance(replacements, List)
for k, v in replacements:
assert k is not None and isinstance(k, Column)
assert v is not None and isinstance(v, Column)
+
+ super().__init__(child, self._collect_references([e for t in
replacements for e in t]))
+ self.cols = cols
self.replacements = replacements
def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1590,7 +1662,7 @@ class NAReplace(LogicalPlan):
replacement.old_value.CopyFrom(old_value.to_plan(session).literal)
replacement.new_value.CopyFrom(new_value.to_plan(session).literal)
plan.replace.replacements.append(replacement)
- return plan
+ return self._with_relations(plan, session)
class StatSummary(LogicalPlan):
@@ -1700,8 +1772,6 @@ class StatSampleBy(LogicalPlan):
fractions: Sequence[Tuple[Column, float]],
seed: int,
) -> None:
- super().__init__(child)
-
assert col is not None and isinstance(col, (Column, str))
assert fractions is not None and isinstance(fractions, List)
@@ -1711,6 +1781,12 @@ class StatSampleBy(LogicalPlan):
assert seed is None or isinstance(seed, int)
+ super().__init__(
+ child,
+ self._collect_references(
+ [col] if isinstance(col, Column) else [] + [c for c, _ in
fractions]
+ ),
+ )
self._col = col
self._fractions = fractions
self._seed = seed
@@ -1727,7 +1803,7 @@ class StatSampleBy(LogicalPlan):
fraction.fraction = float(v)
plan.sample_by.fractions.append(fraction)
plan.sample_by.seed = self._seed
- return plan
+ return self._with_relations(plan, session)
class StatCorr(LogicalPlan):
@@ -2375,7 +2451,7 @@ class GroupMap(LogicalPlan):
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column)
for c in grouping_cols)
- super().__init__(child)
+ super().__init__(child, self._collect_references(grouping_cols))
self._grouping_cols = grouping_cols
self._function =
function._build_common_inline_user_defined_function(*cols)
@@ -2387,7 +2463,7 @@ class GroupMap(LogicalPlan):
[c.to_plan(session) for c in self._grouping_cols]
)
plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
- return plan
+ return self._with_relations(plan, session)
class CoGroupMap(LogicalPlan):
@@ -2408,7 +2484,7 @@ class CoGroupMap(LogicalPlan):
isinstance(c, Column) for c in other_grouping_cols
)
- super().__init__(input)
+ super().__init__(input, self._collect_references(input_grouping_cols +
other_grouping_cols))
self._input_grouping_cols = input_grouping_cols
self._other_grouping_cols = other_grouping_cols
self._other = cast(LogicalPlan, other)
@@ -2428,7 +2504,7 @@ class CoGroupMap(LogicalPlan):
[c.to_plan(session) for c in self._other_grouping_cols]
)
plan.co_group_map.func.CopyFrom(self._function.to_plan_udf(session))
- return plan
+ return self._with_relations(plan, session)
class ApplyInPandasWithState(LogicalPlan):
@@ -2447,7 +2523,7 @@ class ApplyInPandasWithState(LogicalPlan):
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column)
for c in grouping_cols)
- super().__init__(child)
+ super().__init__(child, self._collect_references(grouping_cols))
self._grouping_cols = grouping_cols
self._function =
function._build_common_inline_user_defined_function(*cols)
self._output_schema = output_schema
@@ -2467,7 +2543,7 @@ class ApplyInPandasWithState(LogicalPlan):
plan.apply_in_pandas_with_state.state_schema = self._state_schema
plan.apply_in_pandas_with_state.output_mode = self._output_mode
plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
- return plan
+ return self._with_relations(plan, session)
class PythonUDTF:
@@ -2531,7 +2607,7 @@ class CommonInlineUserDefinedTableFunction(LogicalPlan):
deterministic: bool,
arguments: Sequence[Expression],
) -> None:
- super().__init__(None)
+ super().__init__(None, self._collect_references(arguments))
self._function_name = function_name
self._deterministic = deterministic
self._arguments = arguments
@@ -2548,7 +2624,7 @@ class CommonInlineUserDefinedTableFunction(LogicalPlan):
plan.common_inline_user_defined_table_function.python_udtf.CopyFrom(
self._function.to_plan(session)
)
- return plan
+ return self._with_relations(plan, session)
def udtf_plan(
self, session: "SparkConnectClient"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 87070fd5ad3c..093997a0d0c5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x8b\x31\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xe1\x31\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved
[...]
)
_globals = globals()
@@ -54,79 +54,83 @@ if not _descriptor._USE_C_DESCRIPTORS:
"DESCRIPTOR"
]._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
_globals["_EXPRESSION"]._serialized_start = 133
- _globals["_EXPRESSION"]._serialized_end = 6416
- _globals["_EXPRESSION_WINDOW"]._serialized_start = 1974
- _globals["_EXPRESSION_WINDOW"]._serialized_end = 2757
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2264
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_end = 2757
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_start
= 2531
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_end =
2676
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_start =
2678
- _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_end = 2757
- _globals["_EXPRESSION_SORTORDER"]._serialized_start = 2760
- _globals["_EXPRESSION_SORTORDER"]._serialized_end = 3185
- _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_start = 2990
- _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_end = 3098
- _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_start = 3100
- _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_end = 3185
- _globals["_EXPRESSION_CAST"]._serialized_start = 3188
- _globals["_EXPRESSION_CAST"]._serialized_end = 3503
- _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3389
- _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3487
- _globals["_EXPRESSION_LITERAL"]._serialized_start = 3506
- _globals["_EXPRESSION_LITERAL"]._serialized_end = 5069
- _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4341
- _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4458
- _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4460
- _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4558
- _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4561
- _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 4691
- _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 4694
- _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 4921
- _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 4924
- _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5053
- _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5072
- _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5258
- _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5261
- _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 5465
- _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 5467
- _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 5517
- _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 5519
- _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 5643
- _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 5645
- _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 5731
- _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 5734
- _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 5866
- _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 5869
- _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6056
- _globals["_EXPRESSION_ALIAS"]._serialized_start = 6058
- _globals["_EXPRESSION_ALIAS"]._serialized_end = 6178
- _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6181
- _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 6339
- _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start =
6341
- _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end =
6403
- _globals["_EXPRESSIONCOMMON"]._serialized_start = 6418
- _globals["_EXPRESSIONCOMMON"]._serialized_end = 6483
- _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 6486
- _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 6850
- _globals["_PYTHONUDF"]._serialized_start = 6853
- _globals["_PYTHONUDF"]._serialized_end = 7057
- _globals["_SCALARSCALAUDF"]._serialized_start = 7060
- _globals["_SCALARSCALAUDF"]._serialized_end = 7274
- _globals["_JAVAUDF"]._serialized_start = 7277
- _globals["_JAVAUDF"]._serialized_end = 7426
- _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 7428
- _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 7527
- _globals["_CALLFUNCTION"]._serialized_start = 7529
- _globals["_CALLFUNCTION"]._serialized_end = 7637
- _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 7639
- _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 7731
- _globals["_MERGEACTION"]._serialized_start = 7734
- _globals["_MERGEACTION"]._serialized_end = 8246
- _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 7956
- _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8062
- _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8065
- _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8232
- _globals["_LAZYEXPRESSION"]._serialized_start = 8248
- _globals["_LAZYEXPRESSION"]._serialized_end = 8313
+ _globals["_EXPRESSION"]._serialized_end = 6502
+ _globals["_EXPRESSION_WINDOW"]._serialized_start = 2060
+ _globals["_EXPRESSION_WINDOW"]._serialized_end = 2843
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2350
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_end = 2843
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_start
= 2617
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_end =
2762
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_start =
2764
+ _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_end = 2843
+ _globals["_EXPRESSION_SORTORDER"]._serialized_start = 2846
+ _globals["_EXPRESSION_SORTORDER"]._serialized_end = 3271
+ _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_start = 3076
+ _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_end = 3184
+ _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_start = 3186
+ _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_end = 3271
+ _globals["_EXPRESSION_CAST"]._serialized_start = 3274
+ _globals["_EXPRESSION_CAST"]._serialized_end = 3589
+ _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3475
+ _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3573
+ _globals["_EXPRESSION_LITERAL"]._serialized_start = 3592
+ _globals["_EXPRESSION_LITERAL"]._serialized_end = 5155
+ _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4427
+ _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4544
+ _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4546
+ _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4644
+ _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4647
+ _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 4777
+ _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 4780
+ _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 5007
+ _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 5010
+ _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5139
+ _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5158
+ _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5344
+ _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5347
+ _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 5551
+ _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 5553
+ _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 5603
+ _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 5605
+ _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 5729
+ _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 5731
+ _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 5817
+ _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 5820
+ _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 5952
+ _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 5955
+ _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6142
+ _globals["_EXPRESSION_ALIAS"]._serialized_start = 6144
+ _globals["_EXPRESSION_ALIAS"]._serialized_end = 6264
+ _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6267
+ _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 6425
+ _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start =
6427
+ _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end =
6489
+ _globals["_EXPRESSIONCOMMON"]._serialized_start = 6504
+ _globals["_EXPRESSIONCOMMON"]._serialized_end = 6569
+ _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 6572
+ _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 6936
+ _globals["_PYTHONUDF"]._serialized_start = 6939
+ _globals["_PYTHONUDF"]._serialized_end = 7143
+ _globals["_SCALARSCALAUDF"]._serialized_start = 7146
+ _globals["_SCALARSCALAUDF"]._serialized_end = 7360
+ _globals["_JAVAUDF"]._serialized_start = 7363
+ _globals["_JAVAUDF"]._serialized_end = 7512
+ _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 7514
+ _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 7613
+ _globals["_CALLFUNCTION"]._serialized_start = 7615
+ _globals["_CALLFUNCTION"]._serialized_end = 7723
+ _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 7725
+ _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 7817
+ _globals["_MERGEACTION"]._serialized_start = 7820
+ _globals["_MERGEACTION"]._serialized_end = 8332
+ _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 8042
+ _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8148
+ _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8151
+ _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8318
+ _globals["_LAZYEXPRESSION"]._serialized_start = 8334
+ _globals["_LAZYEXPRESSION"]._serialized_end = 8399
+ _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8402
+ _globals["_SUBQUERYEXPRESSION"]._serialized_end = 8627
+ _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 8534
+ _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 8627
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index df4106cfc5f7..0a6f3caee8b5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1185,6 +1185,7 @@ class Expression(google.protobuf.message.Message):
MERGE_ACTION_FIELD_NUMBER: builtins.int
TYPED_AGGREGATE_EXPRESSION_FIELD_NUMBER: builtins.int
LAZY_EXPRESSION_FIELD_NUMBER: builtins.int
+ SUBQUERY_EXPRESSION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def common(self) -> global___ExpressionCommon: ...
@@ -1231,6 +1232,8 @@ class Expression(google.protobuf.message.Message):
@property
def lazy_expression(self) -> global___LazyExpression: ...
@property
+ def subquery_expression(self) -> global___SubqueryExpression: ...
+ @property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins
generate arbitrary
relations they can add them here. During the planning the correct
resolution is done.
@@ -1260,6 +1263,7 @@ class Expression(google.protobuf.message.Message):
merge_action: global___MergeAction | None = ...,
typed_aggregate_expression: global___TypedAggregateExpression | None =
...,
lazy_expression: global___LazyExpression | None = ...,
+ subquery_expression: global___SubqueryExpression | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
@@ -1293,6 +1297,8 @@ class Expression(google.protobuf.message.Message):
b"named_argument_expression",
"sort_order",
b"sort_order",
+ "subquery_expression",
+ b"subquery_expression",
"typed_aggregate_expression",
b"typed_aggregate_expression",
"unresolved_attribute",
@@ -1344,6 +1350,8 @@ class Expression(google.protobuf.message.Message):
b"named_argument_expression",
"sort_order",
b"sort_order",
+ "subquery_expression",
+ b"subquery_expression",
"typed_aggregate_expression",
b"typed_aggregate_expression",
"unresolved_attribute",
@@ -1388,6 +1396,7 @@ class Expression(google.protobuf.message.Message):
"merge_action",
"typed_aggregate_expression",
"lazy_expression",
+ "subquery_expression",
"extension",
]
| None
@@ -1829,3 +1838,47 @@ class LazyExpression(google.protobuf.message.Message):
def ClearField(self, field_name: typing_extensions.Literal["child",
b"child"]) -> None: ...
global___LazyExpression = LazyExpression
+
+class SubqueryExpression(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _SubqueryType:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _SubqueryTypeEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ SubqueryExpression._SubqueryType.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ SUBQUERY_TYPE_UNKNOWN: SubqueryExpression._SubqueryType.ValueType # 0
+ SUBQUERY_TYPE_SCALAR: SubqueryExpression._SubqueryType.ValueType # 1
+ SUBQUERY_TYPE_EXISTS: SubqueryExpression._SubqueryType.ValueType # 2
+
+ class SubqueryType(_SubqueryType, metaclass=_SubqueryTypeEnumTypeWrapper):
...
+ SUBQUERY_TYPE_UNKNOWN: SubqueryExpression.SubqueryType.ValueType # 0
+ SUBQUERY_TYPE_SCALAR: SubqueryExpression.SubqueryType.ValueType # 1
+ SUBQUERY_TYPE_EXISTS: SubqueryExpression.SubqueryType.ValueType # 2
+
+ PLAN_ID_FIELD_NUMBER: builtins.int
+ SUBQUERY_TYPE_FIELD_NUMBER: builtins.int
+ plan_id: builtins.int
+ """(Required) The id of corresponding connect plan."""
+ subquery_type: global___SubqueryExpression.SubqueryType.ValueType
+ """(Required) The type of the subquery."""
+ def __init__(
+ self,
+ *,
+ plan_id: builtins.int = ...,
+ subquery_type: global___SubqueryExpression.SubqueryType.ValueType =
...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "plan_id", b"plan_id", "subquery_type", b"subquery_type"
+ ],
+ ) -> None: ...
+
+global___SubqueryExpression = SubqueryExpression
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f88ca5348ff2..660f577f56f8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6612,7 +6612,7 @@ class DataFrame:
>>> from pyspark.sql import functions as sf
>>> employees.where(
... sf.col("salary") > employees.select(sf.avg("salary")).scalar()
- ... ).select("name", "salary", "department_id").show()
+ ... ).select("name", "salary", "department_id").orderBy("name").show()
+-----+------+-------------+
| name|salary|department_id|
+-----+------+-------------+
@@ -6630,7 +6630,7 @@ class DataFrame:
... > employees.alias("e2").where(
... sf.col("e2.department_id") ==
sf.col("e1.department_id").outer()
... ).select(sf.avg("salary")).scalar()
- ... ).select("name", "salary", "department_id").show()
+ ... ).select("name", "salary", "department_id").orderBy("name").show()
+-----+------+-------------+
| name|salary|department_id|
+-----+------+-------------+
@@ -6651,15 +6651,15 @@ class DataFrame:
...
).select(sf.sum("salary")).scalar().alias("avg_salary"),
... 1
... ).alias("salary_proportion_in_department")
- ... ).show()
+ ... ).orderBy("name").show()
+-------+------+-------------+-------------------------------+
| name|salary|department_id|salary_proportion_in_department|
+-------+------+-------------+-------------------------------+
| Alice| 45000| 101| 30.6|
| Bob| 54000| 101| 36.7|
|Charlie| 29000| 102| 32.2|
- | Eve| 48000| 101| 32.7|
| David| 61000| 102| 67.8|
+ | Eve| 48000| 101| 32.7|
+-------+------+-------------+-------------------------------+
"""
...
diff --git a/python/pyspark/sql/tests/connect/test_parity_subquery.py
b/python/pyspark/sql/tests/connect/test_parity_subquery.py
index cffb6fc39059..dae60a354d20 100644
--- a/python/pyspark/sql/tests/connect/test_parity_subquery.py
+++ b/python/pyspark/sql/tests/connect/test_parity_subquery.py
@@ -17,42 +17,37 @@
import unittest
+from pyspark.sql import functions as sf
from pyspark.sql.tests.test_subquery import SubqueryTestsMixin
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.connectutils import ReusedConnectTestCase
class SubqueryParityTests(SubqueryTestsMixin, ReusedConnectTestCase):
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_simple_uncorrelated_scalar_subquery(self):
- super().test_simple_uncorrelated_scalar_subquery()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_uncorrelated_scalar_subquery_with_view(self):
- super().test_uncorrelated_scalar_subquery_with_view()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_scalar_subquery_against_local_relations(self):
- super().test_scalar_subquery_against_local_relations()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_correlated_scalar_subquery(self):
- super().test_correlated_scalar_subquery()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_exists_subquery(self):
- super().test_exists_subquery()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_scalar_subquery_with_outer_reference_errors(self):
- super().test_scalar_subquery_with_outer_reference_errors()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_scalar_subquery_inside_lateral_join(self):
- super().test_scalar_subquery_inside_lateral_join()
-
- @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
- def test_lateral_join_inside_subquery(self):
- super().test_lateral_join_inside_subquery()
+ def test_scalar_subquery_with_missing_outer_reference(self):
+ with self.tempView("l", "r"):
+ self.df1.createOrReplaceTempView("l")
+ self.df2.createOrReplaceTempView("r")
+
+ assertDataFrameEqual(
+ self.spark.table("l").select(
+ "a",
+ (
+ self.spark.table("r")
+ .where(sf.col("c") == sf.col("a"))
+ .select(sf.sum("d"))
+ .scalar()
+ ),
+ ),
+ self.spark.sql("""SELECT a, (SELECT sum(d) FROM r WHERE c = a)
FROM l"""),
+ )
+
+ def test_subquery_in_unpivot(self):
+ self.check_subquery_in_unpivot(None, None)
+
+ @unittest.skip("SPARK-50601: Fix the SparkConnectPlanner to support this
case")
+ def test_subquery_in_with_columns(self):
+ super().test_subquery_in_with_columns()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_subquery.py
b/python/pyspark/sql/tests/test_subquery.py
index 91789f74d9da..0f431589b461 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -459,30 +459,29 @@ class SubqueryTestsMixin:
),
)
- def test_scalar_subquery_with_outer_reference_errors(self):
+ def test_scalar_subquery_with_missing_outer_reference(self):
with self.tempView("l", "r"):
self.df1.createOrReplaceTempView("l")
self.df2.createOrReplaceTempView("r")
- with self.subTest("missing `outer()`"):
- with self.assertRaises(AnalysisException) as pe:
- self.spark.table("l").select(
- "a",
- (
- self.spark.table("r")
- .where(sf.col("c") == sf.col("a"))
- .select(sf.sum("d"))
- .scalar()
- ),
- ).collect()
+ with self.assertRaises(AnalysisException) as pe:
+ self.spark.table("l").select(
+ "a",
+ (
+ self.spark.table("r")
+ .where(sf.col("c") == sf.col("a"))
+ .select(sf.sum("d"))
+ .scalar()
+ ),
+ ).collect()
- self.check_error(
- exception=pe.exception,
- errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
- messageParameters={"objectName": "`a`", "proposal": "`c`,
`d`"},
- query_context_type=QueryContextType.DataFrame,
- fragment="col",
- )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ messageParameters={"objectName": "`a`", "proposal": "`c`,
`d`"},
+ query_context_type=QueryContextType.DataFrame,
+ fragment="col",
+ )
def table1(self):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
@@ -833,6 +832,90 @@ class SubqueryTestsMixin:
),
)
+ def test_subquery_with_generator_and_tvf(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+
self.spark.range(1).select(sf.explode(t1.select(sf.collect_list("c2")).scalar())),
+ self.spark.sql("""SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM
t1))"""),
+ )
+ assertDataFrameEqual(
+
self.spark.tvf.explode(t1.select(sf.collect_list("c2")).scalar()),
+ self.spark.sql("""SELECT * FROM EXPLODE((SELECT
COLLECT_LIST(c2) FROM t1))"""),
+ )
+
+ def test_subquery_in_join_condition(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.join(t2, sf.col("t1.c1") ==
t1.select(sf.max("c1")).scalar()),
+ self.spark.sql("""SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT
MAX(c1) FROM t1)"""),
+ )
+
+ def test_subquery_in_unpivot(self):
+ self.check_subquery_in_unpivot(QueryContextType.DataFrame, "exists")
+
+ def check_subquery_in_unpivot(self, query_context_type, fragment):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ with self.assertRaises(AnalysisException) as pe:
+ t1.unpivot("c1", t2.exists(), "c1", "c2").collect()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass=(
+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY"
+ ),
+ messageParameters={"treeNode": "Expand.*"},
+ query_context_type=query_context_type,
+ fragment=fragment,
+ matchPVals=True,
+ )
+
+ def test_subquery_in_transpose(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ with self.assertRaises(AnalysisException) as pe:
+ t1.transpose(t1.select(sf.max("c1")).scalar()).collect()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="TRANSPOSE_INVALID_INDEX_COLUMN",
+ messageParameters={"reason": "Index column must be an atomic
attribute"},
+ )
+
+ def test_subquery_in_with_columns(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+ t1.withColumn(
+ "scalar",
+ self.spark.range(1)
+ .select(sf.col("c1").outer() + sf.col("c2").outer())
+ .scalar(),
+ ),
+ t1.withColumn("scalar", sf.col("c1") + sf.col("c2")),
+ )
+
+ def test_subquery_in_drop(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+
assertDataFrameEqual(t1.drop(self.spark.range(1).select(sf.lit("c1")).scalar()),
t1)
+
+ def test_subquery_in_repartition(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+
assertDataFrameEqual(t1.repartition(self.spark.range(1).select(sf.lit(1)).scalar()),
t1)
+
class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index d5f097065dc5..233b432766b7 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -344,6 +344,7 @@ class PySparkErrorTestUtils:
messageParameters: Optional[Dict[str, str]] = None,
query_context_type: Optional[QueryContextType] = None,
fragment: Optional[str] = None,
+ matchPVals: bool = False,
):
query_context = exception.getQueryContext()
assert bool(query_context) == (query_context_type is not None), (
@@ -367,9 +368,30 @@ class PySparkErrorTestUtils:
# Test message parameters
expected = messageParameters
actual = exception.getMessageParameters()
- self.assertEqual(
- expected, actual, f"Expected message parameters was '{expected}',
got '{actual}'"
- )
+ if matchPVals:
+ self.assertEqual(
+ len(expected),
+ len(actual),
+ "Expected message parameters count does not match actual
message parameters count"
+ f": {len(expected)}, {len(actual)}.",
+ )
+ for key, value in expected.items():
+ self.assertIn(
+ key,
+ actual,
+ f"Expected message parameter key '{key}' was not found "
+ "in actual message parameters.",
+ )
+ self.assertRegex(
+ actual[key],
+ value,
+ f"Expected message parameter value '{value}' does not
match actual message "
+ f"parameter value '{actual[key]}'.",
+ ),
+ else:
+ self.assertEqual(
+ expected, actual, f"Expected message parameters was
'{expected}', got '{actual}'"
+ )
# Test query context
if query_context:
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index f745c152170e..ef4bdb8d5bdf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -70,6 +70,19 @@ private[sql] trait ColumnNode extends ColumnNodeLike {
trait ColumnNodeLike {
private[internal] def normalize(): ColumnNodeLike = this
private[internal] def sql: String
+ private[internal] def children: Seq[ColumnNodeLike]
+
+ private[sql] def foreach(f: ColumnNodeLike => Unit): Unit = {
+ f(this)
+ children.foreach(_.foreach(f))
+ }
+
+ private[sql] def collect[A](pf: PartialFunction[ColumnNodeLike, A]): Seq[A]
= {
+ val ret = new collection.mutable.ArrayBuffer[A]()
+ val lifted = pf.lift
+ foreach(node => lifted(node).foreach(ret.+=))
+ ret.toSeq
+ }
}
private[internal] object ColumnNode {
@@ -118,6 +131,8 @@ private[sql] case class Literal(
case v: Short => toSQLValue(v)
case _ => value.toString
}
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
/**
@@ -141,6 +156,8 @@ private[sql] case class UnresolvedAttribute(
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`"
else n).mkString(".")
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
private[sql] object UnresolvedAttribute {
@@ -183,6 +200,7 @@ private[sql] case class UnresolvedStar(
override private[internal] def normalize(): UnresolvedStar =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = unparsedTarget.map(_ + ".*").getOrElse("*")
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
/**
@@ -208,6 +226,8 @@ private[sql] case class UnresolvedFunction(
copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN)
override def sql: String = functionName + argumentsToSql(arguments)
+
+ override private[internal] def children: Seq[ColumnNodeLike] = arguments
}
/**
@@ -222,6 +242,7 @@ private[sql] case class SqlExpression(
extends ColumnNode {
override private[internal] def normalize(): SqlExpression = copy(origin =
NO_ORIGIN)
override def sql: String = expression
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
/**
@@ -250,6 +271,8 @@ private[sql] case class Alias(
}
s"${child.sql} AS $alias"
}
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
}
/**
@@ -275,10 +298,14 @@ private[sql] case class Cast(
override def sql: String = {
s"${optionToSql(evalMode)}CAST(${child.sql} AS ${dataType.sql})"
}
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq(child) ++
evalMode
}
private[sql] object Cast {
- sealed abstract class EvalMode(override val sql: String = "") extends
ColumnNodeLike
+ sealed abstract class EvalMode(override val sql: String = "") extends
ColumnNodeLike {
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+ }
object Legacy extends EvalMode
object Ansi extends EvalMode
object Try extends EvalMode("TRY_")
@@ -300,6 +327,7 @@ private[sql] case class UnresolvedRegex(
override private[internal] def normalize(): UnresolvedRegex =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = regex
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
/**
@@ -322,13 +350,19 @@ private[sql] case class SortOrder(
copy(child = child.normalize(), origin = NO_ORIGIN)
override def sql: String = s"${child.sql} ${sortDirection.sql}
${nullOrdering.sql}"
+
+ override def children: Seq[ColumnNodeLike] = Seq(child, sortDirection,
nullOrdering)
}
private[sql] object SortOrder {
- sealed abstract class SortDirection(override val sql: String) extends
ColumnNodeLike
+ sealed abstract class SortDirection(override val sql: String) extends
ColumnNodeLike {
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+ }
object Ascending extends SortDirection("ASC")
object Descending extends SortDirection("DESC")
- sealed abstract class NullOrdering(override val sql: String) extends
ColumnNodeLike
+ sealed abstract class NullOrdering(override val sql: String) extends
ColumnNodeLike {
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+ }
object NullsFirst extends NullOrdering("NULLS FIRST")
object NullsLast extends NullOrdering("NULLS LAST")
}
@@ -352,6 +386,8 @@ private[sql] case class Window(
origin = NO_ORIGIN)
override def sql: String = s"${windowFunction.sql} OVER (${windowSpec.sql})"
+
+ override private[internal] def children: Seq[ColumnNodeLike] =
Seq(windowFunction, windowSpec)
}
private[sql] case class WindowSpec(
@@ -370,6 +406,9 @@ private[sql] case class WindowSpec(
optionToSql(frame))
parts.filter(_.nonEmpty).mkString(" ")
}
+ override private[internal] def children: Seq[ColumnNodeLike] = {
+ partitionColumns ++ sortColumns ++ frame
+ }
}
private[sql] case class WindowFrame(
@@ -381,15 +420,19 @@ private[sql] case class WindowFrame(
copy(lower = lower.normalize(), upper = upper.normalize())
override private[internal] def sql: String =
s"${frameType.sql} BETWEEN ${lower.sql} AND ${upper.sql}"
+ override private[internal] def children: Seq[ColumnNodeLike] =
Seq(frameType, lower, upper)
}
private[sql] object WindowFrame {
- sealed abstract class FrameType(override val sql: String) extends
ColumnNodeLike
+ sealed abstract class FrameType(override val sql: String) extends
ColumnNodeLike {
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+ }
object Row extends FrameType("ROWS")
object Range extends FrameType("RANGE")
sealed abstract class FrameBoundary extends ColumnNodeLike {
override private[internal] def normalize(): FrameBoundary = this
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
object CurrentRow extends FrameBoundary {
override private[internal] def sql = "CURRENT ROW"
@@ -403,6 +446,7 @@ private[sql] object WindowFrame {
case class Value(value: ColumnNode) extends FrameBoundary {
override private[internal] def normalize(): Value = copy(value.normalize())
override private[internal] def sql: String = value.sql
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq(value)
}
def value(i: Int): Value = Value(Literal(i, Some(IntegerType)))
def value(l: Long): Value = Value(Literal(l, Some(LongType)))
@@ -434,6 +478,8 @@ private[sql] case class LambdaFunction(
}
argumentsSql + " -> " + function.sql
}
+
+ override private[internal] def children: Seq[ColumnNodeLike] = function +:
arguments
}
object LambdaFunction {
@@ -455,6 +501,8 @@ private[sql] case class UnresolvedNamedLambdaVariable(
copy(origin = NO_ORIGIN)
override def sql: String = name
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
object UnresolvedNamedLambdaVariable {
@@ -495,6 +543,8 @@ private[sql] case class UnresolvedExtractValue(
copy(child = child.normalize(), extraction = extraction.normalize(),
origin = NO_ORIGIN)
override def sql: String = s"${child.sql}[${extraction.sql}]"
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq(child,
extraction)
}
/**
@@ -521,6 +571,9 @@ private[sql] case class UpdateFields(
case Some(value) => s"update_field(${structExpression.sql}, $fieldName,
${value.sql})"
case None => s"drop_field(${structExpression.sql}, $fieldName)"
}
+ override private[internal] def children: Seq[ColumnNodeLike] = {
+ structExpression +: valueExpression.toSeq
+ }
}
/**
@@ -549,6 +602,11 @@ private[sql] case class CaseWhenOtherwise(
branches.map(cv => s" WHEN ${cv._1.sql} THEN ${cv._2.sql}").mkString +
otherwise.map(o => s" ELSE ${o.sql}").getOrElse("") +
" END"
+
+ override private[internal] def children: Seq[ColumnNodeLike] = {
+ val branchChildren = branches.flatMap { case (condition, value) =>
Seq(condition, value) }
+ branchChildren ++ otherwise
+ }
}
/**
@@ -570,6 +628,8 @@ private[sql] case class InvokeInlineUserDefinedFunction(
override def sql: String =
function.name + argumentsToSql(arguments)
+
+ override private[internal] def children: Seq[ColumnNodeLike] = arguments
}
private[sql] trait UserDefinedFunctionLike {
@@ -589,4 +649,5 @@ private[sql] case class LazyExpression(
override private[internal] def normalize(): ColumnNode =
copy(child = child.normalize(), origin = NO_ORIGIN)
override def sql: String = "lazy" + argumentsToSql(Seq(child))
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 61b68b743a5c..87a5e94d9f63 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -1056,3 +1056,18 @@ case class LazyExpression(child: Expression) extends
UnaryExpression with Uneval
}
final override val nodePatterns: Seq[TreePattern] = Seq(LAZY_EXPRESSION)
}
+
+trait UnresolvedPlanId extends LeafExpression with Unevaluable {
+ override def nullable: Boolean = throw new UnresolvedException("nullable")
+ override def dataType: DataType = throw new UnresolvedException("dataType")
+ override lazy val resolved = false
+
+ def planId: Long
+ def withPlan(plan: LogicalPlan): Expression
+
+ final override val nodePatterns: Seq[TreePattern] =
+ Seq(UNRESOLVED_PLAN_ID) ++ nodePatternsInternal()
+
+ // Subclasses can override this function to provide more TreePatterns.
+ def nodePatternsInternal(): Seq[TreePattern] = Seq()
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 0c8253659dd5..c0a2bf25fbe6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.catalyst.analysis.UnresolvedPlanId
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -439,6 +440,14 @@ object ScalarSubquery {
}
}
+case class UnresolvedScalarSubqueryPlanId(planId: Long)
+ extends UnresolvedPlanId {
+
+ override def withPlan(plan: LogicalPlan): Expression = {
+ ScalarSubquery(plan)
+ }
+}
+
/**
* A subquery that can return multiple rows and columns. This should be
rewritten as a join
* with the outer query during the optimization phase.
@@ -592,3 +601,11 @@ case class Exists(
final override def nodePatternsInternal(): Seq[TreePattern] =
Seq(EXISTS_SUBQUERY)
}
+
+case class UnresolvedExistsPlanId(planId: Long)
+ extends UnresolvedPlanId {
+
+ override def withPlan(plan: LogicalPlan): Expression = {
+ Exists(plan)
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 95b5832392ec..1dfb0336ecf0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -155,6 +155,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
+ val UNRESOLVED_PLAN_ID: Value = Value
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
diff --git
a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 811dd032aa41..a01b5229a7b7 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -53,6 +53,7 @@ message Expression {
MergeAction merge_action = 19;
TypedAggregateExpression typed_aggregate_expression = 20;
LazyExpression lazy_expression = 21;
+ SubqueryExpression subquery_expression = 22;
// This field is used to mark extensions to the protocol. When plugins
generate arbitrary
// relations they can add them here. During the planning the correct
resolution is done.
@@ -457,3 +458,17 @@ message LazyExpression {
// (Required) The expression to be marked as lazy.
Expression child = 1;
}
+
+message SubqueryExpression {
+ // (Required) The id of corresponding connect plan.
+ int64 plan_id = 1;
+
+ // (Required) The type of the subquery.
+ SubqueryType subquery_type = 2;
+
+ enum SubqueryType {
+ SUBQUERY_TYPE_UNKNOWN = 0;
+ SUBQUERY_TYPE_SCALAR = 1;
+ SUBQUERY_TYPE_EXISTS = 2;
+ }
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index bfb5f8f3fab7..5ace916ba3e9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID,
SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction,
UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction,
UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter,
Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment,
CoGroup, CollectMetrics, CommandResult, Deduplicate,
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except,
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias,
TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...]
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.classic.ClassicConversions._
@@ -161,9 +161,8 @@ class SparkConnectPlanner(
case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
case proto.Relation.RelTypeCase.AGGREGATE =>
transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
- case proto.Relation.RelTypeCase.WITH_RELATIONS
- if isValidSQLWithRefs(rel.getWithRelations) =>
- transformSqlWithRefs(rel.getWithRelations)
+ case proto.Relation.RelTypeCase.WITH_RELATIONS =>
+ transformWithRelations(rel.getWithRelations)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.SAMPLE =>
transformSample(rel.getSample)
@@ -1559,6 +1558,8 @@ class SparkConnectPlanner(
transformTypedAggregateExpression(exp.getTypedAggregateExpression,
baseRelationOpt)
case proto.Expression.ExprTypeCase.LAZY_EXPRESSION =>
transformLazyExpression(exp.getLazyExpression)
+ case proto.Expression.ExprTypeCase.SUBQUERY_EXPRESSION =>
+ transformSubqueryExpression(exp.getSubqueryExpression)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not
supported")
@@ -3724,7 +3725,56 @@ class SparkConnectPlanner(
LazyExpression(transformExpression(getLazyExpression.getChild))
}
- private def assertPlan(assertion: Boolean, message: String = ""): Unit = {
+ private def transformSubqueryExpression(
+ getSubqueryExpression: proto.SubqueryExpression): Expression = {
+ val planId = getSubqueryExpression.getPlanId
+ getSubqueryExpression.getSubqueryType match {
+ case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR =>
+ UnresolvedScalarSubqueryPlanId(planId)
+ case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS =>
+ UnresolvedExistsPlanId(planId)
+ case other => throw InvalidPlanInput(s"Unknown SubqueryType $other")
+ }
+ }
+
+ private def transformWithRelations(getWithRelations: proto.WithRelations):
LogicalPlan = {
+ if (isValidSQLWithRefs(getWithRelations)) {
+ transformSqlWithRefs(getWithRelations)
+ } else {
+ // Wrap the plan to keep the original planId.
+ val plan = Project(Seq(UnresolvedStar(None)),
transformRelation(getWithRelations.getRoot))
+
+ val relations = getWithRelations.getReferencesList.asScala.map { ref =>
+ if (ref.hasCommon && ref.getCommon.hasPlanId) {
+ val planId = ref.getCommon.getPlanId
+ val plan = transformRelation(ref)
+ planId -> plan
+ } else {
+ throw InvalidPlanInput("Invalid WithRelation reference")
+ }
+ }.toMap
+
+ val missingPlanIds = mutable.Set.empty[Long]
+ val withRelations = plan
+
.transformAllExpressionsWithPruning(_.containsPattern(TreePattern.UNRESOLVED_PLAN_ID))
{
+ case u: UnresolvedPlanId =>
+ if (relations.contains(u.planId)) {
+ u.withPlan(relations(u.planId))
+ } else {
+ missingPlanIds += u.planId
+ u
+ }
+ }
+ assertPlan(
+ missingPlanIds.isEmpty,
+ "Missing relation in WithRelations: " +
+ s"${missingPlanIds.mkString("(", ", ", ")")} not in " +
+ s"${relations.keys.mkString("(", ", ", ")")}")
+ withRelations
+ }
+ }
+
+ private def assertPlan(assertion: Boolean, message: => String = ""): Unit = {
if (!assertion) throw InvalidPlanInput(message)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 8b4726114890..8f37f5c32de3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -273,6 +273,8 @@ private[sql] case class ExpressionColumnNode private(
}
override def sql: String = expression.sql
+
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
private[sql] object ExpressionColumnNode {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index cd425162fb01..f94cf89276ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -664,4 +664,102 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
)
}
}
+
+ test("subquery with generator / table-valued functions") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ spark.range(1).select(explode(t1.select(collect_list("c2")).scalar())),
+ sql("SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))")
+ )
+ checkAnswer(
+ spark.tvf.explode(t1.select(collect_list("c2")).scalar()),
+ sql("SELECT * FROM EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))")
+ )
+ }
+ }
+
+ test("subquery in join condition") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.join(t2, $"t1.c1" === t1.select(max("c1")).scalar()),
+ sql("SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT MAX(c1) FROM t1)")
+ )
+ }
+ }
+
+ test("subquery in unpivot") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkError(
+ intercept[AnalysisException] {
+ t1.unpivot(Array(t2.exists()), "c1", "c2").collect()
+ },
+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+ parameters = Map("treeNode" -> "(?s)'Unpivot.*"),
+ matchPVals = true,
+ queryContext = Array(ExpectedContext(
+ fragment = "exists",
+ callSitePattern = getCurrentClassCallSitePattern))
+ )
+ checkError(
+ intercept[AnalysisException] {
+ t1.unpivot(Array($"c1"), Array(t2.exists()), "c1", "c2").collect()
+ },
+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+ parameters = Map("treeNode" -> "(?s)Expand.*"),
+ matchPVals = true,
+ queryContext = Array(ExpectedContext(
+ fragment = "exists",
+ callSitePattern = getCurrentClassCallSitePattern))
+ )
+ }
+ }
+
+ test("subquery in transpose") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkError(
+ intercept[AnalysisException] {
+ t1.transpose(t1.select(max("c1")).scalar()).collect()
+ },
+ "TRANSPOSE_INVALID_INDEX_COLUMN",
+ parameters = Map("reason" -> "Index column must be an atomic
attribute")
+ )
+ }
+ }
+
+ test("subquery in withColumns") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.withColumn("scalar", spark.range(1).select($"c1".outer() +
$"c2".outer()).scalar()),
+ t1.withColumn("scalar", $"c1" + $"c2")
+ )
+ }
+ }
+
+ test("subquery in drop") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(t1.drop(spark.range(1).select(lit("c1")).scalar()), t1)
+ }
+ }
+
+ test("subquery in repartition") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(t1.repartition(spark.range(1).select(lit(1)).scalar()), t1)
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
index 76fcdfc38095..d72e86450de2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
@@ -405,4 +405,5 @@ private[internal] case class Nope(override val origin:
Origin = CurrentOrigin.ge
extends ColumnNode {
override private[internal] def normalize(): Nope = this
override def sql: String = "nope"
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]