This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c7978d80b8 [SPARK-43321][CONNECT] Dataset#Joinwith
9c7978d80b8 is described below

commit 9c7978d80b8a95bd7fcc26769eea581849000862
Author: Zhen Li <[email protected]>
AuthorDate: Thu Jul 6 17:42:53 2023 -0400

    [SPARK-43321][CONNECT] Dataset#Joinwith
    
    ### What changes were proposed in this pull request?
    Impl missing method JoinWith with Join relation operation
    
    The JoinWith adds `left` and `right` struct type info in the Join relation 
proto.
    ### Why are the changes needed?
    Missing Dataset API
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Added the missing Dataset#JoinWith method
    
    ### How was this patch tested?
    E2E tests.
    
    Closes #40997 from zhenlineo/joinwith.
    
    Authored-by: Zhen Li <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  85 +++++++-
 .../spark/sql/connect/client/SparkResult.scala     |  34 +++-
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 193 ++++++++++++++++++
 .../CheckConnectJvmClientCompatibility.scala       |   1 -
 .../main/protobuf/spark/connect/relations.proto    |  10 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  24 ++-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 221 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  48 ++++-
 .../sql/catalyst/encoders/AgnosticEncoder.scala    |  44 ++--
 .../spark/sql/catalyst/plans/logical/object.scala  | 104 +++++++++-
 .../spark/sql/errors/QueryCompilationErrors.scala  |  11 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 104 +---------
 12 files changed, 639 insertions(+), 240 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2ea3169486b..4fa5c0b9641 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,6 +20,7 @@ import java.util.{Collections, Locale}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.SparkException
@@ -568,7 +569,7 @@ class Dataset[T] private[sql] (
     }
   }
 
