This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a27ccd78875 [SPARK-40879][CONNECT] Support Join UsingColumns in proto a27ccd78875 is described below commit a27ccd788750c1b1394c8274f79643cb2ad6cf49 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Oct 25 23:49:02 2022 +0800 [SPARK-40879][CONNECT] Support Join UsingColumns in proto ### What changes were proposed in this pull request? I was working on refactoring Connect proto tests from Catalyst DSL to DataFrame API, and identified that Join in Connect does not support `UsingColumns`. This is a gap between the Connect proto and DataFrame API. This also blocks the refactoring work because without `UsingColumns`, there is no compatible DataFrame Join API that we can covert existing tests to. This PR adds the support for Join's `UsingColumns`. ### Why are the changes needed? 1. Improve API coverage. 2. Unblock testing refactoring. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT Closes #38345 from amaliujia/proto-join-using-columns. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 6 ++ .../org/apache/spark/sql/connect/dsl/package.scala | 32 ++++++++++- .../sql/connect/planner/SparkConnectPlanner.scala | 17 ++++-- .../connect/planner/SparkConnectPlannerSuite.scala | 14 +++++ .../connect/planner/SparkConnectProtoSuite.scala | 16 +++++- python/pyspark/sql/connect/proto/relations_pb2.py | 64 +++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 14 +++++ 7 files changed, 123 insertions(+), 40 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 7dbde775ee8..94010487ee5 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -109,6 +109,12 @@ message Join { Relation right = 2; Expression join_condition = 3; JoinType join_type = 4; + // Optional. using_columns provides a list of columns that should present on both sides of + // the join inputs that this Join will join on. For example A JOIN B USING col_name is + // equivalent to A JOIN B on A.col_name = B.col_name. + // + // This field does not co-exist with join_condition. + repeated string using_columns = 5; enum JoinType { JOIN_TYPE_UNSPECIFIED = 0; diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 4630c86049c..6ae6dfa1577 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -236,15 +236,45 @@ package object dsl { .build() def join( + otherPlan: proto.Relation, + joinType: JoinType, + condition: Option[proto.Expression]): proto.Relation = { + join(otherPlan, joinType, Seq(), condition) + } + + def join(otherPlan: proto.Relation, condition: Option[proto.Expression]): proto.Relation = { + join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition) + } + + def join(otherPlan: proto.Relation): proto.Relation = { + join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None) + } + + def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation = { + join(otherPlan, joinType, Seq(), None) + } + + def join( + otherPlan: proto.Relation, + joinType: JoinType, + usingColumns: Seq[String]): proto.Relation = { + join(otherPlan, joinType, usingColumns, None) + } + + private def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, - condition: Option[proto.Expression] = None): proto.Relation = { + usingColumns: Seq[String], + condition: Option[proto.Expression]): proto.Relation = { val relation = proto.Relation.newBuilder() val join = proto.Join.newBuilder() join .setLeft(logicalPlan) .setRight(otherPlan) .setJoinType(joinType) + if (usingColumns.nonEmpty) { + join.addAllUsingColumns(usingColumns.asJava) + } if (condition.isDefined) { join.setJoinCondition(condition.get) } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 880618cc333..9e3899f4a1a 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sample, SubqueryAlias} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.QueryExecution @@ -292,14 +292,23 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformJoin(rel: proto.Join): LogicalPlan = { assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") + if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) { + throw InvalidPlanInput( + s"Using columns or join conditions cannot be set at the same time in Join") + } val joinCondition = if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None - + val catalystJointype = transformJoinType( + if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER) + val joinType = if (rel.getUsingColumnsCount > 0) { + UsingJoin(catalystJointype, rel.getUsingColumnsList.asScala.toSeq) + } else { + catalystJointype + } logical.Join( left = transformRelation(rel.getLeft), right = transformRelation(rel.getRight), - joinType = transformJoinType( - if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER), + joinType = joinType, condition = joinCondition, hint = logical.JoinHint.NONE) } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 980e899c26e..6fc47e07c59 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -220,6 +220,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(res.nodeName == "Join") assert(res != null) + val e = intercept[InvalidPlanInput] { + val simpleJoin = proto.Relation.newBuilder + .setJoin( + proto.Join.newBuilder + .setLeft(readRel) + .setRight(readRel) + .addUsingColumns("test_col") + .setJoinCondition(joinCondition)) + .build() + transform(simpleJoin) + } + assert( + e.getMessage.contains( + "Using columns or join conditions cannot be set at the same time in Join")) } test("Simple Projection") { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index d8bb1684cb8..0325b6573bd 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** @@ -32,11 +32,13 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) - lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int)) + lazy val connectTestRelation2 = createLocalRelationProto( + Seq($"key".int, $"value".int, $"name".string)) lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) - lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int) + lazy val sparkTestRelation2: LocalRelation = + LocalRelation($"key".int, $"value".int, $"name".string) test("Basic select") { val connectPlan = { @@ -117,6 +119,14 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) } + + val connectPlan4 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform( + connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name"))) + } + val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, UsingJoin(Inner, Seq("name"))) + comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) } test("Test sample") { diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index d9a596fba8c..2a38a014926 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -64,35 +64,35 @@ if _descriptor._USE_C_DESCRIPTORS == False: _FILTER._serialized_start = 1512 _FILTER._serialized_end = 1624 _JOIN._serialized_start = 1627 - _JOIN._serialized_end = 2040 - _JOIN_JOINTYPE._serialized_start = 1853 - _JOIN_JOINTYPE._serialized_end = 2040 - _UNION._serialized_start = 2043 - _UNION._serialized_end = 2248 - _UNION_UNIONTYPE._serialized_start = 2164 - _UNION_UNIONTYPE._serialized_end = 2248 - _LIMIT._serialized_start = 2250 - _LIMIT._serialized_end = 2326 - _OFFSET._serialized_start = 2328 - _OFFSET._serialized_end = 2407 - _AGGREGATE._serialized_start = 2410 - _AGGREGATE._serialized_end = 2735 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2639 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2735 - _SORT._serialized_start = 2738 - _SORT._serialized_end = 3240 - _SORT_SORTFIELD._serialized_start = 2858 - _SORT_SORTFIELD._serialized_end = 3046 - _SORT_SORTDIRECTION._serialized_start = 3048 - _SORT_SORTDIRECTION._serialized_end = 3156 - _SORT_SORTNULLS._serialized_start = 3158 - _SORT_SORTNULLS._serialized_end = 3240 - _DEDUPLICATE._serialized_start = 3243 - _DEDUPLICATE._serialized_end = 3385 - _LOCALRELATION._serialized_start = 3387 - _LOCALRELATION._serialized_end = 3480 - _SAMPLE._serialized_start = 3483 - _SAMPLE._serialized_end = 3723 - _SAMPLE_SEED._serialized_start = 3697 - _SAMPLE_SEED._serialized_end = 3723 + _JOIN._serialized_end = 2077 + _JOIN_JOINTYPE._serialized_start = 1890 + _JOIN_JOINTYPE._serialized_end = 2077 + _UNION._serialized_start = 2080 + _UNION._serialized_end = 2285 + _UNION_UNIONTYPE._serialized_start = 2201 + _UNION_UNIONTYPE._serialized_end = 2285 + _LIMIT._serialized_start = 2287 + _LIMIT._serialized_end = 2363 + _OFFSET._serialized_start = 2365 + _OFFSET._serialized_end = 2444 + _AGGREGATE._serialized_start = 2447 + _AGGREGATE._serialized_end = 2772 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2676 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2772 + _SORT._serialized_start = 2775 + _SORT._serialized_end = 3277 + _SORT_SORTFIELD._serialized_start = 2895 + _SORT_SORTFIELD._serialized_end = 3083 + _SORT_SORTDIRECTION._serialized_start = 3085 + _SORT_SORTDIRECTION._serialized_end = 3193 + _SORT_SORTNULLS._serialized_start = 3195 + _SORT_SORTNULLS._serialized_end = 3277 + _DEDUPLICATE._serialized_start = 3280 + _DEDUPLICATE._serialized_end = 3422 + _LOCALRELATION._serialized_start = 3424 + _LOCALRELATION._serialized_end = 3517 + _SAMPLE._serialized_start = 3520 + _SAMPLE._serialized_end = 3760 + _SAMPLE_SEED._serialized_start = 3734 + _SAMPLE_SEED._serialized_end = 3760 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index df179df1480..d3186c4e3df 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -467,6 +467,7 @@ class Join(google.protobuf.message.Message): RIGHT_FIELD_NUMBER: builtins.int JOIN_CONDITION_FIELD_NUMBER: builtins.int JOIN_TYPE_FIELD_NUMBER: builtins.int + USING_COLUMNS_FIELD_NUMBER: builtins.int @property def left(self) -> global___Relation: ... @property @@ -474,6 +475,16 @@ class Join(google.protobuf.message.Message): @property def join_condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression: ... join_type: global___Join.JoinType.ValueType + @property + def using_columns( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Optional. using_columns provides a list of columns that should present on both sides of + the join inputs that this Join will join on. For example A JOIN B USING col_name is + equivalent to A JOIN B on A.col_name = B.col_name. + + This field does not co-exist with join_condition. + """ def __init__( self, *, @@ -481,6 +492,7 @@ class Join(google.protobuf.message.Message): right: global___Relation | None = ..., join_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ..., join_type: global___Join.JoinType.ValueType = ..., + using_columns: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, @@ -499,6 +511,8 @@ class Join(google.protobuf.message.Message): b"left", "right", b"right", + "using_columns", + b"using_columns", ], ) -> None: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org