This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aac494e74c60 [SPARK-50134][SPARK-50130][SQL][CONNECT] Support 
DataFrame API for SCALAR and EXISTS subqueries in Spark Connect
aac494e74c60 is described below

commit aac494e74c60acfd9abd65386a931fdbbf75c433
Author: Takuya Ueshin <[email protected]>
AuthorDate: Thu Dec 26 13:31:02 2024 -0800

    [SPARK-50134][SPARK-50130][SQL][CONNECT] Support DataFrame API for SCALAR 
and EXISTS subqueries in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Supports DataFrame API for SCALAR and EXISTS subqueries in Spark Connect.
    
    The proto plan will be, using `with_relations`:
    
    ```
    with_relations [id 10]
        root: plan  [id 9]  using subquery expressions holding plan ids
        reference:
             refs#1: [id 8]  plan for the subquery 1
             refs#2: [id 5]  plan for the subquery 2
    ```
    
    ### Why are the changes needed?
    
    DataFrame APIs for SCALAR and EXISTS subqueries are missing in Spark 
Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, SCALAR and EXISTS subqueries will be available in Spark Connect.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49182 from ueshin/issues/SPARK-50130/scalar_exists_connect.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  68 ++--
 .../spark/sql/RelationalGroupedDataset.scala       |   2 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  95 ++++-
 .../org/apache/spark/sql/TableValuedFunction.scala |   2 +-
 .../spark/sql/internal/columnNodeSupport.scala     |  29 ++
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 440 +++++++++++++++++++++
 .../internal/ColumnNodeToProtoConverterSuite.scala |   1 +
 .../org/apache/spark/sql/test/QueryTest.scala      | 155 ++++++++
 .../org/apache/spark/sql/test/SQLHelper.scala      |  17 +
 python/pyspark/sql/connect/dataframe.py            |  21 +-
 python/pyspark/sql/connect/expressions.py          | 100 ++++-
 python/pyspark/sql/connect/plan.py                 | 234 +++++++----
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 156 ++++----
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  53 +++
 python/pyspark/sql/dataframe.py                    |   8 +-
 .../sql/tests/connect/test_parity_subquery.py      |  57 ++-
 python/pyspark/sql/tests/test_subquery.py          | 121 +++++-
 python/pyspark/testing/utils.py                    |  28 +-
 .../apache/spark/sql/internal/columnNodes.scala    |  69 +++-
 .../spark/sql/catalyst/analysis/unresolved.scala   |  15 +
 .../spark/sql/catalyst/expressions/subquery.scala  |  17 +
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   1 +
 .../main/protobuf/spark/connect/expressions.proto  |  15 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  62 ++-
 .../spark/sql/internal/columnNodeSupport.scala     |   2 +
 .../apache/spark/sql/DataFrameSubquerySuite.scala  |  98 +++++
 .../ColumnNodeToExpressionConverterSuite.scala     |   1 +
 27 files changed, 1593 insertions(+), 274 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index ffaa8a70cc7c..75df538678a3 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.functions.{struct, to_json}