-  private def toJoinType(name: String): proto.Join.JoinType = {
+  private def toJoinType(name: String, skipSemiAnti: Boolean = false): 
proto.Join.JoinType = {
     name.trim.toLowerCase(Locale.ROOT) match {
       case "inner" =>
         proto.Join.JoinType.JOIN_TYPE_INNER
@@ -580,12 +581,12 @@ class Dataset[T] private[sql] (
         proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
       case "right" | "rightouter" | "right_outer" =>
         proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
-      case "semi" | "leftsemi" | "left_semi" =>
+      case "semi" | "leftsemi" | "left_semi" if !skipSemiAnti =>
         proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
-      case "anti" | "leftanti" | "left_anti" =>
+      case "anti" | "leftanti" | "left_anti" if !skipSemiAnti =>
         proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
-      case _ =>
-        throw new IllegalArgumentException(s"Unsupported join type 
`joinType`.")
+      case e =>
+        throw new IllegalArgumentException(s"Unsupported join type '$e'.")
     }
   }
 
@@ -835,6 +836,80 @@ class Dataset[T] private[sql] (
     }
   }
 
+  /**
+   * Joins this Dataset returning a `Tuple2` for each pair where `condition` 
evaluates to true.
+   *
+   * This is similar to the relation `join` function with one important 
difference in the result
+   * schema. Since `joinWith` preserves objects present on either side of the 
join, the result
+   * schema is similarly nested into a tuple under the column names `_1` and 
`_2`.
+   *
+   * This type of join can be useful both for preserving type-safety with the 
original object
+   * types as well as working with relational data where either side of the 
join has column names
+   * in common.
+   *
+   * @param other
+   *   Right side of the join.
+   * @param condition
+   *   Join expression.
+   * @param joinType
+   *   Type of join to perform. Default `inner`. Must be one of: `inner`, 
`cross`, `outer`,
+   *   `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`, 
`right`, `rightouter`,
+   *   `right_outer`.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] = {
+    val joinTypeValue = toJoinType(joinType, skipSemiAnti = true)
+    val (leftNullable, rightNullable) = joinTypeValue match {
+      case proto.Join.JoinType.JOIN_TYPE_INNER | 
proto.Join.JoinType.JOIN_TYPE_CROSS =>
+        (false, false)
+      case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER =>
+        (true, true)
+      case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER =>
+        (false, true)
+      case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER =>
+        (true, false)
+      case e =>
+        throw new IllegalArgumentException(s"Unsupported join type '$e'.")
+    }
+
+    val tupleEncoder =
+      ProductEncoder[(T, U)](
+        
ClassTag(Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
+        Seq(
+          EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
+          EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))
+
+    sparkSession.newDataset(tupleEncoder) { builder =>
+      val joinBuilder = builder.getJoinBuilder
+      joinBuilder
+        .setLeft(plan.getRoot)
+        .setRight(other.plan.getRoot)
+        .setJoinType(joinTypeValue)
+        .setJoinCondition(condition.expr)
+        .setJoinDataType(joinBuilder.getJoinDataTypeBuilder
+          .setIsLeftFlattenableToRow(this.encoder.isFlattenable)
+          .setIsRightFlattenableToRow(other.encoder.isFlattenable))
+    }
+  }
+
+  /**
+   * Using inner equi-join to join this Dataset returning a `Tuple2` for each 
pair where
+   * `condition` evaluates to true.
+   *
+   * @param other
+   *   Right side of the join.
+   * @param condition
+   *   Join expression.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+    joinWith(other, condition, "inner")
+  }
+
   /**
    * Returns a new Dataset with each partition sorted by the given expressions.
    *
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 86a7cf846f2..a6ed31c1869 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -28,11 +28,11 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
-import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, 
UnboundRowEncoder}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
 import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
 
@@ -50,15 +50,33 @@ private[sql] class SparkResult[T](
   private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
 
   private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
-    val agnosticEncoder = if (encoder == UnboundRowEncoder) {
-      // Create a row encoder based on the schema.
-      RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]]
-    } else {
-      encoder
-    }
+    val agnosticEncoder = createEncoder(encoder, 
schema).asInstanceOf[AgnosticEncoder[T]]
     ExpressionEncoder(agnosticEncoder)
   }
 
+  /**
+   * Update RowEncoder and recursively update the fields of the ProductEncoder 
if found.
+   */
+  private def createEncoder[_](
+      enc: AgnosticEncoder[_],
+      dataType: DataType): AgnosticEncoder[_] = {
+    enc match {
+      case UnboundRowEncoder =>
+        // Replace the row encoder with the encoder inferred from the schema.
+        RowEncoder.encoderFor(dataType.asInstanceOf[StructType])
+      case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
+        // Recursively continue updating the tuple product encoder
+        val schema = dataType.asInstanceOf[StructType]
+        assert(fields.length <= schema.fields.length)
+        val updatedFields = fields.zipWithIndex.map { case (f, id) =>
+          f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
+        }
+        ProductEncoder(clsTag, updatedFields)
+      case _ =>
+        enc
+    }
+  }
+
   private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean 
= {
     while (responses.hasNext) {
       val response = responses.next()
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 48be815b236..73c04389c05 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -88,6 +88,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("read and write") {
+    assume(IntegrationTestUtils.isSparkHiveJarAvailable)
     val testDataPath = java.nio.file.Paths
       .get(
         IntegrationTestUtils.sparkHome,
@@ -158,6 +159,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("textFile") {
+    assume(IntegrationTestUtils.isSparkHiveJarAvailable)
     val testDataPath = java.nio.file.Paths
       .get(
         IntegrationTestUtils.sparkHome,
@@ -178,6 +180,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("write table") {
+    assume(IntegrationTestUtils.isSparkHiveJarAvailable)
     withTable("myTable") {
       val df = spark.range(10).limit(3)
       df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
@@ -221,6 +224,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("write without table or path") {
+    assume(IntegrationTestUtils.isSparkHiveJarAvailable)
     // Should receive no error to write noop
     spark.range(10).write.format("noop").mode("append").save()
   }
@@ -970,8 +974,197 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
     val result2 = spark.sql("select :c0 limit :l0", Map("l0" -> 1, "c0" -> 
"abc")).collect()
     assert(result2.length == 1 && result2(0).getString(0) === "abc")
   }
+
+  test("joinWith, flat schema") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(1, 2, 3).toDS().as("a")
+    val ds2 = Seq(1, 2).toDS().as("b")
+
+    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value", "inner")
+
+    val expectedSchema = StructType(
+      Seq(
+        StructField("_1", IntegerType, nullable = false),
+        StructField("_2", IntegerType, nullable = false)))
+
+    assert(joined.schema === expectedSchema)
+
+    val expected = Seq((1, 1), (2, 2))
+    checkSameResult(expected, joined)
+  }
+
+  test("joinWith tuple with primitive, expression") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+    val joined = ds1.joinWith(ds2, $"value" === $"_2")
+
+    // This is an inner join, so both outputs fields are non-nullable
+    val expectedSchema = StructType(
+      Seq(
+        StructField("_1", IntegerType, nullable = false),
+        StructField(
+          "_2",
+          StructType(
+            Seq(StructField("_1", StringType), StructField("_2", IntegerType, 
nullable = false))),
+          nullable = false)))
+    assert(joined.schema === expectedSchema)
+
+    checkSameResult(Seq((1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))), joined)
+  }
+
+  test("joinWith tuple with primitive, rows") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(1, 1, 2).toDF()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDF()
+
+    val joined = ds1.joinWith(ds2, $"value" === $"_2")
+
+    checkSameResult(
+      Seq((Row(1), Row("a", 1)), (Row(1), Row("a", 1)), (Row(2), Row("b", 2))),
+      joined)
+  }
+
+  test("joinWith class with primitive, toDF") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+    val df = ds1
+      .joinWith(ds2, $"value" === $"b")
+      .toDF()
+      .select($"_1", $"_2.a", $"_2.b")
+    checkSameResult(Seq(Row(1, "a", 1), Row(1, "a", 1), Row(2, "b", 2)), df)
+  }
+
+  test("multi-level joinWith") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+    val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+    val joined = ds1
+      .joinWith(ds2, $"a._2" === $"b._2")
+      .as("ab")
+      .joinWith(ds3, $"ab._1._2" === $"c._2")
+
+    checkSameResult(
+      Seq(((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))),
+      joined)
+  }
+
+  test("multi-level joinWith, rows") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(("a", 1), ("b", 2)).toDF().as("a")
+    val ds2 = Seq(("a", 1), ("b", 2)).toDF().as("b")
+    val ds3 = Seq(("a", 1), ("b", 2)).toDF().as("c")
+
+    val joined = ds1
+      .joinWith(ds2, $"a._2" === $"b._2")
+      .as("ab")
+      .joinWith(ds3, $"ab._1._2" === $"c._2")
+
+    checkSameResult(
+      Seq(((Row("a", 1), Row("a", 1)), Row("a", 1)), ((Row("b", 2), Row("b", 
2)), Row("b", 2))),
+      joined)
+  }
+
+  test("self join") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds = Seq("1", "2").toDS().as("a")
+    val joined = ds.joinWith(ds, lit(true), "cross")
+    checkSameResult(Seq(("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")), 
joined)
+  }
+
+  test("SPARK-11894: Incorrect results are returned when using null") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val nullInt = null.asInstanceOf[java.lang.Integer]
+    val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
+    val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
+
+    checkSameResult(
+      Seq(
+        ((nullInt, "1"), (nullInt, "1")),
+        ((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")),
+        ((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")),
+        ((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), 
"2"))),
+      ds1.joinWith(ds2, lit(true), "cross"))
+  }
+
+  test("SPARK-15441: Dataset outer join") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
+    val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
+    val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+
+    val expectedSchema = StructType(
+      Seq(
+        StructField(
+          "_1",
+          StructType(
+            Seq(StructField("a", StringType), StructField("b", IntegerType, 
nullable = false))),
+          nullable = false),
+        // This is a left join, so the right output is nullable:
+        StructField(
+          "_2",
+          StructType(
+            Seq(StructField("a", StringType), StructField("b", IntegerType, 
nullable = false))))))
+    assert(joined.schema === expectedSchema)
+
+    val result = joined.collect().toSet
+    assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> 
ClassData("x", 2)))
+  }
+
+  test("SPARK-37829: DataFrame outer join") {
+    // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead 
of Datasets
+    val session: SparkSession = spark
+    import session.implicits._
+    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
+    val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
+    val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+
+    val leftFieldSchema = StructType(
+      Seq(StructField("a", StringType), StructField("b", IntegerType, nullable 
= false)))
+    val rightFieldSchema = StructType(
+      Seq(StructField("a", StringType), StructField("b", IntegerType, nullable 
= false)))
+    val expectedSchema = StructType(
+      Seq(
+        StructField("_1", leftFieldSchema, nullable = false),
+        // This is a left join, so the right output is nullable:
+        StructField("_2", rightFieldSchema)))
+    assert(joined.schema === expectedSchema)
+
+    val result = joined.collect().toSet
+    val expected = Set(
+      new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
+        null,
+      new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
+        new GenericRowWithSchema(Array("x", 2), rightFieldSchema))
+    assert(result == expected)
+  }
+
+  test("SPARK-24762: joinWith on Option[Product]") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
+    val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
+    val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
+    checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined)
+  }
 }
 
