This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9c7978d80b8 [SPARK-43321][CONNECT] Dataset#Joinwith
9c7978d80b8 is described below
commit 9c7978d80b8a95bd7fcc26769eea581849000862
Author: Zhen Li <[email protected]>
AuthorDate: Thu Jul 6 17:42:53 2023 -0400
[SPARK-43321][CONNECT] Dataset#Joinwith
### What changes were proposed in this pull request?
Impl missing method JoinWith with Join relation operation
The JoinWith adds `left` and `right` struct type info in the Join relation
proto.
### Why are the changes needed?
Missing Dataset API
### Does this PR introduce _any_ user-facing change?
Yes. Added the missing Dataset#JoinWith method
### How was this patch tested?
E2E tests.
Closes #40997 from zhenlineo/joinwith.
Authored-by: Zhen Li <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 85 +++++++-
.../spark/sql/connect/client/SparkResult.scala | 34 +++-
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 193 ++++++++++++++++++
.../CheckConnectJvmClientCompatibility.scala | 1 -
.../main/protobuf/spark/connect/relations.proto | 10 +
.../sql/connect/planner/SparkConnectPlanner.scala | 24 ++-
python/pyspark/sql/connect/proto/relations_pb2.py | 221 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 48 ++++-
.../sql/catalyst/encoders/AgnosticEncoder.scala | 44 ++--
.../spark/sql/catalyst/plans/logical/object.scala | 104 +++++++++-
.../spark/sql/errors/QueryCompilationErrors.scala | 11 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 104 +---------
12 files changed, 639 insertions(+), 240 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2ea3169486b..4fa5c0b9641 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,6 +20,7 @@ import java.util.{Collections, Locale}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.spark.SparkException
@@ -568,7 +569,7 @@ class Dataset[T] private[sql] (
}
}
- private def toJoinType(name: String): proto.Join.JoinType = {
+ private def toJoinType(name: String, skipSemiAnti: Boolean = false):
proto.Join.JoinType = {
name.trim.toLowerCase(Locale.ROOT) match {
case "inner" =>
proto.Join.JoinType.JOIN_TYPE_INNER
@@ -580,12 +581,12 @@ class Dataset[T] private[sql] (
proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
case "right" | "rightouter" | "right_outer" =>
proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
- case "semi" | "leftsemi" | "left_semi" =>
+ case "semi" | "leftsemi" | "left_semi" if !skipSemiAnti =>
proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
- case "anti" | "leftanti" | "left_anti" =>
+ case "anti" | "leftanti" | "left_anti" if !skipSemiAnti =>
proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
- case _ =>
- throw new IllegalArgumentException(s"Unsupported join type
`joinType`.")
+ case e =>
+ throw new IllegalArgumentException(s"Unsupported join type '$e'.")
}
}
@@ -835,6 +836,80 @@ class Dataset[T] private[sql] (
}
}
+ /**
+ * Joins this Dataset returning a `Tuple2` for each pair where `condition`
evaluates to true.
+ *
+ * This is similar to the relation `join` function with one important
difference in the result
+ * schema. Since `joinWith` preserves objects present on either side of the
join, the result
+ * schema is similarly nested into a tuple under the column names `_1` and
`_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the
original object
+ * types as well as working with relational data where either side of the
join has column names
+ * in common.
+ *
+ * @param other
+ * Right side of the join.
+ * @param condition
+ * Join expression.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`,
`cross`, `outer`,
+ * `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`,
`right`, `rightouter`,
+ * `right_outer`.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column, joinType: String):
Dataset[(T, U)] = {
+ val joinTypeValue = toJoinType(joinType, skipSemiAnti = true)
+ val (leftNullable, rightNullable) = joinTypeValue match {
+ case proto.Join.JoinType.JOIN_TYPE_INNER |
proto.Join.JoinType.JOIN_TYPE_CROSS =>
+ (false, false)
+ case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER =>
+ (true, true)
+ case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER =>
+ (false, true)
+ case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER =>
+ (true, false)
+ case e =>
+ throw new IllegalArgumentException(s"Unsupported join type '$e'.")
+ }
+
+ val tupleEncoder =
+ ProductEncoder[(T, U)](
+
ClassTag(Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
+ Seq(
+ EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
+ EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))
+
+ sparkSession.newDataset(tupleEncoder) { builder =>
+ val joinBuilder = builder.getJoinBuilder
+ joinBuilder
+ .setLeft(plan.getRoot)
+ .setRight(other.plan.getRoot)
+ .setJoinType(joinTypeValue)
+ .setJoinCondition(condition.expr)
+ .setJoinDataType(joinBuilder.getJoinDataTypeBuilder
+ .setIsLeftFlattenableToRow(this.encoder.isFlattenable)
+ .setIsRightFlattenableToRow(other.encoder.isFlattenable))
+ }
+ }
+
+ /**
+ * Using inner equi-join to join this Dataset returning a `Tuple2` for each
pair where
+ * `condition` evaluates to true.
+ *
+ * @param other
+ * Right side of the join.
+ * @param condition
+ * Join expression.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ joinWith(other, condition, "inner")
+ }
+
/**
* Returns a new Dataset with each partition sorted by the given expressions.
*
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 86a7cf846f2..a6ed31c1869 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -28,11 +28,11 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
ExpressionEncoder, RowEncoder}
-import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder,
UnboundRowEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch,
ColumnVector}
@@ -50,15 +50,33 @@ private[sql] class SparkResult[T](
private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
- val agnosticEncoder = if (encoder == UnboundRowEncoder) {
- // Create a row encoder based on the schema.
- RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]]
- } else {
- encoder
- }
+ val agnosticEncoder = createEncoder(encoder,
schema).asInstanceOf[AgnosticEncoder[T]]
ExpressionEncoder(agnosticEncoder)
}
+ /**
+ * Update RowEncoder and recursively update the fields of the ProductEncoder
if found.
+ */
+ private def createEncoder[_](
+ enc: AgnosticEncoder[_],
+ dataType: DataType): AgnosticEncoder[_] = {
+ enc match {
+ case UnboundRowEncoder =>
+ // Replace the row encoder with the encoder inferred from the schema.
+ RowEncoder.encoderFor(dataType.asInstanceOf[StructType])
+ case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
+ // Recursively continue updating the tuple product encoder
+ val schema = dataType.asInstanceOf[StructType]
+ assert(fields.length <= schema.fields.length)
+ val updatedFields = fields.zipWithIndex.map { case (f, id) =>
+ f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
+ }
+ ProductEncoder(clsTag, updatedFields)
+ case _ =>
+ enc
+ }
+ }
+
private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean
= {
while (responses.hasNext) {
val response = responses.next()
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 48be815b236..73c04389c05 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -88,6 +88,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
}
test("read and write") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
@@ -158,6 +159,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
}
test("textFile") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
@@ -178,6 +180,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
}
test("write table") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("myTable") {
val df = spark.range(10).limit(3)
df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
@@ -221,6 +224,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
}
test("write without table or path") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
// Should receive no error to write noop
spark.range(10).write.format("noop").mode("append").save()
}
@@ -970,8 +974,197 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
val result2 = spark.sql("select :c0 limit :l0", Map("l0" -> 1, "c0" ->
"abc")).collect()
assert(result2.length == 1 && result2(0).getString(0) === "abc")
}
+
+ test("joinWith, flat schema") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ val joined = ds1.joinWith(ds2, $"a.value" === $"b.value", "inner")
+
+ val expectedSchema = StructType(
+ Seq(
+ StructField("_1", IntegerType, nullable = false),
+ StructField("_2", IntegerType, nullable = false)))
+
+ assert(joined.schema === expectedSchema)
+
+ val expected = Seq((1, 1), (2, 2))
+ checkSameResult(expected, joined)
+ }
+
+ test("joinWith tuple with primitive, expression") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ val joined = ds1.joinWith(ds2, $"value" === $"_2")
+
+ // This is an inner join, so both outputs fields are non-nullable
+ val expectedSchema = StructType(
+ Seq(
+ StructField("_1", IntegerType, nullable = false),
+ StructField(
+ "_2",
+ StructType(
+ Seq(StructField("_1", StringType), StructField("_2", IntegerType,
nullable = false))),
+ nullable = false)))
+ assert(joined.schema === expectedSchema)
+
+ checkSameResult(Seq((1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))), joined)
+ }
+
+ test("joinWith tuple with primitive, rows") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(1, 1, 2).toDF()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDF()
+
+ val joined = ds1.joinWith(ds2, $"value" === $"_2")
+
+ checkSameResult(
+ Seq((Row(1), Row("a", 1)), (Row(1), Row("a", 1)), (Row(2), Row("b", 2))),
+ joined)
+ }
+
+ test("joinWith class with primitive, toDF") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ val df = ds1
+ .joinWith(ds2, $"value" === $"b")
+ .toDF()
+ .select($"_1", $"_2.a", $"_2.b")
+ checkSameResult(Seq(Row(1, "a", 1), Row(1, "a", 1), Row(2, "b", 2)), df)
+ }
+
+ test("multi-level joinWith") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+ val joined = ds1
+ .joinWith(ds2, $"a._2" === $"b._2")
+ .as("ab")
+ .joinWith(ds3, $"ab._1._2" === $"c._2")
+
+ checkSameResult(
+ Seq(((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))),
+ joined)
+ }
+
+ test("multi-level joinWith, rows") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(("a", 1), ("b", 2)).toDF().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDF().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDF().as("c")
+
+ val joined = ds1
+ .joinWith(ds2, $"a._2" === $"b._2")
+ .as("ab")
+ .joinWith(ds3, $"ab._1._2" === $"c._2")
+
+ checkSameResult(
+ Seq(((Row("a", 1), Row("a", 1)), Row("a", 1)), ((Row("b", 2), Row("b",
2)), Row("b", 2))),
+ joined)
+ }
+
+ test("self join") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds = Seq("1", "2").toDS().as("a")
+ val joined = ds.joinWith(ds, lit(true), "cross")
+ checkSameResult(Seq(("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")),
joined)
+ }
+
+ test("SPARK-11894: Incorrect results are returned when using null") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val nullInt = null.asInstanceOf[java.lang.Integer]
+ val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
+ val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
+
+ checkSameResult(
+ Seq(
+ ((nullInt, "1"), (nullInt, "1")),
+ ((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")),
+ ((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")),
+ ((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22),
"2"))),
+ ds1.joinWith(ds2, lit(true), "cross"))
+ }
+
+ test("SPARK-15441: Dataset outer join") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
+ val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
+ val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+
+ val expectedSchema = StructType(
+ Seq(
+ StructField(
+ "_1",
+ StructType(
+ Seq(StructField("a", StringType), StructField("b", IntegerType,
nullable = false))),
+ nullable = false),
+ // This is a left join, so the right output is nullable:
+ StructField(
+ "_2",
+ StructType(
+ Seq(StructField("a", StringType), StructField("b", IntegerType,
nullable = false))))))
+ assert(joined.schema === expectedSchema)
+
+ val result = joined.collect().toSet
+ assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) ->
ClassData("x", 2)))
+ }
+
+ test("SPARK-37829: DataFrame outer join") {
+ // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead
of Datasets
+ val session: SparkSession = spark
+ import session.implicits._
+ val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
+ val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
+ val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+
+ val leftFieldSchema = StructType(
+ Seq(StructField("a", StringType), StructField("b", IntegerType, nullable
= false)))
+ val rightFieldSchema = StructType(
+ Seq(StructField("a", StringType), StructField("b", IntegerType, nullable
= false)))
+ val expectedSchema = StructType(
+ Seq(
+ StructField("_1", leftFieldSchema, nullable = false),
+ // This is a left join, so the right output is nullable:
+ StructField("_2", rightFieldSchema)))
+ assert(joined.schema === expectedSchema)
+
+ val result = joined.collect().toSet
+ val expected = Set(
+ new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
+ null,
+ new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
+ new GenericRowWithSchema(Array("x", 2), rightFieldSchema))
+ assert(result == expected)
+ }
+
+ test("SPARK-24762: joinWith on Option[Product]") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
+ val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
+ val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
+ checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined)
+ }
}
+private[sql] case class ClassData(a: String, b: Int)
+
private[sql] case class MyType(id: Long, a: Double, b: Double)
private[sql] case class KV(key: String, value: Int)
private[sql] class SimpleBean {
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index f22baddc01e..576dadd3d9f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"),
// protected
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"),
// deprecated
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 29405a1332b..edde4819b51 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -225,6 +225,16 @@ message Join {
JOIN_TYPE_LEFT_SEMI = 6;
JOIN_TYPE_CROSS = 7;
}
+
+ // (Optional) Only used by joinWith. Set the left and right join data types.
+ optional JoinDataType join_data_type = 6;
+
+ message JoinDataType {
+ // If the left data type is a struct that can be flatten to a row.
+ bool is_left_flattenable_to_row = 1;
+ // If the right data type is a struct that can be flatten to a row.
+ bool is_right_flattenable_to_row = 2;
+ }
}
// Relation of type [[SetOperation]]
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 149d5512953..d3090e8b09b 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup,
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark,
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation,
LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample,
SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot,
UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup,
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark,
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union,
Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter,
StorageLevelProtoConverter, UdfPacket}
@@ -101,7 +101,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder)
extends Logging {
case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
- case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin)
+ case proto.Relation.RelTypeCase.JOIN =>
transformJoinOrJoinWith(rel.getJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE =>
transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP =>
transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
@@ -2066,6 +2066,26 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
}
}
+ private def transformJoinWith(rel: proto.Join): LogicalPlan = {
+ val joined =
+
session.sessionState.executePlan(transformJoin(rel)).analyzed.asInstanceOf[logical.Join]
+
+ JoinWith.typedJoinWith(
+ joined,
+ session.sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
+ session.sessionState.analyzer.resolver,
+ rel.getJoinDataType.getIsLeftFlattenableToRow,
+ rel.getJoinDataType.getIsRightFlattenableToRow)
+ }
+
+ private def transformJoinOrJoinWith(rel: proto.Join): LogicalPlan = {
+ if (rel.hasJoinDataType) {
+ transformJoinWith(rel)
+ } else {
+ transformJoin(rel)
+ }
+ }
+
private def transformJoin(rel: proto.Join): LogicalPlan = {
assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) {
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 288bbe084c1..ce36df6f81e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -53,6 +53,7 @@ _READ_DATASOURCE_OPTIONSENTRY =
_READ_DATASOURCE.nested_types_by_name["OptionsEn
_PROJECT = DESCRIPTOR.message_types_by_name["Project"]
_FILTER = DESCRIPTOR.message_types_by_name["Filter"]
_JOIN = DESCRIPTOR.message_types_by_name["Join"]
+_JOIN_JOINDATATYPE = _JOIN.nested_types_by_name["JoinDataType"]
_SETOPERATION = DESCRIPTOR.message_types_by_name["SetOperation"]
_LIMIT = DESCRIPTOR.message_types_by_name["Limit"]
_OFFSET = DESCRIPTOR.message_types_by_name["Offset"]
@@ -238,12 +239,22 @@ Join = _reflection.GeneratedProtocolMessageType(
"Join",
(_message.Message,),
{
+ "JoinDataType": _reflection.GeneratedProtocolMessageType(
+ "JoinDataType",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _JOIN_JOINDATATYPE,
+ "__module__": "spark.connect.relations_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.Join.JoinDataType)
+ },
+ ),
"DESCRIPTOR": _JOIN,
"__module__": "spark.connect.relations_pb2"
# @@protoc_insertion_point(class_scope:spark.connect.Join)
},
)
_sym_db.RegisterMessage(Join)
+_sym_db.RegisterMessage(Join.JoinDataType)
SetOperation = _reflection.GeneratedProtocolMessageType(
"SetOperation",
@@ -808,109 +819,111 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_FILTER._serialized_start = 4314
_FILTER._serialized_end = 4426
_JOIN._serialized_start = 4429
- _JOIN._serialized_end = 4900
- _JOIN_JOINTYPE._serialized_start = 4692
- _JOIN_JOINTYPE._serialized_end = 4900
- _SETOPERATION._serialized_start = 4903
- _SETOPERATION._serialized_end = 5382
- _SETOPERATION_SETOPTYPE._serialized_start = 5219
- _SETOPERATION_SETOPTYPE._serialized_end = 5333
- _LIMIT._serialized_start = 5384
- _LIMIT._serialized_end = 5460
- _OFFSET._serialized_start = 5462
- _OFFSET._serialized_end = 5541
- _TAIL._serialized_start = 5543
- _TAIL._serialized_end = 5618
- _AGGREGATE._serialized_start = 5621
- _AGGREGATE._serialized_end = 6203
- _AGGREGATE_PIVOT._serialized_start = 5960
- _AGGREGATE_PIVOT._serialized_end = 6071
- _AGGREGATE_GROUPTYPE._serialized_start = 6074
- _AGGREGATE_GROUPTYPE._serialized_end = 6203
- _SORT._serialized_start = 6206
- _SORT._serialized_end = 6366
- _DROP._serialized_start = 6369
- _DROP._serialized_end = 6510
- _DEDUPLICATE._serialized_start = 6513
- _DEDUPLICATE._serialized_end = 6753
- _LOCALRELATION._serialized_start = 6755
- _LOCALRELATION._serialized_end = 6844
- _CACHEDLOCALRELATION._serialized_start = 6846
- _CACHEDLOCALRELATION._serialized_end = 6941
- _CACHEDREMOTERELATION._serialized_start = 6943
- _CACHEDREMOTERELATION._serialized_end = 6998
- _SAMPLE._serialized_start = 7001
- _SAMPLE._serialized_end = 7274
- _RANGE._serialized_start = 7277
- _RANGE._serialized_end = 7422
- _SUBQUERYALIAS._serialized_start = 7424
- _SUBQUERYALIAS._serialized_end = 7538
- _REPARTITION._serialized_start = 7541
- _REPARTITION._serialized_end = 7683
- _SHOWSTRING._serialized_start = 7686
- _SHOWSTRING._serialized_end = 7828
- _HTMLSTRING._serialized_start = 7830
- _HTMLSTRING._serialized_end = 7944
- _STATSUMMARY._serialized_start = 7946
- _STATSUMMARY._serialized_end = 8038
- _STATDESCRIBE._serialized_start = 8040
- _STATDESCRIBE._serialized_end = 8121
- _STATCROSSTAB._serialized_start = 8123
- _STATCROSSTAB._serialized_end = 8224
- _STATCOV._serialized_start = 8226
- _STATCOV._serialized_end = 8322
- _STATCORR._serialized_start = 8325
- _STATCORR._serialized_end = 8462
- _STATAPPROXQUANTILE._serialized_start = 8465
- _STATAPPROXQUANTILE._serialized_end = 8629
- _STATFREQITEMS._serialized_start = 8631
- _STATFREQITEMS._serialized_end = 8756
- _STATSAMPLEBY._serialized_start = 8759
- _STATSAMPLEBY._serialized_end = 9068
- _STATSAMPLEBY_FRACTION._serialized_start = 8960
- _STATSAMPLEBY_FRACTION._serialized_end = 9059
- _NAFILL._serialized_start = 9071
- _NAFILL._serialized_end = 9205
- _NADROP._serialized_start = 9208
- _NADROP._serialized_end = 9342
- _NAREPLACE._serialized_start = 9345
- _NAREPLACE._serialized_end = 9641
- _NAREPLACE_REPLACEMENT._serialized_start = 9500
- _NAREPLACE_REPLACEMENT._serialized_end = 9641
- _TODF._serialized_start = 9643
- _TODF._serialized_end = 9731
- _WITHCOLUMNSRENAMED._serialized_start = 9734
- _WITHCOLUMNSRENAMED._serialized_end = 9973
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9906
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9973
- _WITHCOLUMNS._serialized_start = 9975
- _WITHCOLUMNS._serialized_end = 10094
- _WITHWATERMARK._serialized_start = 10097
- _WITHWATERMARK._serialized_end = 10231
- _HINT._serialized_start = 10234
- _HINT._serialized_end = 10366
- _UNPIVOT._serialized_start = 10369
- _UNPIVOT._serialized_end = 10696
- _UNPIVOT_VALUES._serialized_start = 10626
- _UNPIVOT_VALUES._serialized_end = 10685
- _TOSCHEMA._serialized_start = 10698
- _TOSCHEMA._serialized_end = 10804
- _REPARTITIONBYEXPRESSION._serialized_start = 10807
- _REPARTITIONBYEXPRESSION._serialized_end = 11010
- _MAPPARTITIONS._serialized_start = 11013
- _MAPPARTITIONS._serialized_end = 11194
- _GROUPMAP._serialized_start = 11197
- _GROUPMAP._serialized_end = 11832
- _COGROUPMAP._serialized_start = 11835
- _COGROUPMAP._serialized_end = 12361
- _APPLYINPANDASWITHSTATE._serialized_start = 12364
- _APPLYINPANDASWITHSTATE._serialized_end = 12721
- _COLLECTMETRICS._serialized_start = 12724
- _COLLECTMETRICS._serialized_end = 12860
- _PARSE._serialized_start = 12863
- _PARSE._serialized_end = 13251
+ _JOIN._serialized_end = 5135
+ _JOIN_JOINDATATYPE._serialized_start = 4769
+ _JOIN_JOINDATATYPE._serialized_end = 4905
+ _JOIN_JOINTYPE._serialized_start = 4908
+ _JOIN_JOINTYPE._serialized_end = 5116
+ _SETOPERATION._serialized_start = 5138
+ _SETOPERATION._serialized_end = 5617
+ _SETOPERATION_SETOPTYPE._serialized_start = 5454
+ _SETOPERATION_SETOPTYPE._serialized_end = 5568
+ _LIMIT._serialized_start = 5619
+ _LIMIT._serialized_end = 5695
+ _OFFSET._serialized_start = 5697
+ _OFFSET._serialized_end = 5776
+ _TAIL._serialized_start = 5778
+ _TAIL._serialized_end = 5853
+ _AGGREGATE._serialized_start = 5856
+ _AGGREGATE._serialized_end = 6438
+ _AGGREGATE_PIVOT._serialized_start = 6195
+ _AGGREGATE_PIVOT._serialized_end = 6306
+ _AGGREGATE_GROUPTYPE._serialized_start = 6309
+ _AGGREGATE_GROUPTYPE._serialized_end = 6438
+ _SORT._serialized_start = 6441
+ _SORT._serialized_end = 6601
+ _DROP._serialized_start = 6604
+ _DROP._serialized_end = 6745
+ _DEDUPLICATE._serialized_start = 6748
+ _DEDUPLICATE._serialized_end = 6988
+ _LOCALRELATION._serialized_start = 6990
+ _LOCALRELATION._serialized_end = 7079
+ _CACHEDLOCALRELATION._serialized_start = 7081
+ _CACHEDLOCALRELATION._serialized_end = 7176
+ _CACHEDREMOTERELATION._serialized_start = 7178
+ _CACHEDREMOTERELATION._serialized_end = 7233
+ _SAMPLE._serialized_start = 7236
+ _SAMPLE._serialized_end = 7509
+ _RANGE._serialized_start = 7512
+ _RANGE._serialized_end = 7657
+ _SUBQUERYALIAS._serialized_start = 7659
+ _SUBQUERYALIAS._serialized_end = 7773
+ _REPARTITION._serialized_start = 7776
+ _REPARTITION._serialized_end = 7918
+ _SHOWSTRING._serialized_start = 7921
+ _SHOWSTRING._serialized_end = 8063
+ _HTMLSTRING._serialized_start = 8065
+ _HTMLSTRING._serialized_end = 8179
+ _STATSUMMARY._serialized_start = 8181
+ _STATSUMMARY._serialized_end = 8273
+ _STATDESCRIBE._serialized_start = 8275
+ _STATDESCRIBE._serialized_end = 8356
+ _STATCROSSTAB._serialized_start = 8358
+ _STATCROSSTAB._serialized_end = 8459
+ _STATCOV._serialized_start = 8461
+ _STATCOV._serialized_end = 8557
+ _STATCORR._serialized_start = 8560
+ _STATCORR._serialized_end = 8697
+ _STATAPPROXQUANTILE._serialized_start = 8700
+ _STATAPPROXQUANTILE._serialized_end = 8864
+ _STATFREQITEMS._serialized_start = 8866
+ _STATFREQITEMS._serialized_end = 8991
+ _STATSAMPLEBY._serialized_start = 8994
+ _STATSAMPLEBY._serialized_end = 9303
+ _STATSAMPLEBY_FRACTION._serialized_start = 9195
+ _STATSAMPLEBY_FRACTION._serialized_end = 9294
+ _NAFILL._serialized_start = 9306
+ _NAFILL._serialized_end = 9440
+ _NADROP._serialized_start = 9443
+ _NADROP._serialized_end = 9577
+ _NAREPLACE._serialized_start = 9580
+ _NAREPLACE._serialized_end = 9876
+ _NAREPLACE_REPLACEMENT._serialized_start = 9735
+ _NAREPLACE_REPLACEMENT._serialized_end = 9876
+ _TODF._serialized_start = 9878
+ _TODF._serialized_end = 9966
+ _WITHCOLUMNSRENAMED._serialized_start = 9969
+ _WITHCOLUMNSRENAMED._serialized_end = 10208
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10141
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10208
+ _WITHCOLUMNS._serialized_start = 10210
+ _WITHCOLUMNS._serialized_end = 10329
+ _WITHWATERMARK._serialized_start = 10332
+ _WITHWATERMARK._serialized_end = 10466
+ _HINT._serialized_start = 10469
+ _HINT._serialized_end = 10601
+ _UNPIVOT._serialized_start = 10604
+ _UNPIVOT._serialized_end = 10931
+ _UNPIVOT_VALUES._serialized_start = 10861
+ _UNPIVOT_VALUES._serialized_end = 10920
+ _TOSCHEMA._serialized_start = 10933
+ _TOSCHEMA._serialized_end = 11039
+ _REPARTITIONBYEXPRESSION._serialized_start = 11042
+ _REPARTITIONBYEXPRESSION._serialized_end = 11245
+ _MAPPARTITIONS._serialized_start = 11248
+ _MAPPARTITIONS._serialized_end = 11429
+ _GROUPMAP._serialized_start = 11432
+ _GROUPMAP._serialized_end = 12067
+ _COGROUPMAP._serialized_start = 12070
+ _COGROUPMAP._serialized_end = 12596
+ _APPLYINPANDASWITHSTATE._serialized_start = 12599
+ _APPLYINPANDASWITHSTATE._serialized_end = 12956
+ _COLLECTMETRICS._serialized_start = 12959
+ _COLLECTMETRICS._serialized_end = 13095
+ _PARSE._serialized_start = 13098
+ _PARSE._serialized_end = 13486
_PARSE_OPTIONSENTRY._serialized_start = 3842
_PARSE_OPTIONSENTRY._serialized_end = 3900
- _PARSE_PARSEFORMAT._serialized_start = 13152
- _PARSE_PARSEFORMAT._serialized_end = 13240
+ _PARSE_PARSEFORMAT._serialized_start = 13387
+ _PARSE_PARSEFORMAT._serialized_end = 13475
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 8909d438c9d..75fd93e4e9b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -964,11 +964,37 @@ class Join(google.protobuf.message.Message):
JOIN_TYPE_LEFT_SEMI: Join.JoinType.ValueType # 6
JOIN_TYPE_CROSS: Join.JoinType.ValueType # 7
+ class JoinDataType(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ IS_LEFT_FLATTENABLE_TO_ROW_FIELD_NUMBER: builtins.int
+ IS_RIGHT_FLATTENABLE_TO_ROW_FIELD_NUMBER: builtins.int
+ is_left_flattenable_to_row: builtins.bool
+ """If the left data type is a struct that can be flatten to a row."""
+ is_right_flattenable_to_row: builtins.bool
+ """If the right data type is a struct that can be flatten to a row."""
+ def __init__(
+ self,
+ *,
+ is_left_flattenable_to_row: builtins.bool = ...,
+ is_right_flattenable_to_row: builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "is_left_flattenable_to_row",
+ b"is_left_flattenable_to_row",
+ "is_right_flattenable_to_row",
+ b"is_right_flattenable_to_row",
+ ],
+ ) -> None: ...
+
LEFT_FIELD_NUMBER: builtins.int
RIGHT_FIELD_NUMBER: builtins.int
JOIN_CONDITION_FIELD_NUMBER: builtins.int
JOIN_TYPE_FIELD_NUMBER: builtins.int
USING_COLUMNS_FIELD_NUMBER: builtins.int
+ JOIN_DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def left(self) -> global___Relation:
"""(Required) Left input relation for a Join."""
@@ -993,6 +1019,9 @@ class Join(google.protobuf.message.Message):
This field does not co-exist with join_condition.
"""
+ @property
+ def join_data_type(self) -> global___Join.JoinDataType:
+ """(Optional) Only used by joinWith. Set the left and right join data
types."""
def __init__(
self,
*,
@@ -1001,18 +1030,32 @@ class Join(google.protobuf.message.Message):
join_condition: pyspark.sql.connect.proto.expressions_pb2.Expression |
None = ...,
join_type: global___Join.JoinType.ValueType = ...,
using_columns: collections.abc.Iterable[builtins.str] | None = ...,
+ join_data_type: global___Join.JoinDataType | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "join_condition", b"join_condition", "left", b"left", "right",
b"right"
+ "_join_data_type",
+ b"_join_data_type",
+ "join_condition",
+ b"join_condition",
+ "join_data_type",
+ b"join_data_type",
+ "left",
+ b"left",
+ "right",
+ b"right",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_join_data_type",
+ b"_join_data_type",
"join_condition",
b"join_condition",
+ "join_data_type",
+ b"join_data_type",
"join_type",
b"join_type",
"left",
@@ -1023,6 +1066,9 @@ class Join(google.protobuf.message.Message):
b"using_columns",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_join_data_type",
b"_join_data_type"]
+ ) -> typing_extensions.Literal["join_data_type"] | None: ...
global___Join = Join
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 6599916ec7f..443bca1daa9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -36,6 +36,8 @@ import org.apache.spark.util.Utils
* called lenient serialization. An example of this is lenient date
serialization, in this case both
* [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization
is never lenient; it
* will always produce instance of the external type.
+ *
+ * An encoder is flattenable if it contains fields that can be fattened to a
row.
*/
trait AgnosticEncoder[T] extends Encoder[T] {
def isPrimitive: Boolean
@@ -43,6 +45,7 @@ trait AgnosticEncoder[T] extends Encoder[T] {
def dataType: DataType
override def schema: StructType = StructType(StructField("value", dataType,
nullable) :: Nil)
def lenientSerialization: Boolean = false
+ def isFlattenable: Boolean = false
}
object AgnosticEncoders {
@@ -99,19 +102,23 @@ object AgnosticEncoders {
def structField: StructField = StructField(name, enc.dataType, nullable,
metadata)
}
- // This supports both Product and DefinedByConstructorParams
- case class ProductEncoder[K](
- override val clsTag: ClassTag[K],
- fields: Seq[EncoderField])
- extends AgnosticEncoder[K] {
+ // Contains a sequence of fields. The fields can be flattened to columns in
a row.
+ trait FieldsEncoder[K] extends AgnosticEncoder[K] {
+ val fields: Seq[EncoderField]
override def isPrimitive: Boolean = false
- override val schema: StructType = StructType(fields.map(_.structField))
+ override def schema: StructType = StructType(fields.map(_.structField))
override def dataType: DataType = schema
+ override val isFlattenable: Boolean = true
}
+ // This supports both Product and DefinedByConstructorParams
+ case class ProductEncoder[K](
+ override val clsTag: ClassTag[K],
+ override val fields: Seq[EncoderField]) extends FieldsEncoder[K]
+
object ProductEncoder {
val cachedCls = new ConcurrentHashMap[Int, Class[_]]
- def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = {
+ private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]):
AgnosticEncoder[_] = {
val fields = encoders.zipWithIndex.map {
case (e, id) => EncoderField(s"_${id + 1}", e, e.nullable,
Metadata.empty)
}
@@ -119,30 +126,27 @@ object AgnosticEncoders {
_ =>
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
ProductEncoder[Any](ClassTag(cls), fields)
}
+
+ private[sql] def isTuple(tag: ClassTag[_]): Boolean = {
+ tag.runtimeClass.getName.startsWith("scala.Tuple")
+ }
}
- abstract class BaseRowEncoder extends AgnosticEncoder[Row] {
- override def isPrimitive: Boolean = false
- override def dataType: DataType = schema
+ abstract class BaseRowEncoder extends FieldsEncoder[Row] {
override def clsTag: ClassTag[Row] = classTag[Row]
}
- case class RowEncoder(fields: Seq[EncoderField]) extends BaseRowEncoder {
- override val schema: StructType = StructType(fields.map(_.structField))
- }
+ case class RowEncoder(override val fields: Seq[EncoderField]) extends
BaseRowEncoder
object UnboundRowEncoder extends BaseRowEncoder {
override val schema: StructType = new StructType()
- }
+ override val fields: Seq[EncoderField] = Seq.empty
+}
case class JavaBeanEncoder[K](
override val clsTag: ClassTag[K],
- fields: Seq[EncoderField])
- extends AgnosticEncoder[K] {
- override def isPrimitive: Boolean = false
- override val schema: StructType = StructType(fields.map(_.structField))
- override def dataType: DataType = schema
- }
+ override val fields: Seq[EncoderField])
+ extends FieldsEncoder[K]
// This will only work for encoding from/to Sparks' InternalRow format.
// It is here for compatibility.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 54c0b84ff52..980295f5e0d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.{Encoder, Row}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.{catalyst, Encoder, Row}
+import org.apache.spark.sql.catalyst.analysis.{Resolver,
UnresolvedDeserializer}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
+import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftSemi,
ReferenceAllColumns}
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types._
@@ -719,3 +720,100 @@ case class CoGroup(
override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left =
newLeft, right = newRight)
}
+
+// TODO (SPARK-44225): Move this into analyzer
+object JoinWith {
+ /**
+ * find the trivially true predicates and automatically resolves them to
both sides.
+ */
+ private[sql] def resolveSelfJoinCondition(resolver: Resolver, plan: Join):
Join = {
+ val cond = plan.condition.map {
+ _.transform {
+ case catalyst.expressions.EqualTo(a: AttributeReference, b:
AttributeReference)
+ if a.sameRef(b) =>
+ catalyst.expressions.EqualTo(
+ plan.left.resolveQuoted(a.name, resolver).getOrElse(
+ throw QueryCompilationErrors.resolveException(a.name,
plan.left.schema.fieldNames)),
+ plan.right.resolveQuoted(b.name, resolver).getOrElse(
+ throw QueryCompilationErrors.resolveException(b.name,
plan.right.schema.fieldNames)))
+ case catalyst.expressions.EqualNullSafe(a: AttributeReference, b:
AttributeReference)
+ if a.sameRef(b) =>
+ catalyst.expressions.EqualNullSafe(
+ plan.left.resolveQuoted(a.name, resolver).getOrElse(
+ throw QueryCompilationErrors.resolveException(a.name,
plan.left.schema.fieldNames)),
+ plan.right.resolveQuoted(b.name, resolver).getOrElse(
+ throw QueryCompilationErrors.resolveException(b.name,
plan.right.schema.fieldNames)))
+ }
+ }
+ plan.copy(condition = cond)
+ }
+
+ private[sql] def typedJoinWith(
+ plan: Join,
+ isAutoSelfJoinAliasEnable: Boolean,
+ resolver: Resolver,
+ isLeftFlattenableToRow: Boolean,
+ isRightFlattenableToRow: Boolean): LogicalPlan = {
+ var joined = plan
+ if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
+ throw
QueryCompilationErrors.invalidJoinTypeInJoinWithError(joined.joinType)
+ }
+ // If auto self join alias is enable
+ if (isAutoSelfJoinAliasEnable) {
+ joined = resolveSelfJoinCondition(resolver, joined)
+ }
+
+ val leftResultExpr = {
+ if (!isLeftFlattenableToRow) {
+ assert(joined.left.output.length == 1)
+ Alias(joined.left.output.head, "_1")()
+ } else {
+ Alias(CreateStruct(joined.left.output), "_1")()
+ }
+ }
+
+ val rightResultExpr = {
+ if (!isRightFlattenableToRow) {
+ assert(joined.right.output.length == 1)
+ Alias(joined.right.output.head, "_2")()
+ } else {
+ Alias(CreateStruct(joined.right.output), "_2")()
+ }
+ }
+
+ if (joined.joinType.isInstanceOf[InnerLike]) {
+ // For inner joins, we can directly perform the join and then can
project the join
+ // results into structs. This ensures that data remains flat during
shuffles /
+ // exchanges (unlike the outer join path, which nests the data before
shuffling).
+ Project(Seq(leftResultExpr, rightResultExpr), joined)
+ } else { // outer joins
+ // For both join sides, combine all outputs into a single column and
alias it with "_1
+ // or "_2", to match the schema for the encoder of the join result.
+ // Note that we do this before joining them, to enable the join operator
to return null
+ // for one side, in cases like outer-join.
+ val left = Project(leftResultExpr :: Nil, joined.left)
+ val right = Project(rightResultExpr :: Nil, joined.right)
+
+ // Rewrites the join condition to make the attribute point to correct
column/field,
+ // after we combine the outputs of each join side.
+ val conditionExpr = joined.condition.get transformUp {
+ case a: Attribute if joined.left.outputSet.contains(a) =>
+ if (!isLeftFlattenableToRow) {
+ left.output.head
+ } else {
+ val index = joined.left.output.indexWhere(_.exprId == a.exprId)
+ GetStructField(left.output.head, index)
+ }
+ case a: Attribute if joined.right.outputSet.contains(a) =>
+ if (!isRightFlattenableToRow) {
+ right.output.head
+ } else {
+ val index = joined.right.output.indexWhere(_.exprId == a.exprId)
+ GetStructField(right.output.head, index)
+ }
+ }
+
+ Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 6a7979afe41..ae5bd282919 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2940,8 +2940,15 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase {
)
}
- def cannotParseIntervalError(intervalString: String, e: Throwable):
Throwable = {
- val threshold = if (intervalString == null) "" else intervalString
+ def resolveException(colName: String, fields: Array[String]):
AnalysisException = {
+ QueryCompilationErrors.unresolvedColumnWithSuggestionError(
+ colName,
+ fields.map(toSQLId).mkString(", ")
+ )
+ }
+
+ def cannotParseIntervalError(delayThreshold: String, e: Throwable):
Throwable = {
+ val threshold = if (delayThreshold == null) "" else delayThreshold
new AnalysisException(
errorClass = "CANNOT_PARSE_INTERVAL",
messageParameters = Map("intervalString" -> toSQLValue(threshold,
StringType)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c87f95294bf..a9017924e14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -48,8 +48,8 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.IntervalUtils
+import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
-import org.apache.spark.sql.errors.QueryCompilationErrors.toSQLId
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter,
ArrowConverters}
@@ -248,12 +248,7 @@ class Dataset[T] private[sql](
private[sql] def resolve(colName: String): NamedExpression = {
val resolver = sparkSession.sessionState.analyzer.resolver
queryExecution.analyzed.resolveQuoted(colName, resolver)
- .getOrElse(throw resolveException(colName, schema.fieldNames))
- }
-
- private def resolveException(colName: String, fields: Array[String]):
AnalysisException = {
- QueryCompilationErrors.unresolvedColumnWithSuggestionError(
- colName, fields.map(toSQLId).mkString(", "))
+ .getOrElse(throw QueryCompilationErrors.resolveException(colName,
schema.fieldNames))
}
private[sql] def numericColumns: Seq[Expression] = {
@@ -1116,30 +1111,6 @@ class Dataset[T] private[sql](
*/
def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right,
joinExprs, "inner")
- /**
- * find the trivially true predicates and automatically resolves them to
both sides.
- */
- private def resolveSelfJoinCondition(plan: Join): Join = {
- val resolver = sparkSession.sessionState.analyzer.resolver
- val cond = plan.condition.map { _.transform {
- case catalyst.expressions.EqualTo(a: AttributeReference, b:
AttributeReference)
- if a.sameRef(b) =>
- catalyst.expressions.EqualTo(
- plan.left.resolveQuoted(a.name, resolver)
- .getOrElse(throw resolveException(a.name,
plan.left.schema.fieldNames)),
- plan.right.resolveQuoted(b.name, resolver)
- .getOrElse(throw resolveException(b.name,
plan.right.schema.fieldNames)))
- case catalyst.expressions.EqualNullSafe(a: AttributeReference, b:
AttributeReference)
- if a.sameRef(b) =>
- catalyst.expressions.EqualNullSafe(
- plan.left.resolveQuoted(a.name, resolver)
- .getOrElse(throw resolveException(a.name,
plan.left.schema.fieldNames)),
- plan.right.resolveQuoted(b.name, resolver)
- .getOrElse(throw resolveException(b.name,
plan.right.schema.fieldNames)))
- }}
- plan.copy(condition = cond)
- }
-
/**
* find the trivially true predicates and automatically resolves them to
both sides.
*/
@@ -1178,7 +1149,7 @@ class Dataset[T] private[sql](
// By the time we get here, since we have already run analysis, all
attributes should've been
// resolved and become AttributeReference.
- resolveSelfJoinCondition(plan)
+
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver,
plan)
}
/**
@@ -1249,7 +1220,7 @@ class Dataset[T] private[sql](
def joinWith[U](other: Dataset[U], condition: Column, joinType: String):
Dataset[(T, U)] = {
// Creates a Join node and resolve it first, to get join condition
resolved, self-join resolved,
// etc.
- var joined = sparkSession.sessionState.executePlan(
+ val joined = sparkSession.sessionState.executePlan(
Join(
this.logicalPlan,
other.logicalPlan,
@@ -1257,70 +1228,15 @@ class Dataset[T] private[sql](
Some(condition.expr),
JoinHint.NONE)).analyzed.asInstanceOf[Join]
- if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
- throw
QueryCompilationErrors.invalidJoinTypeInJoinWithError(joined.joinType)
- }
-
- // If auto self join alias is enable
- if (sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
- joined = resolveSelfJoinCondition(joined)
- }
-
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
- val leftResultExpr = {
- if (!this.exprEnc.isSerializedAsStructForTopLevel) {
- assert(joined.left.output.length == 1)
- Alias(joined.left.output.head, "_1")()
- } else {
- Alias(CreateStruct(joined.left.output), "_1")()
- }
- }
-
- val rightResultExpr = {
- if (!other.exprEnc.isSerializedAsStructForTopLevel) {
- assert(joined.right.output.length == 1)
- Alias(joined.right.output.head, "_2")()
- } else {
- Alias(CreateStruct(joined.right.output), "_2")()
- }
- }
-
- if (joined.joinType.isInstanceOf[InnerLike]) {
- // For inner joins, we can directly perform the join and then can
project the join
- // results into structs. This ensures that data remains flat during
shuffles /
- // exchanges (unlike the outer join path, which nests the data before
shuffling).
- withTypedPlan(Project(Seq(leftResultExpr, rightResultExpr), joined))
- } else { // outer joins
- // For both join sides, combine all outputs into a single column and
alias it with "_1
- // or "_2", to match the schema for the encoder of the join result.
- // Note that we do this before joining them, to enable the join operator
to return null
- // for one side, in cases like outer-join.
- val left = Project(leftResultExpr :: Nil, joined.left)
- val right = Project(rightResultExpr :: Nil, joined.right)
-
- // Rewrites the join condition to make the attribute point to correct
column/field,
- // after we combine the outputs of each join side.
- val conditionExpr = joined.condition.get transformUp {
- case a: Attribute if joined.left.outputSet.contains(a) =>
- if (!this.exprEnc.isSerializedAsStructForTopLevel) {
- left.output.head
- } else {
- val index = joined.left.output.indexWhere(_.exprId == a.exprId)
- GetStructField(left.output.head, index)
- }
- case a: Attribute if joined.right.outputSet.contains(a) =>
- if (!other.exprEnc.isSerializedAsStructForTopLevel) {
- right.output.head
- } else {
- val index = joined.right.output.indexWhere(_.exprId == a.exprId)
- GetStructField(right.output.head, index)
- }
- }
-
- withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr),
JoinHint.NONE))
- }
+ withTypedPlan(JoinWith.typedJoinWith(
+ joined,
+ sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
+ sparkSession.sessionState.analyzer.resolver,
+ this.exprEnc.isSerializedAsStructForTopLevel,
+ other.exprEnc.isSerializedAsStructForTopLevel))
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]