-import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, 
DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, 
UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
+import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, 
DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, 
SubqueryExpressionNode, SubqueryType, ToScalaUDF, UDFAdaptors, 
UnresolvedAttribute, UnresolvedRegex}
 import org.apache.spark.sql.streaming.DataStreamWriter
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -288,9 +288,10 @@ class Dataset[T] private[sql] (
   /** @inheritdoc */
   def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
 
-  private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): 
DataFrame = {
+  private def buildJoin(right: Dataset[_], cols: Seq[Column] = Seq.empty)(
+      f: proto.Join.Builder => Unit): DataFrame = {
     checkSameSparkSession(right)
-    sparkSession.newDataFrame { builder =>
+    sparkSession.newDataFrame(cols) { builder =>
       val joinBuilder = builder.getJoinBuilder
       joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
       f(joinBuilder)
@@ -334,7 +335,7 @@ class Dataset[T] private[sql] (
 
   /** @inheritdoc */
   def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame 
= {
-    buildJoin(right) { builder =>
+    buildJoin(right, Seq(joinExprs)) { builder =>
       builder
         .setJoinType(toJoinType(joinType))
         .setJoinCondition(joinExprs.expr)
@@ -394,7 +395,7 @@ class Dataset[T] private[sql] (
       case _ =>
         throw new IllegalArgumentException(s"Unsupported lateral join type 
$joinType")
     }
-    sparkSession.newDataFrame { builder =>
+    sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
       val lateralJoinBuilder = builder.getLateralJoinBuilder
       lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
       joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
@@ -426,7 +427,7 @@ class Dataset[T] private[sql] (
     val sortExprs = sortCols.map { c =>
       ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
     }
-    sparkSession.newDataset(agnosticEncoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder, sortCols) { builder =>
       builder.getSortBuilder
         .setInput(plan.getRoot)
         .setIsGlobal(global)
@@ -502,7 +503,7 @@ class Dataset[T] private[sql] (
    * methods and typed select methods is the encoder used to build the return 
dataset.
    */
   private def selectUntyped(encoder: AgnosticEncoder[_], cols: Seq[Column]): 
Dataset[_] = {
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(encoder, cols) { builder =>
       builder.getProjectBuilder
         .setInput(plan.getRoot)
         .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
@@ -510,29 +511,32 @@ class Dataset[T] private[sql] (
   }
 
   /** @inheritdoc */
-  def filter(condition: Column): Dataset[T] = 
sparkSession.newDataset(agnosticEncoder) {
-    builder =>
+  def filter(condition: Column): Dataset[T] = {
+    sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
       
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+    }
   }
 
   private def buildUnpivot(
       ids: Array[Column],
       valuesOption: Option[Array[Column]],
       variableColumnName: String,
-      valueColumnName: String): DataFrame = sparkSession.newDataFrame { 
builder =>
-    val unpivot = builder.getUnpivotBuilder
-      .setInput(plan.getRoot)
-      .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
-      .setVariableColumnName(variableColumnName)
-      .setValueColumnName(valueColumnName)
-    valuesOption.foreach { values =>
-      unpivot.getValuesBuilder
-        .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+      valueColumnName: String): DataFrame = {
+    sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { 
builder =>
+      val unpivot = builder.getUnpivotBuilder
+        .setInput(plan.getRoot)
+        .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
+        .setVariableColumnName(variableColumnName)
+        .setValueColumnName(valueColumnName)
+      valuesOption.foreach { values =>
+        unpivot.getValuesBuilder
+          .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+      }
     }
   }
 
   private def buildTranspose(indices: Seq[Column]): DataFrame =
-    sparkSession.newDataFrame { builder =>
+    sparkSession.newDataFrame(indices) { builder =>
       val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
       indices.foreach { indexColumn =>
         transpose.addIndexColumns(indexColumn.expr)
@@ -624,18 +628,15 @@ class Dataset[T] private[sql] (
   def transpose(): DataFrame =
     buildTranspose(Seq.empty)
 
-  // TODO(SPARK-50134): Support scalar Subquery API in Spark Connect
-  // scalastyle:off not.implemented.error.usage
   /** @inheritdoc */
   def scalar(): Column = {
-    ???
+    Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR))
   }
 
   /** @inheritdoc */
   def exists(): Column = {
-    ???
+    Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS))
   }
-  // scalastyle:on not.implemented.error.usage
 
   /** @inheritdoc */
   def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { 
builder =>
@@ -782,7 +783,7 @@ class Dataset[T] private[sql] (
     val aliases = values.zip(names).map { case (value, name) =>
       value.name(name).expr.getAlias
     }
-    sparkSession.newDataFrame { builder =>
+    sparkSession.newDataFrame(values) { builder =>
       builder.getWithColumnsBuilder
         .setInput(plan.getRoot)
         .addAllAliases(aliases.asJava)
@@ -842,10 +843,12 @@ class Dataset[T] private[sql] (
   @scala.annotation.varargs
   def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)
 
-  private def buildDrop(cols: Seq[Column]): DataFrame = 
sparkSession.newDataFrame { builder =>
-    builder.getDropBuilder
-      .setInput(plan.getRoot)
-      .addAllColumns(cols.map(_.expr).asJava)
+  private def buildDrop(cols: Seq[Column]): DataFrame = {
+    sparkSession.newDataFrame(cols) { builder =>
+      builder.getDropBuilder
+        .setInput(plan.getRoot)
+        .addAllColumns(cols.map(_.expr).asJava)
+    }
   }
 
   private def buildDropByNames(cols: Seq[String]): DataFrame = 
sparkSession.newDataFrame {
@@ -1015,12 +1018,13 @@ class Dataset[T] private[sql] (
 
   private def buildRepartitionByExpression(
       numPartitions: Option[Int],
-      partitionExprs: Seq[Column]): Dataset[T] = 
sparkSession.newDataset(agnosticEncoder) {
-    builder =>
+      partitionExprs: Seq[Column]): Dataset[T] = {
+    sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
       val repartitionBuilder = builder.getRepartitionByExpressionBuilder
         .setInput(plan.getRoot)
         .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
       numPartitions.foreach(repartitionBuilder.setNumPartitions)
+    }
   }
 
   /** @inheritdoc */
@@ -1152,7 +1156,7 @@ class Dataset[T] private[sql] (
   /** @inheritdoc */
   @scala.annotation.varargs
   def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
-    sparkSession.newDataset(agnosticEncoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder, expr +: exprs) { builder =>
       builder.getCollectMetricsBuilder
         .setInput(plan.getRoot)
         .setName(name)
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 5bded40b0d13..0944c88a6790 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -45,7 +45,7 @@ class RelationalGroupedDataset private[sql] (
   import df.sparkSession.RichColumn
 
   protected def toDF(aggExprs: Seq[Column]): DataFrame = {
-    df.sparkSession.newDataFrame { builder =>
+    df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
       val aggBuilder = builder.getAggregateBuilder
         .setInput(df.plan.getRoot)
       groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b6bba8251913..89519034d07c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, 
CloseableIterator, Spar
 import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, 
SessionCleaner, SessionState, SharedState, SqlApiConf}
+import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, 
SessionCleaner, SessionState, SharedState, SqlApiConf, SubqueryExpressionNode}
 import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, 
toTypedExpr}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.streaming.DataStreamReader
@@ -324,20 +324,111 @@ class SparkSession private[sql] (
     }
   }
 
+  /**
+   * Create a DataFrame including the proto plan built by the given function.
+   *
+   * @param f
+   *   The function to build the proto plan.
+   * @return
+   *   The DataFrame created from the proto plan.
+   */
   @Since("4.0.0")
   @DeveloperApi
   def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
     newDataset(UnboundRowEncoder)(f)
   }
 
+  /**
+   * Create a DataFrame including the proto plan built by the given function.
+   *
+   * Use this method when columns are used to create a new DataFrame. When 
there are columns
+   * referring to other Dataset or DataFrame, the plan will be wrapped with a 
`WithRelation`.
+   *
+   * {{{
+   *   with_relations [id 10]
+   *     root: plan  [id 9]  using columns referring to other Dataset or 
DataFrame, holding plan ids
+   *     reference:
+   *          refs#1: [id 8]  plan for the reference 1
+   *          refs#2: [id 5]  plan for the reference 2
+   * }}}
+   *
+   * @param cols
+   *   The columns to be used in the DataFrame.
+   * @param f
+   *   The function to build the proto plan.
+   * @return
+   *   The DataFrame created from the proto plan.
+   */
+  @Since("4.0.0")
+  @DeveloperApi
+  def newDataFrame(cols: Seq[Column])(f: proto.Relation.Builder => Unit): 
DataFrame = {
+    newDataset(UnboundRowEncoder, cols)(f)
+  }
+
+  /**
+   * Create a Dataset including the proto plan built by the given function.
+   *
+   * @param encoder
+   *   The encoder for the Dataset.
+   * @param f
+   *   The function to build the proto plan.
+   * @return
+   *   The Dataset created from the proto plan.
+   */
   @Since("4.0.0")
   @DeveloperApi
   def newDataset[T](encoder: AgnosticEncoder[T])(
       f: proto.Relation.Builder => Unit): Dataset[T] = {
+    newDataset[T](encoder, Seq.empty)(f)
+  }
+
+  /**
+   * Create a Dataset including the proto plan built by the given function.
+   *
+   * Use this method when columns are used to create a new Dataset. When there 
are columns
+   * referring to other Dataset or DataFrame, the plan will be wrapped with a 
`WithRelation`.
+   *
+   * {{{
+   *   with_relations [id 10]
+   *     root: plan  [id 9]  using columns referring to other Dataset or 
DataFrame, holding plan ids
+   *     reference:
+   *          refs#1: [id 8]  plan for the reference 1
+   *          refs#2: [id 5]  plan for the reference 2
+   * }}}
+   *
+   * @param encoder
+   *   The encoder for the Dataset.
+   * @param cols
+   *   The columns to be used in the DataFrame.
+   * @param f
+   *   The function to build the proto plan.
+   * @return
+   *   The Dataset created from the proto plan.
+   */
+  @Since("4.0.0")
+  @DeveloperApi
+  def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])(
+      f: proto.Relation.Builder => Unit): Dataset[T] = {
+    val references = cols.flatMap(_.node.collect { case n: 
SubqueryExpressionNode =>
+      n.relation
+    })
+
     val builder = proto.Relation.newBuilder()
     f(builder)
     builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
-    val plan = proto.Plan.newBuilder().setRoot(builder).build()
+
+    val rootBuilder = if (references.length == 0) {
+      builder
+    } else {
+      val rootBuilder = proto.Relation.newBuilder()
+      rootBuilder.getWithRelationsBuilder
+        .setRoot(builder)
+        .addAllReferences(references.asJava)
+      rootBuilder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
+      rootBuilder
+    }
+
+    val plan = proto.Plan.newBuilder().setRoot(rootBuilder).build()
     new Dataset[T](this, plan, encoder)
   }
 
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
index 4f2687b53786..2a5afd1d5871 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
@@ -47,7 +47,7 @@ class TableValuedFunction(sparkSession: SparkSession) extends 
api.TableValuedFun
   }
 
   private def fn(name: String, args: Seq[Column]): Dataset[Row] = {
-    sparkSession.newDataFrame { builder =>
+    sparkSession.newDataFrame(args) { builder =>
       builder.getUnresolvedTableValuedFunctionBuilder
         .setFunctionName(name)
         .addAllArguments(args.map(toExpr).asJava)
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 0e8889e19de2..8d57a8d3efd4 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -167,6 +167,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode => 
proto.Expression) {
       case LazyExpression(child, _) =>
         builder.getLazyExpressionBuilder.setChild(apply(child, e))
 
+      case SubqueryExpressionNode(relation, subqueryType, _) =>
+        val b = builder.getSubqueryExpressionBuilder
+        b.setSubqueryType(subqueryType match {
+          case SubqueryType.SCALAR => 
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR
+          case SubqueryType.EXISTS => 
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS
+        })
+        assert(relation.hasCommon && relation.getCommon.hasPlanId)
+        b.setPlanId(relation.getCommon.getPlanId)
+
       case ProtoColumnNode(e, _) =>
         return e
 
@@ -217,4 +226,24 @@ case class ProtoColumnNode(
     override val origin: Origin = CurrentOrigin.get)
     extends ColumnNode {
   override def sql: String = expr.toString
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+}
+
+sealed trait SubqueryType
+
+object SubqueryType {
+  case object SCALAR extends SubqueryType
+  case object EXISTS extends SubqueryType
+}
+
+case class SubqueryExpressionNode(
+    relation: proto.Relation,
+    subqueryType: SubqueryType,
+    override val origin: Origin = CurrentOrigin.get)
+    extends ColumnNode {
+  override def sql: String = subqueryType match {
+    case SubqueryType.SCALAR => s"($relation)"
+    case _ => s"$subqueryType ($relation)"
+  }
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 91f60b1fefb9..fc37444f7719 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -17,12 +17,306 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.{SparkException, SparkRuntimeException}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession}
 
 class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession {
   import testImplicits._
 
+  val row = identity[(java.lang.Integer, java.lang.Double)](_)
+
+  lazy val l = Seq(
+    row((1, 2.0)),
+    row((1, 2.0)),
+    row((2, 1.0)),
+    row((2, 1.0)),
+    row((3, 3.0)),
+    row((null, null)),
+    row((null, 5.0)),
+    row((6, null))).toDF("a", "b")
+
+  lazy val r = Seq(
+    row((2, 3.0)),
+    row((2, 3.0)),
+    row((3, 2.0)),
+    row((4, 1.0)),
+    row((null, null)),
+    row((null, 5.0)),
+    row((6, null))).toDF("c", "d")
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    l.createOrReplaceTempView("l")
+    r.createOrReplaceTempView("r")
+  }
+
+  test("noop outer()") {
+    checkAnswer(spark.range(1).select($"id".outer()), Row(0))
+    checkError(
+      
intercept[AnalysisException](spark.range(1).select($"outer_col".outer()).collect()),
+      "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+      parameters = Map("objectName" -> "`outer_col`", "proposal" -> "`id`"))
+  }
+
+  test("simple uncorrelated scalar subquery") {
+    checkAnswer(
+      spark.range(1).select(spark.range(1).select(lit(1)).scalar().as("b")),
+      sql("select (select 1 as b) as b"))
+
+    checkAnswer(
+      spark
+        .range(1)
+        .select(
+          spark.range(1).select(spark.range(1).select(lit(1)).scalar() + 
1).scalar() + lit(1)),
+      sql("select (select (select 1) + 1) + 1"))
+
+    // string type
+    checkAnswer(
+      spark.range(1).select(spark.range(1).select(lit("s")).scalar().as("b")),
+      sql("select (select 's' as s) as b"))
+  }
+
+  test("uncorrelated scalar subquery should return null if there is 0 rows") {
+    checkAnswer(
+      
spark.range(1).select(spark.range(1).select(lit("s")).limit(0).scalar().as("b")),
+      sql("select (select 's' as s limit 0) as b"))
+  }
+
+  test("uncorrelated scalar subquery on a DataFrame generated query") {
+    withTempView("subqueryData") {
+      val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
+      df.createOrReplaceTempView("subqueryData")
+
+      checkAnswer(
+        spark
+          .range(1)
+          .select(
+            spark
+              .table("subqueryData")
+              .select($"key")
+              .where($"key" > 2)
+              .orderBy($"key")
+              .limit(1)
+              .scalar() + lit(1)),
+        sql("select (select key from subqueryData where key > 2 order by key 
limit 1) + 1"))
+
+      checkAnswer(
+        
spark.range(1).select(-spark.table("subqueryData").select(max($"key")).scalar()),
+        sql("select -(select max(key) from subqueryData)"))
+
+      checkAnswer(
+        
spark.range(1).select(spark.table("subqueryData").select($"value").limit(0).scalar()),
+        sql("select (select value from subqueryData limit 0)"))
+
+      checkAnswer(
+        spark
+          .range(1)
+          .select(
+            spark
+              .table("subqueryData")
+              .where($"key" === 
spark.table("subqueryData").select(max($"key")).scalar() - lit(1))
+              .select(min($"value"))
+              .scalar()),
+        sql(
+          "select (select min(value) from subqueryData" +
+            " where key = (select max(key) from subqueryData) - 1)"))
+    }
+  }
+
+  test("correlated scalar subquery in SELECT with outer() function") {
+    val df1 = spark.table("l").as("t1")
+    val df2 = spark.table("l").as("t2")
+    // We can use the `.outer()` function to wrap either the outer column, or 
the entire condition,
+    // or the SQL string of the condition.
+    Seq($"t1.a" === $"t2.a".outer(), ($"t1.a" === $"t2.a").outer(), expr("t1.a 
= t2.a").outer())
+      .foreach { cond =>
+        checkAnswer(
+          df1.select($"a", 
df2.where(cond).select(sum($"b")).scalar().as("sum_b")),
+          sql("select a, (select sum(b) from l t1 where t1.a = t2.a) sum_b 
from l t2"))
+      }
+  }
+
+  test("correlated scalar subquery in WHERE with outer() function") {
+    // We can use the `.outer()` function to wrap either the outer column, or 
the entire condition,
+    // or the SQL string of the condition.
+    Seq($"a".outer() === $"c", ($"a" === $"c").outer(), expr("a = 
c").outer()).foreach { cond =>
+      checkAnswer(
+        spark.table("l").where($"b" < 
spark.table("r").where(cond).select(max($"d")).scalar()),
+        sql("select * from l where b < (select max(d) from r where a = c)"))
+    }
+  }
+
+  test("EXISTS predicate subquery with outer() function") {
+    // We can use the `.outer()` function to wrap either the outer column, or 
the entire condition,
+    // or the SQL string of the condition.
+    Seq($"a".outer() === $"c", ($"a" === $"c").outer(), expr("a = 
c").outer()).foreach { cond =>
+      checkAnswer(
+        spark.table("l").where(spark.table("r").where(cond).exists()),
+        sql("select * from l where exists (select * from r where l.a = r.c)"))
+
+      checkAnswer(
+        spark.table("l").where(spark.table("r").where(cond).exists() && $"a" 
<= lit(2)),
+        sql("select * from l where exists (select * from r where l.a = r.c) 
and l.a <= 2"))
+    }
+  }
+
+  test("SPARK-15677: Queries against local relations with scalar subquery in 
Select list") {
+    withTempView("t1", "t2") {
+      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
+      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
+
+      checkAnswer(
+        
spark.table("t1").select(spark.range(1).select(lit(1).as("col")).scalar()),
+        sql("SELECT (select 1 as col) from t1"))
+
+      checkAnswer(
+        
spark.table("t1").select(spark.table("t2").select(max($"c1")).scalar()),
+        sql("SELECT (select max(c1) from t2) from t1"))
+
+      checkAnswer(
+        spark.table("t1").select(lit(1) + 
spark.range(1).select(lit(1).as("col")).scalar()),
+        sql("SELECT 1 + (select 1 as col) from t1"))
+
+      checkAnswer(
+        spark.table("t1").select($"c1", 
spark.table("t2").select(max($"c1")).scalar() + $"c2"),
+        sql("SELECT c1, (select max(c1) from t2) + c2 from t1"))
+
+      checkAnswer(
+        spark
+          .table("t1")
+          .select(
+            $"c1",
+            spark.table("t2").where($"t1.c2".outer() === 
$"t2.c2").select(max($"c1")).scalar()),
+        sql("SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1"))
+    }
+  }
+
+  test("NOT EXISTS predicate subquery") {
+    checkAnswer(
+      spark.table("l").where(!spark.table("r").where($"a".outer() === 
$"c").exists()),
+      sql("select * from l where not exists (select * from r where l.a = 
r.c)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!spark.table("r").where($"a".outer() === $"c" && $"b".outer() < 
$"d").exists()),
+      sql("select * from l where not exists (select * from r where l.a = r.c 
and l.b < r.d)"))
+  }
+
+  test("EXISTS predicate subquery within OR") {
+    checkAnswer(
+      spark
+        .table("l")
+        .where(spark.table("r").where($"a".outer() === $"c").exists() ||
+          spark.table("r").where($"a".outer() === $"c").exists()),
+      sql(
+        "select * from l where exists (select * from r where l.a = r.c)" +
+          " or exists (select * from r where l.a = r.c)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!spark.table("r").where($"a".outer() === $"c" && $"b".outer() < 
$"d").exists() ||
+          !spark.table("r").where($"a".outer() === $"c").exists()),
+      sql(
+        "select * from l where not exists (select * from r where l.a = r.c and 
l.b < r.d)" +
+          " or not exists (select * from r where l.a = r.c)"))
+  }
+
+  test("correlated scalar subquery in select (null safe equal)") {
+    val df1 = spark.table("l").as("t1")
+    val df2 = spark.table("l").as("t2")
+    checkAnswer(
+      df1.select(
+        $"a",
+        df2.where($"t2.a" <=> 
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")),
+      sql("select a, (select sum(b) from l t2 where t2.a <=> t1.a) sum_b from 
l t1"))
+  }
+
+  test("correlated scalar subquery in aggregate") {
+    checkAnswer(
+      spark
+        .table("l")
+        .groupBy(
+          $"a",
+          spark.table("r").where($"a".outer() === 
$"c").select(sum($"d")).scalar().as("sum_d"))
+        .agg(Map.empty[String, String]),
+      sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group 
by 1, 2"))
+  }
+
+  test("SPARK-34269: correlated subquery with view in aggregate's grouping 
expression") {
+    withTable("tr") {
+      withView("vr") {
+        r.write.saveAsTable("tr")
+        sql("create view vr as select * from tr")
+        checkAnswer(
+          spark
+            .table("l")
+            .groupBy(
+              $"a",
+              spark
+                .table("vr")
+                .where($"a".outer() === $"c")
+                .select(sum($"d"))
+                .scalar()
+                .as("sum_d"))
+            .agg(Map.empty[String, String]),
+          sql("select a, (select sum(d) from vr where a = c) sum_d from l l1 
group by 1, 2"))
+      }
+    }
+  }
+
+  test("non-aggregated correlated scalar subquery") {
+    val df1 = spark.table("l").as("t1")
+    val df2 = spark.table("l").as("t2")
+    val exception1 = intercept[SparkRuntimeException] {
+      df1
+        .select($"a", df2.where($"t1.a" === 
$"t2.a".outer()).select($"b").scalar().as("sum_b"))
+        .collect()
+    }
+    checkError(exception1, condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS")
+  }
+
+  test("non-equal correlated scalar subquery") {
+    val df1 = spark.table("l").as("t1")
+    val df2 = spark.table("l").as("t2")
+    checkAnswer(
+      df1.select(
+        $"a",
+        df2.where($"t2.a" < 
$"t1.a".outer()).select(sum($"b")).scalar().as("sum_b")),
+      sql("select a, (select sum(b) from l t2 where t2.a < t1.a) sum_b from l 
t1"))
+  }
+
+  test("disjunctive correlated scalar subquery") {
+    checkAnswer(
+      spark
+        .table("l")
+        .where(
+          spark
+            .table("r")
+            .where(($"a".outer() === $"c" && $"d" === 2.0) ||
+              ($"a".outer() === $"c" && $"d" === 1.0))
+            .select(count(lit(1)))
+            .scalar() > 0)
+        .select($"a"),
+      sql("""
+            |select a
+            |from   l
+            |where  (select count(*)
+            |        from   r
+            |        where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0
+        """.stripMargin))
+  }
+
+  test("correlated scalar subquery with missing outer reference") {
+    checkAnswer(
+      spark
+        .table("l")
+        .select($"a", spark.table("r").where($"c" === 
$"a").select(sum($"d")).scalar()),
+      sql("select a, (select sum(d) from r where c = a) from l"))
+  }
+
   private def table1() = {
     sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
     spark.table("t1")
@@ -182,6 +476,60 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
     }
   }
 
+  test("scalar subquery inside lateral join") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      // uncorrelated
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select($"c2".outer(), 
t2.select(min($"c2")).scalar()))
+          .toDF("c1", "c2", "c3", "c4"),
+        sql("SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2) FROM t2))")
+          .toDF("c1", "c2", "c3", "c4"))
+
+      // correlated
+      checkAnswer(
+        t1.lateralJoin(
+          spark
+            .range(1)
+            .select($"c1".outer().as("a"))
+            .select(t2.where($"c1" === 
$"a".outer()).select(sum($"c2")).scalar())),
+        sql("""
+              |SELECT * FROM t1, LATERAL (
+              |    SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM (SELECT 
c1 AS a)
+              |)
+              |""".stripMargin))
+    }
+  }
+
+  test("lateral join inside subquery") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      // uncorrelated
+      checkAnswer(
+        t1.where(
+          $"c1" === t2
+            .lateralJoin(spark.range(1).select($"c1".outer().as("a")))
+            .select(min($"a"))
+            .scalar()),
+        sql("SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL 
(SELECT c1 AS a))"))
+      // correlated
+      checkAnswer(
+        t1.where(
+          $"c1" === t2
+            .lateralJoin(spark.range(1).select($"c1".outer().as("a")))
+            .where($"c1" === $"t1.c1".outer())
+            .select(min($"a"))
+            .scalar()),
+        sql(
+          "SELECT * FROM t1 " +
+            "WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE 
c1 = t1.c1)"))
+    }
+  }
+
   test("lateral join with table-valued functions") {
     withView("t1", "t3") {
       val t1 = table1()
@@ -219,4 +567,96 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
         sql("SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = 
c3"))
     }
   }
+
+  test("subquery with generator / table-valued functions") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        spark.range(1).select(explode(t1.select(collect_list("c2")).scalar())),
+        sql("SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))"))
+      checkAnswer(
+        spark.tvf.explode(t1.select(collect_list("c2")).scalar()),
+        sql("SELECT * FROM EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))"))
+    }
+  }
+
+  test("subquery in join condition") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.join(t2, $"t1.c1" === t1.select(max("c1")).scalar()).toDF("c1", 
"c2", "c3", "c4"),
+        sql("SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT MAX(c1) FROM t1)")
+          .toDF("c1", "c2", "c3", "c4"))
+    }
+  }
+
+  test("subquery in unpivot") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkError(
+        intercept[AnalysisException] {
+          t1.unpivot(Array(t2.exists()), "c1", "c2").collect()
+        },
+        
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+        parameters = Map("treeNode" -> "(?s)'Unpivot.*"),
+        matchPVals = true)
+      checkError(
+        intercept[AnalysisException] {
+          t1.unpivot(Array($"c1"), Array(t2.exists()), "c1", "c2").collect()
+        },
+        
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+        parameters = Map("treeNode" -> "(?s)Expand.*"),
+        matchPVals = true)
+    }
+  }
+
+  test("subquery in transpose") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkError(
+        intercept[AnalysisException] {
+          t1.transpose(t1.select(max("c1")).scalar()).collect()
+        },
+        "TRANSPOSE_INVALID_INDEX_COLUMN",
+        parameters = Map("reason" -> "Index column must be an atomic 
attribute"))
+    }
+  }
+
+  test("subquery in withColumns") {
+    withView("t1") {
+      val t1 = table1()
+
+      // TODO(SPARK-50601): Fix the SparkConnectPlanner to support this case
+      checkError(
+        intercept[SparkException] {
+          t1.withColumn("scalar", spark.range(1).select($"c1".outer() + 
$"c2".outer()).scalar())
+            .collect()
+        },
+        "INTERNAL_ERROR",
+        parameters = Map("message" -> "Found the unresolved operator: .*"),
+        matchPVals = true)
+    }
+  }
+
+  test("subquery in drop") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(t1.drop(spark.range(1).select(lit("c1")).scalar()), t1)
+    }
+  }
+
+  test("subquery in repartition") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(t1.repartition(spark.range(1).select(lit(1)).scalar()), t1)
+    }
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
index 2efd39673519..4cb03420c4d0 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala
@@ -431,4 +431,5 @@ class ColumnNodeToProtoConverterSuite extends 
ConnectFunSuite {
 private[internal] case class Nope(override val origin: Origin = 
CurrentOrigin.get)
     extends ColumnNode {
   override def sql: String = "nope"
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
index 8837c76b76ae..f22644074324 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
@@ -19,8 +19,11 @@ package org.apache.spark.sql.test
 
 import java.util.TimeZone
 
+import scala.jdk.CollectionConverters._
+
 import org.scalatest.Assertions
 
+import org.apache.spark.{QueryContextType, SparkThrowable}
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide
 import org.apache.spark.util.ArrayImplicits._
@@ -53,6 +56,158 @@ abstract class QueryTest extends ConnectFunSuite with 
SQLHelper {
     checkAnswer(df, expectedAnswer.toImmutableArraySeq)
   }
 
+  case class ExpectedContext(
+      contextType: QueryContextType,
+      objectType: String,
+      objectName: String,
+      startIndex: Int,
+      stopIndex: Int,
+      fragment: String,
+      callSitePattern: String)
+
+  object ExpectedContext {
+    def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
+      ExpectedContext("", "", start, stop, fragment)
+    }
+
+    def apply(
+        objectType: String,
+        objectName: String,
+        startIndex: Int,
+        stopIndex: Int,
+        fragment: String): ExpectedContext = {
+      new ExpectedContext(
+        QueryContextType.SQL,
+        objectType,
+        objectName,
+        startIndex,
+        stopIndex,
+        fragment,
+        "")
+    }
+
+    def apply(fragment: String, callSitePattern: String): ExpectedContext = {
+      new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, 
fragment, callSitePattern)
+    }
+  }
+
+  /**
+   * Checks an exception with an error condition against expected results.
+   * @param exception
+   *   The exception to check
+   * @param condition
+   *   The expected error condition identifying the error
+   * @param sqlState
+   *   Optional the expected SQLSTATE, not verified if not supplied
+   * @param parameters
+   *   A map of parameter names and values. The names are as defined in the 
error-classes file.
+   * @param matchPVals
+   *   Optionally treat the parameters value as regular expression pattern. 
false if not supplied.
+   */
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      sqlState: Option[String] = None,
+      parameters: Map[String, String] = Map.empty,
+      matchPVals: Boolean = false,
+      queryContext: Array[ExpectedContext] = Array.empty): Unit = {
+    assert(exception.getCondition === condition)
+    sqlState.foreach(state => assert(exception.getSqlState === state))
+    val expectedParameters = exception.getMessageParameters.asScala
+    if (matchPVals) {
+      assert(expectedParameters.size === parameters.size)
+      expectedParameters.foreach(exp => {
+        val parm = parameters.getOrElse(
+          exp._1,
+          throw new IllegalArgumentException("Missing parameter" + exp._1))
+        if (!exp._2.matches(parm)) {
+          throw new IllegalArgumentException(
+            "For parameter '" + exp._1 + "' value '" + exp._2 +
+              "' does not match: " + parm)
+        }
+      })
+    } else {
+      assert(expectedParameters === parameters)
+    }
+    val actualQueryContext = exception.getQueryContext()
+    assert(
+      actualQueryContext.length === queryContext.length,
+      "Invalid length of the query context")
+    actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
+      assert(
+        actual.contextType() === expected.contextType,
+        "Invalid contextType of a query context Actual:" + actual.toString)
+      if (actual.contextType() == QueryContextType.SQL) {
+        assert(
+          actual.objectType() === expected.objectType,
+          "Invalid objectType of a query context Actual:" + actual.toString)
+        assert(
+          actual.objectName() === expected.objectName,
+          "Invalid objectName of a query context. Actual:" + actual.toString)
+        assert(
+          actual.startIndex() === expected.startIndex,
+          "Invalid startIndex of a query context. Actual:" + actual.toString)
+        assert(
+          actual.stopIndex() === expected.stopIndex,
+          "Invalid stopIndex of a query context. Actual:" + actual.toString)
+        assert(
+          actual.fragment() === expected.fragment,
+          "Invalid fragment of a query context. Actual:" + actual.toString)
+      } else if (actual.contextType() == QueryContextType.DataFrame) {
+        assert(
+          actual.fragment() === expected.fragment,
+          "Invalid code fragment of a query context. Actual:" + 
actual.toString)
+        if (expected.callSitePattern.nonEmpty) {
+          assert(
+            actual.callSite().matches(expected.callSitePattern),
+            "Invalid callSite of a query context. Actual:" + actual.toString)
+        }
+      }
+    }
+  }
+
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      sqlState: String,
+      parameters: Map[String, String]): Unit =
+    checkError(exception, condition, Some(sqlState), parameters)
+
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      sqlState: String,
+      parameters: Map[String, String],
+      context: ExpectedContext): Unit =
+    checkError(exception, condition, Some(sqlState), parameters, false, 
Array(context))
+
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      parameters: Map[String, String],
+      context: ExpectedContext): Unit =
+    checkError(exception, condition, None, parameters, false, Array(context))
+
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      sqlState: String,
+      context: ExpectedContext): Unit =
+    checkError(exception, condition, Some(sqlState), Map.empty, false, 
Array(context))
+
+  protected def checkError(
+      exception: SparkThrowable,
+      condition: String,
+      sqlState: Option[String],
+      parameters: Map[String, String],
+      context: ExpectedContext): Unit =
+    checkError(exception, condition, sqlState, parameters, false, 
Array(context))
+
+  protected def getCurrentClassCallSitePattern: String = {
+    val cs = Thread.currentThread().getStackTrace()(2)
+    s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)"
+  }
+
   /**
    * Evaluates a dataset to make sure that the result of calling collect 
matches the given
    * expected answer.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
index 007d4f0648e4..d9828ae92267 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
@@ -22,6 +22,7 @@ import java.util.UUID
 import org.scalatest.Assertions.fail
 
 import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession, 
SQLImplicits}
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
 import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils}
 
 trait SQLHelper {
@@ -110,6 +111,22 @@ trait SQLHelper {
     finally SparkFileUtils.deleteRecursively(path)
   }
 
+  /**
+   * Drops temporary view `viewNames` after calling `f`.
+   */
+  protected def withTempView(viewNames: String*)(f: => Unit): Unit = {
+    SparkErrorUtils.tryWithSafeFinally(f) {
+      viewNames.foreach { viewName =>
+        try spark.catalog.dropTempView(viewName)
+        catch {
+          // If the test failed part way, we don't want to mask the failure by 
failing to remove
+          // temp views that never got created.
+          case _: NoSuchTableException =>
+        }
+      }
+    }
+  }
+
   /**
    * Drops table `tableName` after calling `f`.
    */
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index ee1886b8ef29..185ddc88cd08 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -79,6 +79,7 @@ from pyspark.sql.connect.streaming.readwriter import 
DataStreamWriter
 from pyspark.sql.column import Column
 from pyspark.sql.connect.expressions import (
     ColumnReference,
+    SubqueryExpression,
     UnresolvedRegex,
     UnresolvedStar,
 )
@@ -1801,18 +1802,14 @@ class DataFrame(ParentDataFrame):
         )
 
     def scalar(self) -> Column:
-        # TODO(SPARK-50134): Implement this method
-        raise PySparkNotImplementedError(
-            errorClass="NOT_IMPLEMENTED",
-            messageParameters={"feature": "scalar()"},
-        )
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        return ConnectColumn(SubqueryExpression(self._plan, 
subquery_type="scalar"))
 
     def exists(self) -> Column:
-        # TODO(SPARK-50134): Implement this method
-        raise PySparkNotImplementedError(
-            errorClass="NOT_IMPLEMENTED",
-            messageParameters={"feature": "exists()"},
-        )
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        return ConnectColumn(SubqueryExpression(self._plan, 
subquery_type="exists"))
 
     @property
     def schema(self) -> StructType:
@@ -2278,10 +2275,6 @@ def _test() -> None:
         del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
         del pyspark.sql.dataframe.DataFrame.rdd.__doc__
 
-    # TODO(SPARK-50134): Support subquery in connect
-    del pyspark.sql.dataframe.DataFrame.scalar.__doc__
-    del pyspark.sql.dataframe.DataFrame.exists.__doc__
-
     globs["spark"] = (
         PySparkSession.builder.appName("sql.connect.dataframe tests")
         .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 5d7b348f6d38..413a69181683 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -82,6 +82,7 @@ from pyspark.sql.utils import is_timestamp_ntz_preferred, 
enum_to_value
 if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
     from pyspark.sql.connect.window import WindowSpec
+    from pyspark.sql.connect.plan import LogicalPlan
 
 
 class Expression:
@@ -128,6 +129,15 @@ class Expression:
             plan.common.origin.CopyFrom(self.origin)
         return plan
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return []
+
+    def foreach(self, f: Callable[["Expression"], None]) -> None:
+        f(self)
+        for c in self.children:
+            c.foreach(f)
+
 
 class CaseWhen(Expression):
     def __init__(
@@ -162,6 +172,16 @@ class CaseWhen(Expression):
 
         return unresolved_function.to_plan(session)
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        children = []
+        for branch in self._branches:
+            children.append(branch[0])
+            children.append(branch[1])
+        if self._else_value is not None:
+            children.append(self._else_value)
+        return children
+
     def __repr__(self) -> str:
         _cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
         _else = f" ELSE {self._else_value}" if self._else_value is not None 
else ""
@@ -196,6 +216,10 @@ class ColumnAlias(Expression):
             exp.alias.expr.CopyFrom(self._child.to_plan(session))
             return exp
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._child]
+
     def __repr__(self) -> str:
         return f"{self._child} AS {','.join(self._alias)}"
 
@@ -622,6 +646,10 @@ class SortOrder(Expression):
 
         return sort
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._child]
+
 
 class UnresolvedFunction(Expression):
     def __init__(
@@ -649,6 +677,10 @@ class UnresolvedFunction(Expression):
         fun.unresolved_function.is_distinct = self._is_distinct
         return fun
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return self._args
+
     def __repr__(self) -> str:
         # Default print handling:
         if self._is_distinct:
@@ -730,12 +762,12 @@ class CommonInlineUserDefinedFunction(Expression):
         function_name: str,
         function: Union[PythonUDF, JavaUDF],
         deterministic: bool = False,
-        arguments: Sequence[Expression] = [],
+        arguments: Optional[Sequence[Expression]] = None,
     ):
         super().__init__()
         self._function_name = function_name
         self._deterministic = deterministic
-        self._arguments = arguments
+        self._arguments: Sequence[Expression] = arguments or []
         self._function = function
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
@@ -770,6 +802,10 @@ class CommonInlineUserDefinedFunction(Expression):
         expr.java_udf.CopyFrom(cast(proto.JavaUDF, 
self._function.to_plan(session)))
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return self._arguments
+
     def __repr__(self) -> str:
         return f"{self._function_name}({', '.join([str(arg) for arg in 
self._arguments])})"
 
@@ -799,6 +835,10 @@ class WithField(Expression):
         
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._structExpr, self._valueExpr]
+
     def __repr__(self) -> str:
         return f"update_field({self._structExpr}, {self._fieldName}, 
{self._valueExpr})"
 
@@ -823,6 +863,10 @@ class DropField(Expression):
         expr.update_fields.field_name = self._fieldName
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._structExpr]
+
     def __repr__(self) -> str:
         return f"drop_field({self._structExpr}, {self._fieldName})"
 
@@ -847,6 +891,10 @@ class UnresolvedExtractValue(Expression):
         
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._child, self._extraction]
+
     def __repr__(self) -> str:
         return f"{self._child}['{self._extraction}']"
 
@@ -906,6 +954,10 @@ class CastExpression(Expression):
 
         return fun
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._expr]
+
     def __repr__(self) -> str:
         # We cannot guarantee the string representations be exactly the same, 
e.g.
         # str(sf.col("a").cast("long")):
@@ -989,6 +1041,10 @@ class LambdaFunction(Expression):
         )
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._function] + self._arguments
+
     def __repr__(self) -> str:
         return (
             f"LambdaFunction({str(self._function)}, "
@@ -1098,6 +1154,12 @@ class WindowExpression(Expression):
 
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return (
+            [self._windowFunction] + self._windowSpec._partitionSpec + 
self._windowSpec._orderSpec
+        )
+
     def __repr__(self) -> str:
         return f"WindowExpression({str(self._windowFunction)}, 
({str(self._windowSpec)}))"
 
@@ -1128,6 +1190,10 @@ class CallFunction(Expression):
             expr.call_function.arguments.extend([arg.to_plan(session) for arg 
in self._args])
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return self._args
+
     def __repr__(self) -> str:
         if len(self._args) > 0:
             return f"CallFunction('{self._name}', {', '.join([str(arg) for arg 
in self._args])})"
@@ -1151,6 +1217,10 @@ class NamedArgumentExpression(Expression):
         
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._value]
+
     def __repr__(self) -> str:
         return f"{self._key} => {self._value}"
 
@@ -1166,5 +1236,31 @@ class LazyExpression(Expression):
         expr.lazy_expression.child.CopyFrom(self._expr.to_plan(session))
         return expr
 
+    @property
+    def children(self) -> Sequence["Expression"]:
+        return [self._expr]
+
     def __repr__(self) -> str:
         return f"lazy({self._expr})"
+
+
+class SubqueryExpression(Expression):
+    def __init__(self, plan: "LogicalPlan", subquery_type: str) -> None:
+        assert isinstance(subquery_type, str)
+        assert subquery_type in ("scalar", "exists")
+
+        super().__init__()
+        self._plan = plan
+        self._subquery_type = subquery_type
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+        expr = self._create_proto_expression()
+        expr.subquery_expression.plan_id = self._plan._plan_id
+        if self._subquery_type == "scalar":
+            expr.subquery_expression.subquery_type = 
proto.SubqueryExpression.SUBQUERY_TYPE_SCALAR
+        elif self._subquery_type == "exists":
+            expr.subquery_expression.subquery_type = 
proto.SubqueryExpression.SUBQUERY_TYPE_EXISTS
+        return expr
+
+    def __repr__(self) -> str:
+        return f"SubqueryExpression({self._plan}, {self._subquery_type})"
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index c411baf17ce9..02b60381ab93 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -52,7 +52,7 @@ from pyspark.sql.column import Column
 from pyspark.sql.connect.logging import logger
 from pyspark.sql.connect.proto import base_pb2 as 
spark_dot_connect_dot_base__pb2
 from pyspark.sql.connect.conversion import storage_level_to_proto
-from pyspark.sql.connect.expressions import Expression
+from pyspark.sql.connect.expressions import Expression, SubqueryExpression
 from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
 from pyspark.errors import (
     AnalysisException,
@@ -73,9 +73,30 @@ class LogicalPlan:
 
     INDENT = 2
 
-    def __init__(self, child: Optional["LogicalPlan"]) -> None:
+    def __init__(
+        self, child: Optional["LogicalPlan"], references: 
Optional[Sequence["LogicalPlan"]] = None
+    ) -> None:
+        """
+
+        Parameters
+        ----------
+        child : :class:`LogicalPlan`, optional.
+            The child logical plan.
+        references : list of :class:`LogicalPlan`, optional.
+            The list of logical plans that are referenced as subqueries in 
this logical plan.
+        """
         self._child = child
-        self._plan_id = LogicalPlan._fresh_plan_id()
+        self._root_plan_id = LogicalPlan._fresh_plan_id()
+
+        self._references: Sequence["LogicalPlan"] = references or []
+        self._plan_id_with_rel: Optional[int] = None
+        if len(self._references) > 0:
+            assert all(isinstance(r, LogicalPlan) for r in self._references)
+            self._plan_id_with_rel = LogicalPlan._fresh_plan_id()
+
+    @property
+    def _plan_id(self) -> int:
+        return self._plan_id_with_rel or self._root_plan_id
 
     @staticmethod
     def _fresh_plan_id() -> int:
@@ -89,7 +110,7 @@ class LogicalPlan:
 
     def _create_proto_relation(self) -> proto.Relation:
         plan = proto.Relation()
-        plan.common.plan_id = self._plan_id
+        plan.common.plan_id = self._root_plan_id
         return plan
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:  # type: 
ignore[empty-body]
@@ -136,6 +157,42 @@ class LogicalPlan:
         else:
             return self._child.observations
 
+    @staticmethod
+    def _collect_references(
+        cols_or_exprs: Sequence[Union[Column, Expression]]
+    ) -> Sequence["LogicalPlan"]:
+        references: List[LogicalPlan] = []
+
+        def append_reference(e: Expression) -> None:
+            if isinstance(e, SubqueryExpression):
+                references.append(e._plan)
+
+        for col_or_expr in cols_or_exprs:
+            if isinstance(col_or_expr, Column):
+                col_or_expr._expr.foreach(append_reference)
+            else:
+                col_or_expr.foreach(append_reference)
+        return references
+
+    def _with_relations(
+        self, root: proto.Relation, session: "SparkConnectClient"
+    ) -> proto.Relation:
+        if len(self._references) == 0:
+            return root
+        else:
+            # When there are references to other DataFrame, e.g., subqueries, 
build new plan like:
+            # with_relations [id 10]
+            #     root: plan  [id 9]
+            #     reference:
+            #          refs#1: [id 8]
+            #          refs#2: [id 5]
+            plan = proto.Relation()
+            assert isinstance(self._plan_id_with_rel, int)
+            plan.common.plan_id = self._plan_id_with_rel
+            plan.with_relations.root.CopyFrom(root)
+            plan.with_relations.references.extend([ref.plan(session) for ref 
in self._references])
+            return plan
+
     def _parameters_to_print(self, parameters: Mapping[str, Any]) -> 
Mapping[str, Any]:
         """
         Extracts the parameters that are able to be printed. It looks up the 
signature
@@ -192,6 +249,7 @@ class LogicalPlan:
                 getattr(a, "__forward_arg__", "").endswith("LogicalPlan")
                 for a in getattr(tpe.annotation, "__args__", ())
             )
+
             if (
                 not is_logical_plan
                 and not is_forwardref_logical_plan
@@ -473,8 +531,8 @@ class Project(LogicalPlan):
         child: Optional["LogicalPlan"],
         columns: List[Column],
     ) -> None:
-        super().__init__(child)
         assert all(isinstance(c, Column) for c in columns)
+        super().__init__(child, self._collect_references(columns))
         self._columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -482,7 +540,8 @@ class Project(LogicalPlan):
         plan = self._create_proto_relation()
         plan.project.input.CopyFrom(self._child.plan(session))
         plan.project.expressions.extend([c.to_plan(session) for c in 
self._columns])
-        return plan
+
+        return self._with_relations(plan, session)
 
 
 class WithColumns(LogicalPlan):
@@ -495,8 +554,6 @@ class WithColumns(LogicalPlan):
         columns: Sequence[Column],
         metadata: Optional[Sequence[str]] = None,
     ) -> None:
-        super().__init__(child)
-
         assert isinstance(columnNames, list)
         assert len(columnNames) > 0
         assert all(isinstance(c, str) for c in columnNames)
@@ -513,6 +570,8 @@ class WithColumns(LogicalPlan):
                 # validate json string
                 assert m == "" or json.loads(m) is not None
 
+        super().__init__(child, self._collect_references(columns))
+
         self._columnNames = columnNames
         self._columns = columns
         self._metadata = metadata
@@ -530,7 +589,7 @@ class WithColumns(LogicalPlan):
                 alias.metadata = self._metadata[i]
             plan.with_columns.aliases.append(alias)
 
-        return plan
+        return self._with_relations(plan, session)
 
 
 class WithWatermark(LogicalPlan):
@@ -608,16 +667,14 @@ class Hint(LogicalPlan):
         name: str,
         parameters: Sequence[Column],
     ) -> None:
-        super().__init__(child)
-
         assert isinstance(name, str)
 
-        self._name = name
-
         assert parameters is not None and isinstance(parameters, List)
         for param in parameters:
             assert isinstance(param, Column)
 
+        super().__init__(child, self._collect_references(parameters))
+        self._name = name
         self._parameters = parameters
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -626,12 +683,12 @@ class Hint(LogicalPlan):
         plan.hint.input.CopyFrom(self._child.plan(session))
         plan.hint.name = self._name
         plan.hint.parameters.extend([param.to_plan(session) for param in 
self._parameters])
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Filter(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], filter: Column) -> None:
-        super().__init__(child)
+        super().__init__(child, self._collect_references([filter]))
         self.filter = filter
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -639,7 +696,7 @@ class Filter(LogicalPlan):
         plan = self._create_proto_relation()
         plan.filter.input.CopyFrom(self._child.plan(session))
         plan.filter.condition.CopyFrom(self.filter.to_plan(session))
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Limit(LogicalPlan):
@@ -712,11 +769,10 @@ class Sort(LogicalPlan):
         columns: List[Column],
         is_global: bool,
     ) -> None:
-        super().__init__(child)
-
         assert all(isinstance(c, Column) for c in columns)
         assert isinstance(is_global, bool)
 
+        super().__init__(child, self._collect_references(columns))
         self.columns = columns
         self.is_global = is_global
 
@@ -726,7 +782,7 @@ class Sort(LogicalPlan):
         plan.sort.input.CopyFrom(self._child.plan(session))
         plan.sort.order.extend([c.to_plan(session).sort_order for c in 
self.columns])
         plan.sort.is_global = self.is_global
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Drop(LogicalPlan):
@@ -735,9 +791,12 @@ class Drop(LogicalPlan):
         child: Optional["LogicalPlan"],
         columns: List[Union[Column, str]],
     ) -> None:
-        super().__init__(child)
         if len(columns) > 0:
             assert all(isinstance(c, (Column, str)) for c in columns)
+
+        super().__init__(
+            child, self._collect_references([c for c in columns if 
isinstance(c, Column)])
+        )
         self._columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -749,7 +808,7 @@ class Drop(LogicalPlan):
                 plan.drop.columns.append(c.to_plan(session))
             else:
                 plan.drop.column_names.append(c)
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Sample(LogicalPlan):
@@ -792,8 +851,6 @@ class Aggregate(LogicalPlan):
         pivot_values: Optional[Sequence[Column]],
         grouping_sets: Optional[Sequence[Sequence[Column]]],
     ) -> None:
-        super().__init__(child)
-
         assert isinstance(group_type, str) and group_type in [
             "groupby",
             "rollup",
@@ -801,15 +858,12 @@ class Aggregate(LogicalPlan):
             "pivot",
             "grouping_sets",
         ]
-        self._group_type = group_type
 
         assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
-        self._grouping_cols = grouping_cols
 
         assert isinstance(aggregate_cols, list) and all(
             isinstance(c, Column) for c in aggregate_cols
         )
-        self._aggregate_cols = aggregate_cols
 
         if group_type == "pivot":
             assert pivot_col is not None and isinstance(pivot_col, Column)
@@ -821,6 +875,19 @@ class Aggregate(LogicalPlan):
             assert pivot_values is None
             assert grouping_sets is None
 
+        super().__init__(
+            child,
+            self._collect_references(
+                grouping_cols
+                + aggregate_cols
+                + ([pivot_col] if pivot_col is not None else [])
+                + (pivot_values if pivot_values is not None else [])
+                + ([g for gs in grouping_sets for g in gs] if grouping_sets is 
not None else [])
+            ),
+        )
+        self._group_type = group_type
+        self._grouping_cols = grouping_cols
+        self._aggregate_cols = aggregate_cols
         self._pivot_col = pivot_col
         self._pivot_values = pivot_values
         self._grouping_sets = grouping_sets
@@ -859,7 +926,7 @@ class Aggregate(LogicalPlan):
                         grouping_set=[c.to_plan(session) for c in grouping_set]
                     )
                 )
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Join(LogicalPlan):
@@ -870,7 +937,16 @@ class Join(LogicalPlan):
         on: Optional[Union[str, List[str], Column, List[Column]]],
         how: Optional[str],
     ) -> None:
-        super().__init__(left)
+        super().__init__(
+            left,
+            self._collect_references(
+                []
+                if on is None or isinstance(on, str)
+                else [on]
+                if isinstance(on, Column)
+                else [c for c in on if isinstance(c, Column)]
+            ),
+        )
         self.left = cast(LogicalPlan, left)
         self.right = right
         self.on = on
@@ -942,7 +1018,7 @@ class Join(LogicalPlan):
                     merge_column = functools.reduce(lambda c1, c2: c1 & c2, 
self.on)
                     plan.join.join_condition.CopyFrom(cast(Column, 
merge_column).to_plan(session))
         plan.join.join_type = self.how
-        return plan
+        return self._with_relations(plan, session)
 
     @property
     def observations(self) -> Dict[str, "Observation"]:
@@ -982,7 +1058,20 @@ class AsOfJoin(LogicalPlan):
         allow_exact_matches: bool,
         direction: str,
     ) -> None:
-        super().__init__(left)
+        super().__init__(
+            left,
+            self._collect_references(
+                [left_as_of, right_as_of]
+                + (
+                    []
+                    if on is None or isinstance(on, str)
+                    else [on]
+                    if isinstance(on, Column)
+                    else [c for c in on if isinstance(c, Column)]
+                )
+                + ([tolerance] if tolerance is not None else [])
+            ),
+        )
         self.left = left
         self.right = right
         self.left_as_of = left_as_of
@@ -1022,7 +1111,7 @@ class AsOfJoin(LogicalPlan):
         plan.as_of_join.allow_exact_matches = self.allow_exact_matches
         plan.as_of_join.direction = self.direction
 
-        return plan
+        return self._with_relations(plan, session)
 
     @property
     def observations(self) -> Dict[str, "Observation"]:
@@ -1064,7 +1153,7 @@ class LateralJoin(LogicalPlan):
         on: Optional[Column],
         how: Optional[str],
     ) -> None:
-        super().__init__(left)
+        super().__init__(left, self._collect_references([on] if on is not None 
else []))
         self.left = cast(LogicalPlan, left)
         self.right = right
         self.on = on
@@ -1097,7 +1186,7 @@ class LateralJoin(LogicalPlan):
         if self.on is not None:
             plan.lateral_join.join_condition.CopyFrom(self.on.to_plan(session))
         plan.lateral_join.join_type = self.how
-        return plan
+        return self._with_relations(plan, session)
 
     @property
     def observations(self) -> Dict[str, "Observation"]:
@@ -1225,9 +1314,9 @@ class RepartitionByExpression(LogicalPlan):
         num_partitions: Optional[int],
         columns: List[Column],
     ) -> None:
-        super().__init__(child)
-        self.num_partitions = num_partitions
         assert all(isinstance(c, Column) for c in columns)
+        super().__init__(child, self._collect_references(columns))
+        self.num_partitions = num_partitions
         self.columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1240,7 +1329,7 @@ class RepartitionByExpression(LogicalPlan):
             
plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
         if self.num_partitions is not None:
             plan.repartition_by_expression.num_partitions = self.num_partitions
-        return plan
+        return self._with_relations(plan, session)
 
 
 class SubqueryAlias(LogicalPlan):
@@ -1286,8 +1375,6 @@ class SQL(LogicalPlan):
         named_args: Optional[Dict[str, Column]] = None,
         views: Optional[Sequence[SubqueryAlias]] = None,
     ) -> None:
-        super().__init__(None)
-
         if args is not None:
             assert isinstance(args, List)
             assert all(isinstance(arg, Column) for arg in args)
@@ -1301,10 +1388,8 @@ class SQL(LogicalPlan):
         if views is not None:
             assert isinstance(views, List)
             assert all(isinstance(v, SubqueryAlias) for v in views)
-            if len(views) > 0:
-                # reserved plan id for WithRelations
-                self._plan_id_with_rel = LogicalPlan._fresh_plan_id()
 
+        super().__init__(None, views)
         self._query = query
         self._args = args
         self._named_args = named_args
@@ -1320,20 +1405,7 @@ class SQL(LogicalPlan):
             for k, arg in self._named_args.items():
                 plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session))
 
-        if self._views is not None and len(self._views) > 0:
-            # build new plan like
-            # with_relations [id 10]
-            #     root: sql  [id 9]
-            #     reference:
-            #          view#1: [id 8]
-            #          view#2: [id 5]
-            sql_plan = plan
-            plan = proto.Relation()
-            plan.common.plan_id = self._plan_id_with_rel
-            plan.with_relations.root.CopyFrom(sql_plan)
-            plan.with_relations.references.extend([v.plan(session) for v in 
self._views])
-
-        return plan
+        return self._with_relations(plan, session)
 
     def command(self, session: "SparkConnectClient") -> proto.Command:
         cmd = proto.Command()
@@ -1407,7 +1479,7 @@ class Unpivot(LogicalPlan):
         variable_column_name: str,
         value_column_name: str,
     ) -> None:
-        super().__init__(child)
+        super().__init__(child, self._collect_references(ids + (values or [])))
         self.ids = ids
         self.values = values
         self.variable_column_name = variable_column_name
@@ -1422,7 +1494,7 @@ class Unpivot(LogicalPlan):
             plan.unpivot.values.values.extend([v.to_plan(session) for v in 
self.values])
         plan.unpivot.variable_column_name = self.variable_column_name
         plan.unpivot.value_column_name = self.value_column_name
-        return plan
+        return self._with_relations(plan, session)
 
 
 class Transpose(LogicalPlan):
@@ -1433,7 +1505,7 @@ class Transpose(LogicalPlan):
         child: Optional["LogicalPlan"],
         index_columns: Sequence[Column],
     ) -> None:
-        super().__init__(child)
+        super().__init__(child, self._collect_references(index_columns))
         self.index_columns = index_columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1443,12 +1515,12 @@ class Transpose(LogicalPlan):
         if self.index_columns is not None and len(self.index_columns) > 0:
             for index_column in self.index_columns:
                 
plan.transpose.index_columns.append(index_column.to_plan(session))
-        return plan
+        return self._with_relations(plan, session)
 
 
 class UnresolvedTableValuedFunction(LogicalPlan):
     def __init__(self, name: str, args: Sequence[Column]):
-        super().__init__(None)
+        super().__init__(None, self._collect_references(args))
         self._name = name
         self._args = args
 
@@ -1457,7 +1529,7 @@ class UnresolvedTableValuedFunction(LogicalPlan):
         plan.unresolved_table_valued_function.function_name = self._name
         for arg in self._args:
             
plan.unresolved_table_valued_function.arguments.append(arg.to_plan(session))
-        return plan
+        return self._with_relations(plan, session)
 
 
 class CollectMetrics(LogicalPlan):
@@ -1469,9 +1541,9 @@ class CollectMetrics(LogicalPlan):
         observation: Union[str, "Observation"],
         exprs: List[Column],
     ) -> None:
-        super().__init__(child)
-        self._observation = observation
         assert all(isinstance(e, Column) for e in exprs)
+        super().__init__(child, self._collect_references(exprs))
+        self._observation = observation
         self._exprs = exprs
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1484,7 +1556,7 @@ class CollectMetrics(LogicalPlan):
             else str(self._observation._name)
         )
         plan.collect_metrics.metrics.extend([e.to_plan(session) for e in 
self._exprs])
-        return plan
+        return self._with_relations(plan, session)
 
     @property
     def observations(self) -> Dict[str, "Observation"]:
@@ -1569,13 +1641,13 @@ class NAReplace(LogicalPlan):
         cols: Optional[List[str]],
         replacements: Sequence[Tuple[Column, Column]],
     ) -> None:
-        super().__init__(child)
-        self.cols = cols
-
         assert replacements is not None and isinstance(replacements, List)
         for k, v in replacements:
             assert k is not None and isinstance(k, Column)
             assert v is not None and isinstance(v, Column)
+
+        super().__init__(child, self._collect_references([e for t in 
replacements for e in t]))
+        self.cols = cols
         self.replacements = replacements
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1590,7 +1662,7 @@ class NAReplace(LogicalPlan):
                 
replacement.old_value.CopyFrom(old_value.to_plan(session).literal)
                 
replacement.new_value.CopyFrom(new_value.to_plan(session).literal)
                 plan.replace.replacements.append(replacement)
-        return plan
+        return self._with_relations(plan, session)
 
 
 class StatSummary(LogicalPlan):
@@ -1700,8 +1772,6 @@ class StatSampleBy(LogicalPlan):
         fractions: Sequence[Tuple[Column, float]],
         seed: int,
     ) -> None:
-        super().__init__(child)
-
         assert col is not None and isinstance(col, (Column, str))
 
         assert fractions is not None and isinstance(fractions, List)
@@ -1711,6 +1781,12 @@ class StatSampleBy(LogicalPlan):
 
         assert seed is None or isinstance(seed, int)
 
+        super().__init__(
+            child,
+            self._collect_references(
+                [col] if isinstance(col, Column) else [] + [c for c, _ in 
fractions]
+            ),
+        )
         self._col = col
         self._fractions = fractions
         self._seed = seed
@@ -1727,7 +1803,7 @@ class StatSampleBy(LogicalPlan):
                 fraction.fraction = float(v)
                 plan.sample_by.fractions.append(fraction)
         plan.sample_by.seed = self._seed
-        return plan
+        return self._with_relations(plan, session)
 
 
 class StatCorr(LogicalPlan):
@@ -2375,7 +2451,7 @@ class GroupMap(LogicalPlan):
     ):
         assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
 
-        super().__init__(child)
+        super().__init__(child, self._collect_references(grouping_cols))
         self._grouping_cols = grouping_cols
         self._function = 
function._build_common_inline_user_defined_function(*cols)
 
@@ -2387,7 +2463,7 @@ class GroupMap(LogicalPlan):
             [c.to_plan(session) for c in self._grouping_cols]
         )
         plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
-        return plan
+        return self._with_relations(plan, session)
 
 
 class CoGroupMap(LogicalPlan):
@@ -2408,7 +2484,7 @@ class CoGroupMap(LogicalPlan):
             isinstance(c, Column) for c in other_grouping_cols
         )
 
-        super().__init__(input)
+        super().__init__(input, self._collect_references(input_grouping_cols + 
other_grouping_cols))
         self._input_grouping_cols = input_grouping_cols
         self._other_grouping_cols = other_grouping_cols
         self._other = cast(LogicalPlan, other)
@@ -2428,7 +2504,7 @@ class CoGroupMap(LogicalPlan):
             [c.to_plan(session) for c in self._other_grouping_cols]
         )
         plan.co_group_map.func.CopyFrom(self._function.to_plan_udf(session))
-        return plan
+        return self._with_relations(plan, session)
 
 
 class ApplyInPandasWithState(LogicalPlan):
@@ -2447,7 +2523,7 @@ class ApplyInPandasWithState(LogicalPlan):
     ):
         assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
 
-        super().__init__(child)
+        super().__init__(child, self._collect_references(grouping_cols))
         self._grouping_cols = grouping_cols
         self._function = 
function._build_common_inline_user_defined_function(*cols)
         self._output_schema = output_schema
@@ -2467,7 +2543,7 @@ class ApplyInPandasWithState(LogicalPlan):
         plan.apply_in_pandas_with_state.state_schema = self._state_schema
         plan.apply_in_pandas_with_state.output_mode = self._output_mode
         plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
-        return plan
+        return self._with_relations(plan, session)
 
 
 class PythonUDTF:
@@ -2531,7 +2607,7 @@ class CommonInlineUserDefinedTableFunction(LogicalPlan):
         deterministic: bool,
         arguments: Sequence[Expression],
     ) -> None:
-        super().__init__(None)
+        super().__init__(None, self._collect_references(arguments))
         self._function_name = function_name
         self._deterministic = deterministic
         self._arguments = arguments
@@ -2548,7 +2624,7 @@ class CommonInlineUserDefinedTableFunction(LogicalPlan):
         plan.common_inline_user_defined_table_function.python_udtf.CopyFrom(
             self._function.to_plan(session)
         )
-        return plan
+        return self._with_relations(plan, session)
 
     def udtf_plan(
         self, session: "SparkConnectClient"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 87070fd5ad3c..093997a0d0c5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import common_pb2 as 
spark_dot_connect_dot_common
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x8b\x31\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved 
[...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xe1\x31\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved 
[...]
 )
 
 _globals = globals()
@@ -54,79 +54,83 @@ if not _descriptor._USE_C_DESCRIPTORS:
         "DESCRIPTOR"
     ]._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     _globals["_EXPRESSION"]._serialized_start = 133
-    _globals["_EXPRESSION"]._serialized_end = 6416
-    _globals["_EXPRESSION_WINDOW"]._serialized_start = 1974
-    _globals["_EXPRESSION_WINDOW"]._serialized_end = 2757
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2264
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_end = 2757
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_start 
= 2531
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_end = 
2676
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_start = 
2678
-    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_end = 2757
-    _globals["_EXPRESSION_SORTORDER"]._serialized_start = 2760
-    _globals["_EXPRESSION_SORTORDER"]._serialized_end = 3185
-    _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_start = 2990
-    _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_end = 3098
-    _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_start = 3100
-    _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_end = 3185
-    _globals["_EXPRESSION_CAST"]._serialized_start = 3188
-    _globals["_EXPRESSION_CAST"]._serialized_end = 3503
-    _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3389
-    _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3487
-    _globals["_EXPRESSION_LITERAL"]._serialized_start = 3506
-    _globals["_EXPRESSION_LITERAL"]._serialized_end = 5069
-    _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4341
-    _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4458
-    _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4460
-    _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4558
-    _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4561
-    _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 4691
-    _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 4694
-    _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 4921
-    _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 4924
-    _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5053
-    _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5072
-    _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5258
-    _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5261
-    _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 5465
-    _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 5467
-    _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 5517
-    _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 5519
-    _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 5643
-    _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 5645
-    _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 5731
-    _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 5734
-    _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 5866
-    _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 5869
-    _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6056
-    _globals["_EXPRESSION_ALIAS"]._serialized_start = 6058
-    _globals["_EXPRESSION_ALIAS"]._serialized_end = 6178
-    _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6181
-    _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 6339
-    _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 
6341
-    _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 
6403
-    _globals["_EXPRESSIONCOMMON"]._serialized_start = 6418
-    _globals["_EXPRESSIONCOMMON"]._serialized_end = 6483
-    _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 6486
-    _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 6850
-    _globals["_PYTHONUDF"]._serialized_start = 6853
-    _globals["_PYTHONUDF"]._serialized_end = 7057
-    _globals["_SCALARSCALAUDF"]._serialized_start = 7060
-    _globals["_SCALARSCALAUDF"]._serialized_end = 7274
-    _globals["_JAVAUDF"]._serialized_start = 7277
-    _globals["_JAVAUDF"]._serialized_end = 7426
-    _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 7428
-    _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 7527
-    _globals["_CALLFUNCTION"]._serialized_start = 7529
-    _globals["_CALLFUNCTION"]._serialized_end = 7637
-    _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 7639
-    _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 7731
-    _globals["_MERGEACTION"]._serialized_start = 7734
-    _globals["_MERGEACTION"]._serialized_end = 8246
-    _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 7956
-    _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8062
-    _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8065
-    _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8232
-    _globals["_LAZYEXPRESSION"]._serialized_start = 8248
-    _globals["_LAZYEXPRESSION"]._serialized_end = 8313
+    _globals["_EXPRESSION"]._serialized_end = 6502
+    _globals["_EXPRESSION_WINDOW"]._serialized_start = 2060
+    _globals["_EXPRESSION_WINDOW"]._serialized_end = 2843
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2350
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_end = 2843
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_start 
= 2617
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY"]._serialized_end = 
2762
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_start = 
2764
+    _globals["_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE"]._serialized_end = 2843
+    _globals["_EXPRESSION_SORTORDER"]._serialized_start = 2846
+    _globals["_EXPRESSION_SORTORDER"]._serialized_end = 3271
+    _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_start = 3076
+    _globals["_EXPRESSION_SORTORDER_SORTDIRECTION"]._serialized_end = 3184
+    _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_start = 3186
+    _globals["_EXPRESSION_SORTORDER_NULLORDERING"]._serialized_end = 3271
+    _globals["_EXPRESSION_CAST"]._serialized_start = 3274
+    _globals["_EXPRESSION_CAST"]._serialized_end = 3589
+    _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3475
+    _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3573
+    _globals["_EXPRESSION_LITERAL"]._serialized_start = 3592
+    _globals["_EXPRESSION_LITERAL"]._serialized_end = 5155
+    _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4427
+    _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4544
+    _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4546
+    _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4644
+    _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4647
+    _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 4777
+    _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 4780
+    _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 5007
+    _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 5010
+    _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5139
+    _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5158
+    _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5344
+    _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5347
+    _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 5551
+    _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 5553
+    _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 5603
+    _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 5605
+    _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 5729
+    _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 5731
+    _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 5817
+    _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 5820
+    _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 5952
+    _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 5955
+    _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6142
+    _globals["_EXPRESSION_ALIAS"]._serialized_start = 6144
+    _globals["_EXPRESSION_ALIAS"]._serialized_end = 6264
+    _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6267
+    _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 6425
+    _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 
6427
+    _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 
6489
+    _globals["_EXPRESSIONCOMMON"]._serialized_start = 6504
+    _globals["_EXPRESSIONCOMMON"]._serialized_end = 6569
+    _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 6572
+    _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 6936
+    _globals["_PYTHONUDF"]._serialized_start = 6939
+    _globals["_PYTHONUDF"]._serialized_end = 7143
+    _globals["_SCALARSCALAUDF"]._serialized_start = 7146
+    _globals["_SCALARSCALAUDF"]._serialized_end = 7360
+    _globals["_JAVAUDF"]._serialized_start = 7363
+    _globals["_JAVAUDF"]._serialized_end = 7512
+    _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 7514
+    _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 7613
+    _globals["_CALLFUNCTION"]._serialized_start = 7615
+    _globals["_CALLFUNCTION"]._serialized_end = 7723
+    _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 7725
+    _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 7817
+    _globals["_MERGEACTION"]._serialized_start = 7820
+    _globals["_MERGEACTION"]._serialized_end = 8332
+    _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 8042
+    _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8148
+    _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8151
+    _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8318
+    _globals["_LAZYEXPRESSION"]._serialized_start = 8334
+    _globals["_LAZYEXPRESSION"]._serialized_end = 8399
+    _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8402
+    _globals["_SUBQUERYEXPRESSION"]._serialized_end = 8627
+    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 8534
+    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 8627
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index df4106cfc5f7..0a6f3caee8b5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1185,6 +1185,7 @@ class Expression(google.protobuf.message.Message):
     MERGE_ACTION_FIELD_NUMBER: builtins.int
     TYPED_AGGREGATE_EXPRESSION_FIELD_NUMBER: builtins.int
     LAZY_EXPRESSION_FIELD_NUMBER: builtins.int
+    SUBQUERY_EXPRESSION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def common(self) -> global___ExpressionCommon: ...
@@ -1231,6 +1232,8 @@ class Expression(google.protobuf.message.Message):
     @property
     def lazy_expression(self) -> global___LazyExpression: ...
     @property
+    def subquery_expression(self) -> global___SubqueryExpression: ...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         relations they can add them here. During the planning the correct 
resolution is done.
@@ -1260,6 +1263,7 @@ class Expression(google.protobuf.message.Message):
         merge_action: global___MergeAction | None = ...,
         typed_aggregate_expression: global___TypedAggregateExpression | None = 
...,
         lazy_expression: global___LazyExpression | None = ...,
+        subquery_expression: global___SubqueryExpression | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -1293,6 +1297,8 @@ class Expression(google.protobuf.message.Message):
             b"named_argument_expression",
             "sort_order",
             b"sort_order",
+            "subquery_expression",
+            b"subquery_expression",
             "typed_aggregate_expression",
             b"typed_aggregate_expression",
             "unresolved_attribute",
@@ -1344,6 +1350,8 @@ class Expression(google.protobuf.message.Message):
             b"named_argument_expression",
             "sort_order",
             b"sort_order",
+            "subquery_expression",
+            b"subquery_expression",
             "typed_aggregate_expression",
             b"typed_aggregate_expression",
             "unresolved_attribute",
@@ -1388,6 +1396,7 @@ class Expression(google.protobuf.message.Message):
             "merge_action",
             "typed_aggregate_expression",
             "lazy_expression",
+            "subquery_expression",
             "extension",
         ]
         | None
@@ -1829,3 +1838,47 @@ class LazyExpression(google.protobuf.message.Message):
     def ClearField(self, field_name: typing_extensions.Literal["child", 
b"child"]) -> None: ...
 
 global___LazyExpression = LazyExpression
+
+class SubqueryExpression(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class _SubqueryType:
+        ValueType = typing.NewType("ValueType", builtins.int)
+        V: typing_extensions.TypeAlias = ValueType
+
+    class _SubqueryTypeEnumTypeWrapper(
+        google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+            SubqueryExpression._SubqueryType.ValueType
+        ],
+        builtins.type,
+    ):  # noqa: F821
+        DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+        SUBQUERY_TYPE_UNKNOWN: SubqueryExpression._SubqueryType.ValueType  # 0
+        SUBQUERY_TYPE_SCALAR: SubqueryExpression._SubqueryType.ValueType  # 1
+        SUBQUERY_TYPE_EXISTS: SubqueryExpression._SubqueryType.ValueType  # 2
+
+    class SubqueryType(_SubqueryType, metaclass=_SubqueryTypeEnumTypeWrapper): 
...
+    SUBQUERY_TYPE_UNKNOWN: SubqueryExpression.SubqueryType.ValueType  # 0
+    SUBQUERY_TYPE_SCALAR: SubqueryExpression.SubqueryType.ValueType  # 1
+    SUBQUERY_TYPE_EXISTS: SubqueryExpression.SubqueryType.ValueType  # 2
+
+    PLAN_ID_FIELD_NUMBER: builtins.int
+    SUBQUERY_TYPE_FIELD_NUMBER: builtins.int
+    plan_id: builtins.int
+    """(Required) The id of corresponding connect plan."""
+    subquery_type: global___SubqueryExpression.SubqueryType.ValueType
+    """(Required) The type of the subquery."""
+    def __init__(
+        self,
+        *,
+        plan_id: builtins.int = ...,
+        subquery_type: global___SubqueryExpression.SubqueryType.ValueType = 
...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "plan_id", b"plan_id", "subquery_type", b"subquery_type"
+        ],
+    ) -> None: ...
+
+global___SubqueryExpression = SubqueryExpression
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f88ca5348ff2..660f577f56f8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6612,7 +6612,7 @@ class DataFrame:
         >>> from pyspark.sql import functions as sf
         >>> employees.where(
         ...     sf.col("salary") > employees.select(sf.avg("salary")).scalar()
-        ... ).select("name", "salary", "department_id").show()
+        ... ).select("name", "salary", "department_id").orderBy("name").show()
         +-----+------+-------------+
         | name|salary|department_id|
         +-----+------+-------------+
@@ -6630,7 +6630,7 @@ class DataFrame:
         ...     > employees.alias("e2").where(
         ...         sf.col("e2.department_id") == 
sf.col("e1.department_id").outer()
         ...     ).select(sf.avg("salary")).scalar()
-        ... ).select("name", "salary", "department_id").show()
+        ... ).select("name", "salary", "department_id").orderBy("name").show()
         +-----+------+-------------+
         | name|salary|department_id|
         +-----+------+-------------+
@@ -6651,15 +6651,15 @@ class DataFrame:
         ...             
).select(sf.sum("salary")).scalar().alias("avg_salary"),
         ...         1
         ...     ).alias("salary_proportion_in_department")
-        ... ).show()
+        ... ).orderBy("name").show()
         +-------+------+-------------+-------------------------------+
         |   name|salary|department_id|salary_proportion_in_department|
         +-------+------+-------------+-------------------------------+
         |  Alice| 45000|          101|                           30.6|
         |    Bob| 54000|          101|                           36.7|
         |Charlie| 29000|          102|                           32.2|
-        |    Eve| 48000|          101|                           32.7|
         |  David| 61000|          102|                           67.8|
+        |    Eve| 48000|          101|                           32.7|
         +-------+------+-------------+-------------------------------+
         """
         ...
diff --git a/python/pyspark/sql/tests/connect/test_parity_subquery.py 
b/python/pyspark/sql/tests/connect/test_parity_subquery.py
index cffb6fc39059..dae60a354d20 100644
--- a/python/pyspark/sql/tests/connect/test_parity_subquery.py
+++ b/python/pyspark/sql/tests/connect/test_parity_subquery.py
@@ -17,42 +17,37 @@
 
 import unittest
 
+from pyspark.sql import functions as sf
 from pyspark.sql.tests.test_subquery import SubqueryTestsMixin
+from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class SubqueryParityTests(SubqueryTestsMixin, ReusedConnectTestCase):
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_simple_uncorrelated_scalar_subquery(self):
-        super().test_simple_uncorrelated_scalar_subquery()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_uncorrelated_scalar_subquery_with_view(self):
-        super().test_uncorrelated_scalar_subquery_with_view()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_scalar_subquery_against_local_relations(self):
-        super().test_scalar_subquery_against_local_relations()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_correlated_scalar_subquery(self):
-        super().test_correlated_scalar_subquery()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_exists_subquery(self):
-        super().test_exists_subquery()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_scalar_subquery_with_outer_reference_errors(self):
-        super().test_scalar_subquery_with_outer_reference_errors()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_scalar_subquery_inside_lateral_join(self):
-        super().test_scalar_subquery_inside_lateral_join()
-
-    @unittest.skip("TODO(SPARK-50134): Support subquery in connect")
-    def test_lateral_join_inside_subquery(self):
-        super().test_lateral_join_inside_subquery()
+    def test_scalar_subquery_with_missing_outer_reference(self):
+        with self.tempView("l", "r"):
+            self.df1.createOrReplaceTempView("l")
+            self.df2.createOrReplaceTempView("r")
+
+            assertDataFrameEqual(
+                self.spark.table("l").select(
+                    "a",
+                    (
+                        self.spark.table("r")
+                        .where(sf.col("c") == sf.col("a"))
+                        .select(sf.sum("d"))
+                        .scalar()
+                    ),
+                ),
+                self.spark.sql("""SELECT a, (SELECT sum(d) FROM r WHERE c = a) 
FROM l"""),
+            )
+
+    def test_subquery_in_unpivot(self):
+        self.check_subquery_in_unpivot(None, None)
+
+    @unittest.skip("SPARK-50601: Fix the SparkConnectPlanner to support this 
case")
+    def test_subquery_in_with_columns(self):
+        super().test_subquery_in_with_columns()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_subquery.py 
b/python/pyspark/sql/tests/test_subquery.py
index 91789f74d9da..0f431589b461 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -459,30 +459,29 @@ class SubqueryTestsMixin:
                     ),
                 )
 
-    def test_scalar_subquery_with_outer_reference_errors(self):
+    def test_scalar_subquery_with_missing_outer_reference(self):
         with self.tempView("l", "r"):
             self.df1.createOrReplaceTempView("l")
             self.df2.createOrReplaceTempView("r")
 
-            with self.subTest("missing `outer()`"):
-                with self.assertRaises(AnalysisException) as pe:
-                    self.spark.table("l").select(
-                        "a",
-                        (
-                            self.spark.table("r")
-                            .where(sf.col("c") == sf.col("a"))
-                            .select(sf.sum("d"))
-                            .scalar()
-                        ),
-                    ).collect()
+            with self.assertRaises(AnalysisException) as pe:
+                self.spark.table("l").select(
+                    "a",
+                    (
+                        self.spark.table("r")
+                        .where(sf.col("c") == sf.col("a"))
+                        .select(sf.sum("d"))
+                        .scalar()
+                    ),
+                ).collect()
 
-                self.check_error(
-                    exception=pe.exception,
-                    errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
-                    messageParameters={"objectName": "`a`", "proposal": "`c`, 
`d`"},
-                    query_context_type=QueryContextType.DataFrame,
-                    fragment="col",
-                )
+            self.check_error(
+                exception=pe.exception,
+                errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+                messageParameters={"objectName": "`a`", "proposal": "`c`, 
`d`"},
+                query_context_type=QueryContextType.DataFrame,
+                fragment="col",
+            )
 
     def table1(self):
         t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
@@ -833,6 +832,90 @@ class SubqueryTestsMixin:
                 ),
             )
 
+    def test_subquery_with_generator_and_tvf(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                
self.spark.range(1).select(sf.explode(t1.select(sf.collect_list("c2")).scalar())),
+                self.spark.sql("""SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM 
t1))"""),
+            )
+            assertDataFrameEqual(
+                
self.spark.tvf.explode(t1.select(sf.collect_list("c2")).scalar()),
+                self.spark.sql("""SELECT * FROM EXPLODE((SELECT 
COLLECT_LIST(c2) FROM t1))"""),
+            )
+
+    def test_subquery_in_join_condition(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.join(t2, sf.col("t1.c1") == 
t1.select(sf.max("c1")).scalar()),
+                self.spark.sql("""SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT 
MAX(c1) FROM t1)"""),
+            )
+
+    def test_subquery_in_unpivot(self):
+        self.check_subquery_in_unpivot(QueryContextType.DataFrame, "exists")
+
+    def check_subquery_in_unpivot(self, query_context_type, fragment):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            with self.assertRaises(AnalysisException) as pe:
+                t1.unpivot("c1", t2.exists(), "c1", "c2").collect()
+
+            self.check_error(
+                exception=pe.exception,
+                errorClass=(
+                    
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY"
+                ),
+                messageParameters={"treeNode": "Expand.*"},
+                query_context_type=query_context_type,
+                fragment=fragment,
+                matchPVals=True,
+            )
+
+    def test_subquery_in_transpose(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            with self.assertRaises(AnalysisException) as pe:
+                t1.transpose(t1.select(sf.max("c1")).scalar()).collect()
+
+            self.check_error(
+                exception=pe.exception,
+                errorClass="TRANSPOSE_INVALID_INDEX_COLUMN",
+                messageParameters={"reason": "Index column must be an atomic 
attribute"},
+            )
+
+    def test_subquery_in_with_columns(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                t1.withColumn(
+                    "scalar",
+                    self.spark.range(1)
+                    .select(sf.col("c1").outer() + sf.col("c2").outer())
+                    .scalar(),
+                ),
+                t1.withColumn("scalar", sf.col("c1") + sf.col("c2")),
+            )
+
+    def test_subquery_in_drop(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            
assertDataFrameEqual(t1.drop(self.spark.range(1).select(sf.lit("c1")).scalar()),
 t1)
+
+    def test_subquery_in_repartition(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            
assertDataFrameEqual(t1.repartition(self.spark.range(1).select(sf.lit(1)).scalar()),
 t1)
+
 
 class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index d5f097065dc5..233b432766b7 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -344,6 +344,7 @@ class PySparkErrorTestUtils:
         messageParameters: Optional[Dict[str, str]] = None,
         query_context_type: Optional[QueryContextType] = None,
         fragment: Optional[str] = None,
+        matchPVals: bool = False,
     ):
         query_context = exception.getQueryContext()
         assert bool(query_context) == (query_context_type is not None), (
@@ -367,9 +368,30 @@ class PySparkErrorTestUtils:
         # Test message parameters
         expected = messageParameters
         actual = exception.getMessageParameters()
-        self.assertEqual(
-            expected, actual, f"Expected message parameters was '{expected}', 
got '{actual}'"
-        )
+        if matchPVals:
+            self.assertEqual(
+                len(expected),
+                len(actual),
+                "Expected message parameters count does not match actual 
message parameters count"
+                f": {len(expected)}, {len(actual)}.",
+            )
+            for key, value in expected.items():
+                self.assertIn(
+                    key,
+                    actual,
+                    f"Expected message parameter key '{key}' was not found "
+                    "in actual message parameters.",
+                )
+                self.assertRegex(
+                    actual[key],
+                    value,
+                    f"Expected message parameter value '{value}' does not 
match actual message "
+                    f"parameter value '{actual[key]}'.",
+                ),
+        else:
+            self.assertEqual(
+                expected, actual, f"Expected message parameters was 
'{expected}', got '{actual}'"
+            )
 
         # Test query context
         if query_context:
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index f745c152170e..ef4bdb8d5bdf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -70,6 +70,19 @@ private[sql] trait ColumnNode extends ColumnNodeLike {
 trait ColumnNodeLike {
   private[internal] def normalize(): ColumnNodeLike = this
   private[internal] def sql: String
+  private[internal] def children: Seq[ColumnNodeLike]
+
+  private[sql] def foreach(f: ColumnNodeLike => Unit): Unit = {
+    f(this)
+    children.foreach(_.foreach(f))
+  }
+
+  private[sql] def collect[A](pf: PartialFunction[ColumnNodeLike, A]): Seq[A] 
= {
+    val ret = new collection.mutable.ArrayBuffer[A]()
+    val lifted = pf.lift
+    foreach(node => lifted(node).foreach(ret.+=))
+    ret.toSeq
+  }
 }
 
 private[internal] object ColumnNode {
@@ -118,6 +131,8 @@ private[sql] case class Literal(
     case v: Short => toSQLValue(v)
     case _ => value.toString
   }
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 /**
@@ -141,6 +156,8 @@ private[sql] case class UnresolvedAttribute(
     copy(planId = None, origin = NO_ORIGIN)
 
   override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`" 
else n).mkString(".")
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 private[sql] object UnresolvedAttribute {
@@ -183,6 +200,7 @@ private[sql] case class UnresolvedStar(
   override private[internal] def normalize(): UnresolvedStar =
     copy(planId = None, origin = NO_ORIGIN)
   override def sql: String = unparsedTarget.map(_ + ".*").getOrElse("*")
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 /**
@@ -208,6 +226,8 @@ private[sql] case class UnresolvedFunction(
     copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN)
 
   override def sql: String = functionName + argumentsToSql(arguments)
+
+  override private[internal] def children: Seq[ColumnNodeLike] = arguments
 }
 
 /**
@@ -222,6 +242,7 @@ private[sql] case class SqlExpression(
     extends ColumnNode {
   override private[internal] def normalize(): SqlExpression = copy(origin = 
NO_ORIGIN)
   override def sql: String = expression
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 /**
@@ -250,6 +271,8 @@ private[sql] case class Alias(
     }
     s"${child.sql} AS $alias"
   }
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
 }
 
 /**
@@ -275,10 +298,14 @@ private[sql] case class Cast(
   override def sql: String = {
     s"${optionToSql(evalMode)}CAST(${child.sql} AS ${dataType.sql})"
   }
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq(child) ++ 
evalMode
 }
 
 private[sql] object Cast {
-  sealed abstract class EvalMode(override val sql: String = "") extends 
ColumnNodeLike
+  sealed abstract class EvalMode(override val sql: String = "") extends 
ColumnNodeLike {
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+  }
   object Legacy extends EvalMode
   object Ansi extends EvalMode
   object Try extends EvalMode("TRY_")
@@ -300,6 +327,7 @@ private[sql] case class UnresolvedRegex(
   override private[internal] def normalize(): UnresolvedRegex =
     copy(planId = None, origin = NO_ORIGIN)
   override def sql: String = regex
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 /**
@@ -322,13 +350,19 @@ private[sql] case class SortOrder(
     copy(child = child.normalize(), origin = NO_ORIGIN)
 
   override def sql: String = s"${child.sql} ${sortDirection.sql} 
${nullOrdering.sql}"
+
+  override def children: Seq[ColumnNodeLike] = Seq(child, sortDirection, 
nullOrdering)
 }
 
 private[sql] object SortOrder {
-  sealed abstract class SortDirection(override val sql: String) extends 
ColumnNodeLike
+  sealed abstract class SortDirection(override val sql: String) extends 
ColumnNodeLike {
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+  }
   object Ascending extends SortDirection("ASC")
   object Descending extends SortDirection("DESC")
-  sealed abstract class NullOrdering(override val sql: String) extends 
ColumnNodeLike
+  sealed abstract class NullOrdering(override val sql: String) extends 
ColumnNodeLike {
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+  }
   object NullsFirst extends NullOrdering("NULLS FIRST")
   object NullsLast extends NullOrdering("NULLS LAST")
 }
@@ -352,6 +386,8 @@ private[sql] case class Window(
     origin = NO_ORIGIN)
 
   override def sql: String = s"${windowFunction.sql} OVER (${windowSpec.sql})"
+
+  override private[internal] def children: Seq[ColumnNodeLike] = 
Seq(windowFunction, windowSpec)
 }
 
 private[sql] case class WindowSpec(
@@ -370,6 +406,9 @@ private[sql] case class WindowSpec(
       optionToSql(frame))
     parts.filter(_.nonEmpty).mkString(" ")
   }
+  override private[internal] def children: Seq[ColumnNodeLike] = {
+    partitionColumns ++ sortColumns ++ frame
+  }
 }
 
 private[sql] case class WindowFrame(
@@ -381,15 +420,19 @@ private[sql] case class WindowFrame(
     copy(lower = lower.normalize(), upper = upper.normalize())
   override private[internal] def sql: String =
     s"${frameType.sql} BETWEEN ${lower.sql} AND ${upper.sql}"
+  override private[internal] def children: Seq[ColumnNodeLike] = 
Seq(frameType, lower, upper)
 }
 
 private[sql] object WindowFrame {
-  sealed abstract class FrameType(override val sql: String) extends 
ColumnNodeLike
+  sealed abstract class FrameType(override val sql: String) extends 
ColumnNodeLike {
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+  }
   object Row extends FrameType("ROWS")
   object Range extends FrameType("RANGE")
 
   sealed abstract class FrameBoundary extends ColumnNodeLike {
     override private[internal] def normalize(): FrameBoundary = this
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
   }
   object CurrentRow extends FrameBoundary {
     override private[internal] def sql = "CURRENT ROW"
@@ -403,6 +446,7 @@ private[sql] object WindowFrame {
   case class Value(value: ColumnNode) extends FrameBoundary {
     override private[internal] def normalize(): Value = copy(value.normalize())
     override private[internal] def sql: String = value.sql
+    override private[internal] def children: Seq[ColumnNodeLike] = Seq(value)
   }
   def value(i: Int): Value = Value(Literal(i, Some(IntegerType)))
   def value(l: Long): Value = Value(Literal(l, Some(LongType)))
@@ -434,6 +478,8 @@ private[sql] case class LambdaFunction(
     }
     argumentsSql + " -> " + function.sql
   }
+
+  override private[internal] def children: Seq[ColumnNodeLike] = function +: 
arguments
 }
 
 object LambdaFunction {
@@ -455,6 +501,8 @@ private[sql] case class UnresolvedNamedLambdaVariable(
     copy(origin = NO_ORIGIN)
 
   override def sql: String = name
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 object UnresolvedNamedLambdaVariable {
@@ -495,6 +543,8 @@ private[sql] case class UnresolvedExtractValue(
     copy(child = child.normalize(), extraction = extraction.normalize(), 
origin = NO_ORIGIN)
 
   override def sql: String = s"${child.sql}[${extraction.sql}]"
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq(child, 
extraction)
 }
 
 /**
@@ -521,6 +571,9 @@ private[sql] case class UpdateFields(
     case Some(value) => s"update_field(${structExpression.sql}, $fieldName, 
${value.sql})"
     case None => s"drop_field(${structExpression.sql}, $fieldName)"
   }
+  override private[internal] def children: Seq[ColumnNodeLike] = {
+    structExpression +: valueExpression.toSeq
+  }
 }
 
 /**
@@ -549,6 +602,11 @@ private[sql] case class CaseWhenOtherwise(
       branches.map(cv => s" WHEN ${cv._1.sql} THEN ${cv._2.sql}").mkString +
       otherwise.map(o => s" ELSE ${o.sql}").getOrElse("") +
       " END"
+
+  override private[internal] def children: Seq[ColumnNodeLike] = {
+    val branchChildren = branches.flatMap { case (condition, value) => 
Seq(condition, value) }
+    branchChildren ++ otherwise
+  }
 }
 
 /**
@@ -570,6 +628,8 @@ private[sql] case class InvokeInlineUserDefinedFunction(
 
   override def sql: String =
     function.name + argumentsToSql(arguments)
+
+  override private[internal] def children: Seq[ColumnNodeLike] = arguments
 }
 
 private[sql] trait UserDefinedFunctionLike {
@@ -589,4 +649,5 @@ private[sql] case class LazyExpression(
   override private[internal] def normalize(): ColumnNode =
     copy(child = child.normalize(), origin = NO_ORIGIN)
   override def sql: String = "lazy" + argumentsToSql(Seq(child))
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 61b68b743a5c..87a5e94d9f63 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -1056,3 +1056,18 @@ case class LazyExpression(child: Expression) extends 
UnaryExpression with Uneval
   }
   final override val nodePatterns: Seq[TreePattern] = Seq(LAZY_EXPRESSION)
 }
+
+trait UnresolvedPlanId extends LeafExpression with Unevaluable {
+  override def nullable: Boolean = throw new UnresolvedException("nullable")
+  override def dataType: DataType = throw new UnresolvedException("dataType")
+  override lazy val resolved = false
+
+  def planId: Long
+  def withPlan(plan: LogicalPlan): Expression
+
+  final override val nodePatterns: Seq[TreePattern] =
+    Seq(UNRESOLVED_PLAN_ID) ++ nodePatternsInternal()
+
+  // Subclasses can override this function to provide more TreePatterns.
+  def nodePatternsInternal(): Seq[TreePattern] = Seq()
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 0c8253659dd5..c0a2bf25fbe6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.sql.catalyst.analysis.UnresolvedPlanId
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -439,6 +440,14 @@ object ScalarSubquery {
   }
 }
 
+case class UnresolvedScalarSubqueryPlanId(planId: Long)
+  extends UnresolvedPlanId {
+
+  override def withPlan(plan: LogicalPlan): Expression = {
+    ScalarSubquery(plan)
+  }
+}
+
 /**
  * A subquery that can return multiple rows and columns. This should be 
rewritten as a join
  * with the outer query during the optimization phase.
@@ -592,3 +601,11 @@ case class Exists(
 
   final override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(EXISTS_SUBQUERY)
 }
+
+case class UnresolvedExistsPlanId(planId: Long)
+  extends UnresolvedPlanId {
+
+  override def withPlan(plan: LogicalPlan): Expression = {
+    Exists(plan)
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 95b5832392ec..1dfb0336ecf0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -155,6 +155,7 @@ object TreePattern extends Enumeration  {
   val UNRESOLVED_FUNCTION: Value = Value
   val UNRESOLVED_HINT: Value = Value
   val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
+  val UNRESOLVED_PLAN_ID: Value = Value
 
   // Unresolved Plan patterns (Alphabetically ordered)
   val UNRESOLVED_FUNC: Value = Value
diff --git 
a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 811dd032aa41..a01b5229a7b7 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -53,6 +53,7 @@ message Expression {
     MergeAction merge_action = 19;
     TypedAggregateExpression typed_aggregate_expression = 20;
     LazyExpression lazy_expression = 21;
+    SubqueryExpression subquery_expression = 22;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -457,3 +458,17 @@ message LazyExpression {
   // (Required) The expression to be marked as lazy.
   Expression child = 1;
 }
+
+message SubqueryExpression {
+  // (Required) The id of corresponding connect plan.
+  int64 plan_id = 1;
+
+  // (Required) The type of the subquery.
+  SubqueryType subquery_type = 2;
+
+  enum SubqueryType {
+    SUBQUERY_TYPE_UNKNOWN = 0;
+    SUBQUERY_TYPE_SCALAR = 1;
+    SUBQUERY_TYPE_EXISTS = 2;
+  }
+}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index bfb5f8f3fab7..5ace916ba3e9 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, 
SESSION_ID}
 import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, 
TaskResourceProfile, TaskResourceRequest}
 import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, Row, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction, 
UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction, 
UnresolvedTranspose}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, 
Inner, JoinType, L
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...]
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.classic.ClassicConversions._
@@ -161,9 +161,8 @@ class SparkConnectPlanner(
         case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
         case proto.Relation.RelTypeCase.AGGREGATE => 
transformAggregate(rel.getAggregate)
         case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
-        case proto.Relation.RelTypeCase.WITH_RELATIONS
-            if isValidSQLWithRefs(rel.getWithRelations) =>
-          transformSqlWithRefs(rel.getWithRelations)
+        case proto.Relation.RelTypeCase.WITH_RELATIONS =>
+          transformWithRelations(rel.getWithRelations)
         case proto.Relation.RelTypeCase.LOCAL_RELATION =>
           transformLocalRelation(rel.getLocalRelation)
         case proto.Relation.RelTypeCase.SAMPLE => 
transformSample(rel.getSample)
@@ -1559,6 +1558,8 @@ class SparkConnectPlanner(
         transformTypedAggregateExpression(exp.getTypedAggregateExpression, 
baseRelationOpt)
       case proto.Expression.ExprTypeCase.LAZY_EXPRESSION =>
         transformLazyExpression(exp.getLazyExpression)
+      case proto.Expression.ExprTypeCase.SUBQUERY_EXPRESSION =>
+        transformSubqueryExpression(exp.getSubqueryExpression)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -3724,7 +3725,56 @@ class SparkConnectPlanner(
     LazyExpression(transformExpression(getLazyExpression.getChild))
   }
 
-  private def assertPlan(assertion: Boolean, message: String = ""): Unit = {
+  private def transformSubqueryExpression(
+      getSubqueryExpression: proto.SubqueryExpression): Expression = {
+    val planId = getSubqueryExpression.getPlanId
+    getSubqueryExpression.getSubqueryType match {
+      case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR =>
+        UnresolvedScalarSubqueryPlanId(planId)
+      case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS =>
+        UnresolvedExistsPlanId(planId)
+      case other => throw InvalidPlanInput(s"Unknown SubqueryType $other")
+    }
+  }
+
+  private def transformWithRelations(getWithRelations: proto.WithRelations): 
LogicalPlan = {
+    if (isValidSQLWithRefs(getWithRelations)) {
+      transformSqlWithRefs(getWithRelations)
+    } else {
+      // Wrap the plan to keep the original planId.
+      val plan = Project(Seq(UnresolvedStar(None)), 
transformRelation(getWithRelations.getRoot))
+
+      val relations = getWithRelations.getReferencesList.asScala.map { ref =>
+        if (ref.hasCommon && ref.getCommon.hasPlanId) {
+          val planId = ref.getCommon.getPlanId
+          val plan = transformRelation(ref)
+          planId -> plan
+        } else {
+          throw InvalidPlanInput("Invalid WithRelation reference")
+        }
+      }.toMap
+
+      val missingPlanIds = mutable.Set.empty[Long]
+      val withRelations = plan
+        
.transformAllExpressionsWithPruning(_.containsPattern(TreePattern.UNRESOLVED_PLAN_ID))
 {
+          case u: UnresolvedPlanId =>
+            if (relations.contains(u.planId)) {
+              u.withPlan(relations(u.planId))
+            } else {
+              missingPlanIds += u.planId
+              u
+            }
+        }
+      assertPlan(
+        missingPlanIds.isEmpty,
+        "Missing relation in WithRelations: " +
+          s"${missingPlanIds.mkString("(", ", ", ")")} not in " +
+          s"${relations.keys.mkString("(", ", ", ")")}")
+      withRelations
+    }
+  }
+
+  private def assertPlan(assertion: Boolean, message: => String = ""): Unit = {
     if (!assertion) throw InvalidPlanInput(message)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 8b4726114890..8f37f5c32de3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -273,6 +273,8 @@ private[sql] case class ExpressionColumnNode private(
   }
 
   override def sql: String = expression.sql
+
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }
 
 private[sql] object ExpressionColumnNode {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index cd425162fb01..f94cf89276ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -664,4 +664,102 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
       )
     }
   }
+
+  test("subquery with generator / table-valued functions") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        spark.range(1).select(explode(t1.select(collect_list("c2")).scalar())),
+        sql("SELECT EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))")
+      )
+      checkAnswer(
+        spark.tvf.explode(t1.select(collect_list("c2")).scalar()),
+        sql("SELECT * FROM EXPLODE((SELECT COLLECT_LIST(c2) FROM t1))")
+      )
+    }
+  }
+
+  test("subquery in join condition") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.join(t2, $"t1.c1" === t1.select(max("c1")).scalar()),
+        sql("SELECT * FROM t1 JOIN t2 ON t1.c1 = (SELECT MAX(c1) FROM t1)")
+      )
+    }
+  }
+
+  test("subquery in unpivot") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkError(
+        intercept[AnalysisException] {
+          t1.unpivot(Array(t2.exists()), "c1", "c2").collect()
+        },
+        
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+        parameters = Map("treeNode" -> "(?s)'Unpivot.*"),
+        matchPVals = true,
+        queryContext = Array(ExpectedContext(
+          fragment = "exists",
+          callSitePattern = getCurrentClassCallSitePattern))
+      )
+      checkError(
+        intercept[AnalysisException] {
+          t1.unpivot(Array($"c1"), Array(t2.exists()), "c1", "c2").collect()
+        },
+        
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_IN_EXISTS_SUBQUERY",
+        parameters = Map("treeNode" -> "(?s)Expand.*"),
+        matchPVals = true,
+        queryContext = Array(ExpectedContext(
+          fragment = "exists",
+          callSitePattern = getCurrentClassCallSitePattern))
+      )
+    }
+  }
+
+  test("subquery in transpose") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkError(
+        intercept[AnalysisException] {
+          t1.transpose(t1.select(max("c1")).scalar()).collect()
+        },
+        "TRANSPOSE_INVALID_INDEX_COLUMN",
+        parameters = Map("reason" -> "Index column must be an atomic 
attribute")
+      )
+    }
+  }
+
+  test("subquery in withColumns") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.withColumn("scalar", spark.range(1).select($"c1".outer() + 
$"c2".outer()).scalar()),
+        t1.withColumn("scalar", $"c1" + $"c2")
+      )
+    }
+  }
+
+  test("subquery in drop") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(t1.drop(spark.range(1).select(lit("c1")).scalar()), t1)
+    }
+  }
+
+  test("subquery in repartition") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(t1.repartition(spark.range(1).select(lit(1)).scalar()), t1)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
index 76fcdfc38095..d72e86450de2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala
@@ -405,4 +405,5 @@ private[internal] case class Nope(override val origin: 
Origin = CurrentOrigin.ge
   extends ColumnNode {
   override private[internal] def normalize(): Nope = this
   override def sql: String = "nope"
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to