+private[sql] case class ClassData(a: String, b: Int)
+
 private[sql] case class MyType(id: Long, a: Double, b: Double)
 private[sql] case class KV(key: String, value: Int)
 private[sql] class SimpleBean {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index f22baddc01e..576dadd3d9f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), 
// protected
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), 
// deprecated
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 29405a1332b..edde4819b51 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -225,6 +225,16 @@ message Join {
     JOIN_TYPE_LEFT_SEMI = 6;
     JOIN_TYPE_CROSS = 7;
   }
+
+  // (Optional) Only used by joinWith. Set the left and right join data types.
+  optional JoinDataType join_data_type = 6;
+
+  message JoinDataType {
+    // If the left data type is a struct that can be flatten to a row.
+    bool is_left_flattenable_to_row = 1;
+    // If the right data type is a struct that can be flatten to a row.
+    bool is_right_flattenable_to_row = 2;
+  }
 }
 
 // Relation of type [[SetOperation]]
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 149d5512953..d3090e8b09b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, 
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, 
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, 
LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, 
SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, 
UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, 
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, 
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, 
Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, 
StorageLevelProtoConverter, UdfPacket}
@@ -101,7 +101,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
       case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
       case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
       case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
-      case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin)
+      case proto.Relation.RelTypeCase.JOIN => 
transformJoinOrJoinWith(rel.getJoin)
       case proto.Relation.RelTypeCase.DEDUPLICATE => 
transformDeduplicate(rel.getDeduplicate)
       case proto.Relation.RelTypeCase.SET_OP => 
transformSetOperation(rel.getSetOp)
       case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
@@ -2066,6 +2066,26 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
+  private def transformJoinWith(rel: proto.Join): LogicalPlan = {
+    val joined =
+      
session.sessionState.executePlan(transformJoin(rel)).analyzed.asInstanceOf[logical.Join]
+
+    JoinWith.typedJoinWith(
+      joined,
+      session.sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
+      session.sessionState.analyzer.resolver,
+      rel.getJoinDataType.getIsLeftFlattenableToRow,
+      rel.getJoinDataType.getIsRightFlattenableToRow)
+  }
+
+  private def transformJoinOrJoinWith(rel: proto.Join): LogicalPlan = {
+    if (rel.hasJoinDataType) {
+      transformJoinWith(rel)
+    } else {
+      transformJoin(rel)
+    }
+  }
+
   private def transformJoin(rel: proto.Join): LogicalPlan = {
     assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
     if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) {
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 288bbe084c1..ce36df6f81e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -53,6 +53,7 @@ _READ_DATASOURCE_OPTIONSENTRY = 
_READ_DATASOURCE.nested_types_by_name["OptionsEn
 _PROJECT = DESCRIPTOR.message_types_by_name["Project"]
 _FILTER = DESCRIPTOR.message_types_by_name["Filter"]
 _JOIN = DESCRIPTOR.message_types_by_name["Join"]
+_JOIN_JOINDATATYPE = _JOIN.nested_types_by_name["JoinDataType"]
 _SETOPERATION = DESCRIPTOR.message_types_by_name["SetOperation"]
 _LIMIT = DESCRIPTOR.message_types_by_name["Limit"]
 _OFFSET = DESCRIPTOR.message_types_by_name["Offset"]
@@ -238,12 +239,22 @@ Join = _reflection.GeneratedProtocolMessageType(
     "Join",
     (_message.Message,),
     {
+        "JoinDataType": _reflection.GeneratedProtocolMessageType(
+            "JoinDataType",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _JOIN_JOINDATATYPE,
+                "__module__": "spark.connect.relations_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.Join.JoinDataType)
+            },
+        ),
         "DESCRIPTOR": _JOIN,
         "__module__": "spark.connect.relations_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.Join)
     },
 )
 _sym_db.RegisterMessage(Join)
+_sym_db.RegisterMessage(Join.JoinDataType)
 
 SetOperation = _reflection.GeneratedProtocolMessageType(
     "SetOperation",
@@ -808,109 +819,111 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _FILTER._serialized_start = 4314
     _FILTER._serialized_end = 4426
     _JOIN._serialized_start = 4429
-    _JOIN._serialized_end = 4900
-    _JOIN_JOINTYPE._serialized_start = 4692
-    _JOIN_JOINTYPE._serialized_end = 4900
-    _SETOPERATION._serialized_start = 4903
-    _SETOPERATION._serialized_end = 5382
-    _SETOPERATION_SETOPTYPE._serialized_start = 5219
-    _SETOPERATION_SETOPTYPE._serialized_end = 5333
-    _LIMIT._serialized_start = 5384
-    _LIMIT._serialized_end = 5460
-    _OFFSET._serialized_start = 5462
-    _OFFSET._serialized_end = 5541
-    _TAIL._serialized_start = 5543
-    _TAIL._serialized_end = 5618
-    _AGGREGATE._serialized_start = 5621
-    _AGGREGATE._serialized_end = 6203
-    _AGGREGATE_PIVOT._serialized_start = 5960
-    _AGGREGATE_PIVOT._serialized_end = 6071
-    _AGGREGATE_GROUPTYPE._serialized_start = 6074
-    _AGGREGATE_GROUPTYPE._serialized_end = 6203
-    _SORT._serialized_start = 6206
-    _SORT._serialized_end = 6366
-    _DROP._serialized_start = 6369
-    _DROP._serialized_end = 6510
-    _DEDUPLICATE._serialized_start = 6513
-    _DEDUPLICATE._serialized_end = 6753
-    _LOCALRELATION._serialized_start = 6755
-    _LOCALRELATION._serialized_end = 6844
-    _CACHEDLOCALRELATION._serialized_start = 6846
-    _CACHEDLOCALRELATION._serialized_end = 6941
-    _CACHEDREMOTERELATION._serialized_start = 6943
-    _CACHEDREMOTERELATION._serialized_end = 6998
-    _SAMPLE._serialized_start = 7001
-    _SAMPLE._serialized_end = 7274
-    _RANGE._serialized_start = 7277
-    _RANGE._serialized_end = 7422
-    _SUBQUERYALIAS._serialized_start = 7424
-    _SUBQUERYALIAS._serialized_end = 7538
-    _REPARTITION._serialized_start = 7541
-    _REPARTITION._serialized_end = 7683
-    _SHOWSTRING._serialized_start = 7686
-    _SHOWSTRING._serialized_end = 7828
-    _HTMLSTRING._serialized_start = 7830
-    _HTMLSTRING._serialized_end = 7944
-    _STATSUMMARY._serialized_start = 7946
-    _STATSUMMARY._serialized_end = 8038
-    _STATDESCRIBE._serialized_start = 8040
-    _STATDESCRIBE._serialized_end = 8121
-    _STATCROSSTAB._serialized_start = 8123
-    _STATCROSSTAB._serialized_end = 8224
-    _STATCOV._serialized_start = 8226
-    _STATCOV._serialized_end = 8322
-    _STATCORR._serialized_start = 8325
-    _STATCORR._serialized_end = 8462
-    _STATAPPROXQUANTILE._serialized_start = 8465
-    _STATAPPROXQUANTILE._serialized_end = 8629
-    _STATFREQITEMS._serialized_start = 8631
-    _STATFREQITEMS._serialized_end = 8756
-    _STATSAMPLEBY._serialized_start = 8759
-    _STATSAMPLEBY._serialized_end = 9068
-    _STATSAMPLEBY_FRACTION._serialized_start = 8960
-    _STATSAMPLEBY_FRACTION._serialized_end = 9059
-    _NAFILL._serialized_start = 9071
-    _NAFILL._serialized_end = 9205
-    _NADROP._serialized_start = 9208
-    _NADROP._serialized_end = 9342
-    _NAREPLACE._serialized_start = 9345
-    _NAREPLACE._serialized_end = 9641
-    _NAREPLACE_REPLACEMENT._serialized_start = 9500
-    _NAREPLACE_REPLACEMENT._serialized_end = 9641
-    _TODF._serialized_start = 9643
-    _TODF._serialized_end = 9731
-    _WITHCOLUMNSRENAMED._serialized_start = 9734
-    _WITHCOLUMNSRENAMED._serialized_end = 9973
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9906
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9973
-    _WITHCOLUMNS._serialized_start = 9975
-    _WITHCOLUMNS._serialized_end = 10094
-    _WITHWATERMARK._serialized_start = 10097
-    _WITHWATERMARK._serialized_end = 10231
-    _HINT._serialized_start = 10234
-    _HINT._serialized_end = 10366
-    _UNPIVOT._serialized_start = 10369
-    _UNPIVOT._serialized_end = 10696
-    _UNPIVOT_VALUES._serialized_start = 10626
-    _UNPIVOT_VALUES._serialized_end = 10685
-    _TOSCHEMA._serialized_start = 10698
-    _TOSCHEMA._serialized_end = 10804
-    _REPARTITIONBYEXPRESSION._serialized_start = 10807
-    _REPARTITIONBYEXPRESSION._serialized_end = 11010
-    _MAPPARTITIONS._serialized_start = 11013
-    _MAPPARTITIONS._serialized_end = 11194
-    _GROUPMAP._serialized_start = 11197
-    _GROUPMAP._serialized_end = 11832
-    _COGROUPMAP._serialized_start = 11835
-    _COGROUPMAP._serialized_end = 12361
-    _APPLYINPANDASWITHSTATE._serialized_start = 12364
-    _APPLYINPANDASWITHSTATE._serialized_end = 12721
-    _COLLECTMETRICS._serialized_start = 12724
-    _COLLECTMETRICS._serialized_end = 12860
-    _PARSE._serialized_start = 12863
-    _PARSE._serialized_end = 13251
+    _JOIN._serialized_end = 5135
+    _JOIN_JOINDATATYPE._serialized_start = 4769
+    _JOIN_JOINDATATYPE._serialized_end = 4905
+    _JOIN_JOINTYPE._serialized_start = 4908
+    _JOIN_JOINTYPE._serialized_end = 5116
+    _SETOPERATION._serialized_start = 5138
+    _SETOPERATION._serialized_end = 5617
+    _SETOPERATION_SETOPTYPE._serialized_start = 5454
+    _SETOPERATION_SETOPTYPE._serialized_end = 5568
+    _LIMIT._serialized_start = 5619
+    _LIMIT._serialized_end = 5695
+    _OFFSET._serialized_start = 5697
+    _OFFSET._serialized_end = 5776
+    _TAIL._serialized_start = 5778
+    _TAIL._serialized_end = 5853
+    _AGGREGATE._serialized_start = 5856
+    _AGGREGATE._serialized_end = 6438
+    _AGGREGATE_PIVOT._serialized_start = 6195
+    _AGGREGATE_PIVOT._serialized_end = 6306
+    _AGGREGATE_GROUPTYPE._serialized_start = 6309
+    _AGGREGATE_GROUPTYPE._serialized_end = 6438
+    _SORT._serialized_start = 6441
+    _SORT._serialized_end = 6601
+    _DROP._serialized_start = 6604
+    _DROP._serialized_end = 6745
+    _DEDUPLICATE._serialized_start = 6748
+    _DEDUPLICATE._serialized_end = 6988
+    _LOCALRELATION._serialized_start = 6990
+    _LOCALRELATION._serialized_end = 7079
+    _CACHEDLOCALRELATION._serialized_start = 7081
+    _CACHEDLOCALRELATION._serialized_end = 7176
+    _CACHEDREMOTERELATION._serialized_start = 7178
+    _CACHEDREMOTERELATION._serialized_end = 7233
+    _SAMPLE._serialized_start = 7236
+    _SAMPLE._serialized_end = 7509
+    _RANGE._serialized_start = 7512
+    _RANGE._serialized_end = 7657
+    _SUBQUERYALIAS._serialized_start = 7659
+    _SUBQUERYALIAS._serialized_end = 7773
+    _REPARTITION._serialized_start = 7776
+    _REPARTITION._serialized_end = 7918
+    _SHOWSTRING._serialized_start = 7921
+    _SHOWSTRING._serialized_end = 8063
+    _HTMLSTRING._serialized_start = 8065
+    _HTMLSTRING._serialized_end = 8179
+    _STATSUMMARY._serialized_start = 8181
+    _STATSUMMARY._serialized_end = 8273
+    _STATDESCRIBE._serialized_start = 8275
+    _STATDESCRIBE._serialized_end = 8356
+    _STATCROSSTAB._serialized_start = 8358
+    _STATCROSSTAB._serialized_end = 8459
+    _STATCOV._serialized_start = 8461
+    _STATCOV._serialized_end = 8557
+    _STATCORR._serialized_start = 8560
+    _STATCORR._serialized_end = 8697
+    _STATAPPROXQUANTILE._serialized_start = 8700
+    _STATAPPROXQUANTILE._serialized_end = 8864
+    _STATFREQITEMS._serialized_start = 8866
+    _STATFREQITEMS._serialized_end = 8991
+    _STATSAMPLEBY._serialized_start = 8994
+    _STATSAMPLEBY._serialized_end = 9303
+    _STATSAMPLEBY_FRACTION._serialized_start = 9195
+    _STATSAMPLEBY_FRACTION._serialized_end = 9294
+    _NAFILL._serialized_start = 9306
+    _NAFILL._serialized_end = 9440
+    _NADROP._serialized_start = 9443
+    _NADROP._serialized_end = 9577
+    _NAREPLACE._serialized_start = 9580
+    _NAREPLACE._serialized_end = 9876
+    _NAREPLACE_REPLACEMENT._serialized_start = 9735
+    _NAREPLACE_REPLACEMENT._serialized_end = 9876
+    _TODF._serialized_start = 9878
+    _TODF._serialized_end = 9966
+    _WITHCOLUMNSRENAMED._serialized_start = 9969
+    _WITHCOLUMNSRENAMED._serialized_end = 10208
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10141
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10208
+    _WITHCOLUMNS._serialized_start = 10210
+    _WITHCOLUMNS._serialized_end = 10329
+    _WITHWATERMARK._serialized_start = 10332
+    _WITHWATERMARK._serialized_end = 10466
+    _HINT._serialized_start = 10469
+    _HINT._serialized_end = 10601
+    _UNPIVOT._serialized_start = 10604
+    _UNPIVOT._serialized_end = 10931
+    _UNPIVOT_VALUES._serialized_start = 10861
+    _UNPIVOT_VALUES._serialized_end = 10920
+    _TOSCHEMA._serialized_start = 10933
+    _TOSCHEMA._serialized_end = 11039
+    _REPARTITIONBYEXPRESSION._serialized_start = 11042
+    _REPARTITIONBYEXPRESSION._serialized_end = 11245
+    _MAPPARTITIONS._serialized_start = 11248
+    _MAPPARTITIONS._serialized_end = 11429
+    _GROUPMAP._serialized_start = 11432
+    _GROUPMAP._serialized_end = 12067
+    _COGROUPMAP._serialized_start = 12070
+    _COGROUPMAP._serialized_end = 12596
+    _APPLYINPANDASWITHSTATE._serialized_start = 12599
+    _APPLYINPANDASWITHSTATE._serialized_end = 12956
+    _COLLECTMETRICS._serialized_start = 12959
+    _COLLECTMETRICS._serialized_end = 13095
+    _PARSE._serialized_start = 13098
+    _PARSE._serialized_end = 13486
     _PARSE_OPTIONSENTRY._serialized_start = 3842
     _PARSE_OPTIONSENTRY._serialized_end = 3900
-    _PARSE_PARSEFORMAT._serialized_start = 13152
-    _PARSE_PARSEFORMAT._serialized_end = 13240
+    _PARSE_PARSEFORMAT._serialized_start = 13387
+    _PARSE_PARSEFORMAT._serialized_end = 13475
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 8909d438c9d..75fd93e4e9b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -964,11 +964,37 @@ class Join(google.protobuf.message.Message):
     JOIN_TYPE_LEFT_SEMI: Join.JoinType.ValueType  # 6
     JOIN_TYPE_CROSS: Join.JoinType.ValueType  # 7
 
+    class JoinDataType(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        IS_LEFT_FLATTENABLE_TO_ROW_FIELD_NUMBER: builtins.int
+        IS_RIGHT_FLATTENABLE_TO_ROW_FIELD_NUMBER: builtins.int
+        is_left_flattenable_to_row: builtins.bool
+        """If the left data type is a struct that can be flatten to a row."""
+        is_right_flattenable_to_row: builtins.bool
+        """If the right data type is a struct that can be flatten to a row."""
+        def __init__(
+            self,
+            *,
+            is_left_flattenable_to_row: builtins.bool = ...,
+            is_right_flattenable_to_row: builtins.bool = ...,
+        ) -> None: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "is_left_flattenable_to_row",
+                b"is_left_flattenable_to_row",
+                "is_right_flattenable_to_row",
+                b"is_right_flattenable_to_row",
+            ],
+        ) -> None: ...
+
     LEFT_FIELD_NUMBER: builtins.int
     RIGHT_FIELD_NUMBER: builtins.int
     JOIN_CONDITION_FIELD_NUMBER: builtins.int
     JOIN_TYPE_FIELD_NUMBER: builtins.int
     USING_COLUMNS_FIELD_NUMBER: builtins.int
+    JOIN_DATA_TYPE_FIELD_NUMBER: builtins.int
     @property
     def left(self) -> global___Relation:
         """(Required) Left input relation for a Join."""
@@ -993,6 +1019,9 @@ class Join(google.protobuf.message.Message):
 
         This field does not co-exist with join_condition.
         """
+    @property
+    def join_data_type(self) -> global___Join.JoinDataType:
+        """(Optional) Only used by joinWith. Set the left and right join data 
types."""
     def __init__(
         self,
         *,
@@ -1001,18 +1030,32 @@ class Join(google.protobuf.message.Message):
         join_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | 
None = ...,
         join_type: global___Join.JoinType.ValueType = ...,
         using_columns: collections.abc.Iterable[builtins.str] | None = ...,
+        join_data_type: global___Join.JoinDataType | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "join_condition", b"join_condition", "left", b"left", "right", 
b"right"
+            "_join_data_type",
+            b"_join_data_type",
+            "join_condition",
+            b"join_condition",
+            "join_data_type",
+            b"join_data_type",
+            "left",
+            b"left",
+            "right",
+            b"right",
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_join_data_type",
+            b"_join_data_type",
             "join_condition",
             b"join_condition",
+            "join_data_type",
+            b"join_data_type",
             "join_type",
             b"join_type",
             "left",
@@ -1023,6 +1066,9 @@ class Join(google.protobuf.message.Message):
             b"using_columns",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_join_data_type", 
b"_join_data_type"]
+    ) -> typing_extensions.Literal["join_data_type"] | None: ...
 
 global___Join = Join
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 6599916ec7f..443bca1daa9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -36,6 +36,8 @@ import org.apache.spark.util.Utils
  * called lenient serialization. An example of this is lenient date 
serialization, in this case both
  * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization 
is never lenient; it
  * will always produce instance of the external type.
+ *
+ * An encoder is flattenable if it contains fields that can be fattened to a 
row.
  */
 trait AgnosticEncoder[T] extends Encoder[T] {
   def isPrimitive: Boolean
@@ -43,6 +45,7 @@ trait AgnosticEncoder[T] extends Encoder[T] {
   def dataType: DataType
   override def schema: StructType = StructType(StructField("value", dataType, 
nullable) :: Nil)
   def lenientSerialization: Boolean = false
+  def isFlattenable: Boolean = false
 }
 
 object AgnosticEncoders {
@@ -99,19 +102,23 @@ object AgnosticEncoders {
     def structField: StructField = StructField(name, enc.dataType, nullable, 
metadata)
   }
 
-  // This supports both Product and DefinedByConstructorParams
-  case class ProductEncoder[K](
-      override val clsTag: ClassTag[K],
-      fields: Seq[EncoderField])
-    extends AgnosticEncoder[K] {
+  // Contains a sequence of fields. The fields can be flattened to columns in 
a row.
+  trait FieldsEncoder[K] extends AgnosticEncoder[K] {
+    val fields: Seq[EncoderField]
     override def isPrimitive: Boolean = false
-    override val schema: StructType = StructType(fields.map(_.structField))
+    override def schema: StructType = StructType(fields.map(_.structField))
     override def dataType: DataType = schema
+    override val isFlattenable: Boolean = true
   }
 
+  // This supports both Product and DefinedByConstructorParams
+  case class ProductEncoder[K](
+      override val clsTag: ClassTag[K],
+      override val fields: Seq[EncoderField]) extends FieldsEncoder[K]
+
   object ProductEncoder {
     val cachedCls = new ConcurrentHashMap[Int, Class[_]]
-    def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = {
+    private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]): 
AgnosticEncoder[_] = {
       val fields = encoders.zipWithIndex.map {
         case (e, id) => EncoderField(s"_${id + 1}", e, e.nullable, 
Metadata.empty)
       }
@@ -119,30 +126,27 @@ object AgnosticEncoders {
         _ => 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
       ProductEncoder[Any](ClassTag(cls), fields)
     }
+
+    private[sql] def isTuple(tag: ClassTag[_]): Boolean = {
+      tag.runtimeClass.getName.startsWith("scala.Tuple")
+    }
   }
 
-  abstract class BaseRowEncoder extends AgnosticEncoder[Row] {
-    override def isPrimitive: Boolean = false
-    override def dataType: DataType = schema
+  abstract class BaseRowEncoder extends FieldsEncoder[Row] {
     override def clsTag: ClassTag[Row] = classTag[Row]
   }
 
-  case class RowEncoder(fields: Seq[EncoderField]) extends BaseRowEncoder {
-    override val schema: StructType = StructType(fields.map(_.structField))
-  }
+  case class RowEncoder(override val fields: Seq[EncoderField]) extends 
BaseRowEncoder
 
   object UnboundRowEncoder extends BaseRowEncoder {
     override val schema: StructType = new StructType()
-  }
+    override val fields: Seq[EncoderField] = Seq.empty
+}
 
   case class JavaBeanEncoder[K](
       override val clsTag: ClassTag[K],
-      fields: Seq[EncoderField])
-    extends AgnosticEncoder[K] {
-    override def isPrimitive: Boolean = false
-    override val schema: StructType = StructType(fields.map(_.structField))
-    override def dataType: DataType = schema
-  }
+      override val fields: Seq[EncoderField])
+    extends FieldsEncoder[K]
 
   // This will only work for encoding from/to Sparks' InternalRow format.
   // It is here for compatibility.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 54c0b84ff52..980295f5e0d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.api.java.function.FilterFunction
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.{Encoder, Row}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.{catalyst, Encoder, Row}
+import org.apache.spark.sql.catalyst.analysis.{Resolver, 
UnresolvedDeserializer}
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
+import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftSemi, 
ReferenceAllColumns}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
 import org.apache.spark.sql.types._
@@ -719,3 +720,100 @@ case class CoGroup(
   override protected def withNewChildrenInternal(
       newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = 
newLeft, right = newRight)
 }
+
+// TODO (SPARK-44225): Move this into analyzer
+object JoinWith {
+  /**
+   * find the trivially true predicates and automatically resolves them to 
both sides.
+   */
+  private[sql] def resolveSelfJoinCondition(resolver: Resolver, plan: Join): 
Join = {
+    val cond = plan.condition.map {
+      _.transform {
+        case catalyst.expressions.EqualTo(a: AttributeReference, b: 
AttributeReference)
+          if a.sameRef(b) =>
+          catalyst.expressions.EqualTo(
+            plan.left.resolveQuoted(a.name, resolver).getOrElse(
+              throw QueryCompilationErrors.resolveException(a.name, 
plan.left.schema.fieldNames)),
+            plan.right.resolveQuoted(b.name, resolver).getOrElse(
+              throw QueryCompilationErrors.resolveException(b.name, 
plan.right.schema.fieldNames)))
+        case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: 
AttributeReference)
+          if a.sameRef(b) =>
+          catalyst.expressions.EqualNullSafe(
+            plan.left.resolveQuoted(a.name, resolver).getOrElse(
+              throw QueryCompilationErrors.resolveException(a.name, 
plan.left.schema.fieldNames)),
+            plan.right.resolveQuoted(b.name, resolver).getOrElse(
+              throw QueryCompilationErrors.resolveException(b.name, 
plan.right.schema.fieldNames)))
+      }
+    }
+    plan.copy(condition = cond)
+  }
+
+  private[sql] def typedJoinWith(
+      plan: Join,
+      isAutoSelfJoinAliasEnable: Boolean,
+      resolver: Resolver,
+      isLeftFlattenableToRow: Boolean,
+      isRightFlattenableToRow: Boolean): LogicalPlan = {
+    var joined = plan
+    if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
+      throw 
QueryCompilationErrors.invalidJoinTypeInJoinWithError(joined.joinType)
+    }
+    // If auto self join alias is enable
+    if (isAutoSelfJoinAliasEnable) {
+      joined = resolveSelfJoinCondition(resolver, joined)
+    }
+
+    val leftResultExpr = {
+      if (!isLeftFlattenableToRow) {
+        assert(joined.left.output.length == 1)
+        Alias(joined.left.output.head, "_1")()
+      } else {
+        Alias(CreateStruct(joined.left.output), "_1")()
+      }
+    }
+
+    val rightResultExpr = {
+      if (!isRightFlattenableToRow) {
+        assert(joined.right.output.length == 1)
+        Alias(joined.right.output.head, "_2")()
+      } else {
+        Alias(CreateStruct(joined.right.output), "_2")()
+      }
+    }
+
+    if (joined.joinType.isInstanceOf[InnerLike]) {
+      // For inner joins, we can directly perform the join and then can 
project the join
+      // results into structs. This ensures that data remains flat during 
shuffles /
+      // exchanges (unlike the outer join path, which nests the data before 
shuffling).
+      Project(Seq(leftResultExpr, rightResultExpr), joined)
+    } else { // outer joins
+      // For both join sides, combine all outputs into a single column and 
alias it with "_1
+      // or "_2", to match the schema for the encoder of the join result.
+      // Note that we do this before joining them, to enable the join operator 
to return null
+      // for one side, in cases like outer-join.
+      val left = Project(leftResultExpr :: Nil, joined.left)
+      val right = Project(rightResultExpr :: Nil, joined.right)
+
+      // Rewrites the join condition to make the attribute point to correct 
column/field,
+      // after we combine the outputs of each join side.
+      val conditionExpr = joined.condition.get transformUp {
+        case a: Attribute if joined.left.outputSet.contains(a) =>
+          if (!isLeftFlattenableToRow) {
+            left.output.head
+          } else {
+            val index = joined.left.output.indexWhere(_.exprId == a.exprId)
+            GetStructField(left.output.head, index)
+          }
+        case a: Attribute if joined.right.outputSet.contains(a) =>
+          if (!isRightFlattenableToRow) {
+            right.output.head
+          } else {
+            val index = joined.right.output.indexWhere(_.exprId == a.exprId)
+            GetStructField(right.output.head, index)
+          }
+      }
+
+      Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 6a7979afe41..ae5bd282919 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2940,8 +2940,15 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase {
     )
   }
 
-  def cannotParseIntervalError(intervalString: String, e: Throwable): 
Throwable = {
-    val threshold = if (intervalString == null) "" else intervalString
+  def resolveException(colName: String, fields: Array[String]): 
AnalysisException = {
+    QueryCompilationErrors.unresolvedColumnWithSuggestionError(
+      colName,
+      fields.map(toSQLId).mkString(", ")
+    )
+  }
+
+  def cannotParseIntervalError(delayThreshold: String, e: Throwable): 
Throwable = {
+    val threshold = if (delayThreshold == null) "" else delayThreshold
     new AnalysisException(
       errorClass = "CANNOT_PARSE_INTERVAL",
       messageParameters = Map("intervalString" -> toSQLValue(threshold, 
StringType)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c87f95294bf..a9017924e14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -48,8 +48,8 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.catalyst.util.IntervalUtils
+import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
-import org.apache.spark.sql.errors.QueryCompilationErrors.toSQLId
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, 
ArrowConverters}
@@ -248,12 +248,7 @@ class Dataset[T] private[sql](
   private[sql] def resolve(colName: String): NamedExpression = {
     val resolver = sparkSession.sessionState.analyzer.resolver
     queryExecution.analyzed.resolveQuoted(colName, resolver)
-      .getOrElse(throw resolveException(colName, schema.fieldNames))
-  }
-
-  private def resolveException(colName: String, fields: Array[String]): 
AnalysisException = {
-    QueryCompilationErrors.unresolvedColumnWithSuggestionError(
-      colName, fields.map(toSQLId).mkString(", "))
+      .getOrElse(throw QueryCompilationErrors.resolveException(colName, 
schema.fieldNames))
   }
 
   private[sql] def numericColumns: Seq[Expression] = {
@@ -1116,30 +1111,6 @@ class Dataset[T] private[sql](
    */
   def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, 
joinExprs, "inner")
 
-  /**
-   * find the trivially true predicates and automatically resolves them to 
both sides.
-   */
-  private def resolveSelfJoinCondition(plan: Join): Join = {
-    val resolver = sparkSession.sessionState.analyzer.resolver
-    val cond = plan.condition.map { _.transform {
-      case catalyst.expressions.EqualTo(a: AttributeReference, b: 
AttributeReference)
-        if a.sameRef(b) =>
-        catalyst.expressions.EqualTo(
-          plan.left.resolveQuoted(a.name, resolver)
-            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
-          plan.right.resolveQuoted(b.name, resolver)
-            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
-      case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: 
AttributeReference)
-        if a.sameRef(b) =>
-        catalyst.expressions.EqualNullSafe(
-          plan.left.resolveQuoted(a.name, resolver)
-            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
-          plan.right.resolveQuoted(b.name, resolver)
-            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
-    }}
-    plan.copy(condition = cond)
-  }
-
   /**
    * find the trivially true predicates and automatically resolves them to 
both sides.
    */
@@ -1178,7 +1149,7 @@ class Dataset[T] private[sql](
     // By the time we get here, since we have already run analysis, all 
attributes should've been
     // resolved and become AttributeReference.
 
-    resolveSelfJoinCondition(plan)
+    
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, 
plan)
   }
 
   /**
@@ -1249,7 +1220,7 @@ class Dataset[T] private[sql](
   def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] = {
     // Creates a Join node and resolve it first, to get join condition 
resolved, self-join resolved,
     // etc.
-    var joined = sparkSession.sessionState.executePlan(
+    val joined = sparkSession.sessionState.executePlan(
       Join(
         this.logicalPlan,
         other.logicalPlan,
@@ -1257,70 +1228,15 @@ class Dataset[T] private[sql](
         Some(condition.expr),
         JoinHint.NONE)).analyzed.asInstanceOf[Join]
 
-    if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
-      throw 
QueryCompilationErrors.invalidJoinTypeInJoinWithError(joined.joinType)
-    }
-
-    // If auto self join alias is enable
-    if (sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
-      joined = resolveSelfJoinCondition(joined)
-    }
-
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
 
-    val leftResultExpr = {
-      if (!this.exprEnc.isSerializedAsStructForTopLevel) {
-        assert(joined.left.output.length == 1)
-        Alias(joined.left.output.head, "_1")()
-      } else {
-        Alias(CreateStruct(joined.left.output), "_1")()
-      }
-    }
-
-    val rightResultExpr = {
-      if (!other.exprEnc.isSerializedAsStructForTopLevel) {
-        assert(joined.right.output.length == 1)
-        Alias(joined.right.output.head, "_2")()
-      } else {
-        Alias(CreateStruct(joined.right.output), "_2")()
-      }
-    }
-
-    if (joined.joinType.isInstanceOf[InnerLike]) {
-      // For inner joins, we can directly perform the join and then can 
project the join
-      // results into structs. This ensures that data remains flat during 
shuffles /
-      // exchanges (unlike the outer join path, which nests the data before 
shuffling).
-      withTypedPlan(Project(Seq(leftResultExpr, rightResultExpr), joined))
-    } else { // outer joins
-      // For both join sides, combine all outputs into a single column and 
alias it with "_1
-      // or "_2", to match the schema for the encoder of the join result.
-      // Note that we do this before joining them, to enable the join operator 
to return null
-      // for one side, in cases like outer-join.
-      val left = Project(leftResultExpr :: Nil, joined.left)
-      val right = Project(rightResultExpr :: Nil, joined.right)
-
-      // Rewrites the join condition to make the attribute point to correct 
column/field,
-      // after we combine the outputs of each join side.
-      val conditionExpr = joined.condition.get transformUp {
-        case a: Attribute if joined.left.outputSet.contains(a) =>
-          if (!this.exprEnc.isSerializedAsStructForTopLevel) {
-            left.output.head
-          } else {
-            val index = joined.left.output.indexWhere(_.exprId == a.exprId)
-            GetStructField(left.output.head, index)
-          }
-        case a: Attribute if joined.right.outputSet.contains(a) =>
-          if (!other.exprEnc.isSerializedAsStructForTopLevel) {
-            right.output.head
-          } else {
-            val index = joined.right.output.indexWhere(_.exprId == a.exprId)
-            GetStructField(right.output.head, index)
-          }
-      }
-
-      withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), 
JoinHint.NONE))
-    }
+    withTypedPlan(JoinWith.typedJoinWith(
+      joined,
+      sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
+      sparkSession.sessionState.analyzer.resolver,
+      this.exprEnc.isSerializedAsStructForTopLevel,
+      other.exprEnc.isSerializedAsStructForTopLevel))
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to